Compare commits

...

12 Commits

Author SHA1 Message Date
github-actions[bot]
1151c15e59 ci: update version string in docs 2024-04-05 17:33:23 +00:00
Jseam
0943e534ee docs: automated sphinx documentation for python bindings (#714)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
Co-authored-by: Ethan Cemer <tylercemer@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-04-05 18:33:06 +01:00
dante
316a9a3b40 chore: update tract (#766) 2024-04-04 18:07:08 +01:00
dante
5389012b68 fix: patch large batch ex (#763) 2024-04-03 02:33:57 +01:00
dante
48223cca11 fix: make commitment optional for backwards compat (#762) 2024-04-03 02:26:50 +01:00
dante
32c3a5e159 fix: hold stacked outputs in a separate map 2024-04-02 21:37:20 +01:00
dante
ff563e93a7 fix: bump python version (#761) 2024-04-02 17:08:26 +01:00
dante
5639d36097 chore: verify aggr wasm unit test (#760) 2024-04-01 20:54:20 +01:00
dante
4ec8d13082 chore: verify aggr in wasm (#758) 2024-03-29 23:28:20 +00:00
dante
12735aefd4 chore: reduce softmax recip DR (#756) 2024-03-27 01:14:29 +00:00
dante
7fe179b8d4 feat: dictionary of reusable constants (#754) 2024-03-26 13:12:09 +00:00
Ethan Cemer
3be988a6a0 fix: use pnpm in build script for in-browser-evm-verifier (#752) 2024-03-25 23:23:02 +00:00
61 changed files with 5858 additions and 769 deletions

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag

View File

@@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -70,7 +70,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set Cargo.toml version to match github tag
@@ -115,7 +115,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -128,6 +128,7 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -139,6 +140,20 @@ jobs:
target: ${{ matrix.target }}
manylinux: auto
args: --release --out dist --features python-bindings
before-script-linux: |
# If we're running on rhel centos, install needed packages.
if command -v yum &> /dev/null; then
yum update -y && yum install -y perl-core openssl openssl-devel pkgconfig libatomic
# If we're running on i686 we need to symlink libatomic
# in order to build openssl with -latomic flag.
if [[ ! -d "/usr/lib64" ]]; then
ln -s /usr/lib/libatomic.so.1 /usr/lib/libatomic.so
fi
else
# If we're running on debian-based system.
apt update -y && apt-get install -y libssl-dev openssl pkg-config
fi
- name: Install built wheel
if: matrix.target == 'x86_64'
@@ -162,7 +177,7 @@ jobs:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.7
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
@@ -214,7 +229,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -249,7 +264,7 @@ jobs:
apk add py3-pip
pip3 install -U pip
python3 -m venv .venv
source .venv/bin/activate
source .venv/bin/activate
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
python3 -c "import ezkl"
@@ -273,7 +288,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
- name: Set Cargo.toml version to match github tag
shell: bash
@@ -345,3 +360,17 @@ jobs:
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}

View File

@@ -184,7 +184,7 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -207,7 +207,7 @@ jobs:
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -224,7 +224,7 @@ jobs:
mock-proving-tests:
runs-on: non-gpu
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -281,7 +281,7 @@ jobs:
prove-and-verify-evm-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -354,7 +354,7 @@ jobs:
prove-and-verify-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -460,7 +460,7 @@ jobs:
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -495,7 +495,7 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -512,7 +512,7 @@ jobs:
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -557,16 +557,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Build python ezkl
@@ -576,12 +578,12 @@ jobs:
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -592,7 +594,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
@@ -612,7 +614,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -626,10 +628,14 @@ jobs:
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -645,7 +651,5 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

View File

@@ -14,6 +14,40 @@ jobs:
- uses: actions/checkout@v4
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.1
uses: mathieudutour/github-tag-action@v6.2
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
- name: Set Cargo.toml version to match github tag for docs
shell: bash
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
mv docs/python/src/conf.py docs/python/src/conf.py.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
rm docs/python/src/conf.py.orig
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
rm docs/python/requirements-docs.txt.orig
- name: Commit files and create tag
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git fetch --tags
git checkout -b release-$RELEASE_TAG
git add .
git commit -m "ci: update version string in docs"
git tag -d $RELEASE_TAG
git tag $RELEASE_TAG
- name: Push changes
uses: ad-m/github-push-action@master
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
with:
branch: release-${{ steps.tag_version.outputs.new_tag }}
force: true
tags: true

4
.gitignore vendored
View File

@@ -48,4 +48,6 @@ node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/wasm/vk.key
docs/python/build
!tests/wasm/vk_aggr.key

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12.1

26
.readthedocs.yaml Normal file
View File

@@ -0,0 +1,26 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: ./docs/python/src/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: ./docs/python/requirements-docs.txt

264
Cargo.lock generated
View File

@@ -112,7 +112,7 @@ checksum = "1a047897373be4bbb0224c1afdabca92648dc57a9c9ef6e7b0be3aff7a859c83"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -211,9 +211,9 @@ checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c"
[[package]]
name = "arc-swap"
version = "1.7.0"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b3d0060af21e8d11a926981cc00c6c1541aa91dd64b9f881985c3da1094425f"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]]
name = "ark-bls12-377"
@@ -463,7 +463,7 @@ dependencies = [
"proc-macro2",
"quote",
"serde",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -483,13 +483,13 @@ dependencies = [
[[package]]
name = "async-trait"
version = "0.1.78"
version = "0.1.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85"
checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -522,20 +522,20 @@ checksum = "3c87f3f15e7794432337fc718554eaa4dc8f04c9677a950ffe366f20a162ae42"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
name = "autocfg"
version = "1.1.0"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80"
[[package]]
name = "backtrace"
version = "0.3.70"
version = "0.3.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95d8e92cac0961e91dbd517496b00f7e9b92363dbe6d42c3198268323798860c"
checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d"
dependencies = [
"addr2line",
"cc",
@@ -716,9 +716,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.5.0"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9"
dependencies = [
"serde",
]
@@ -734,9 +734,9 @@ dependencies = [
[[package]]
name = "cargo-platform"
version = "0.1.7"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "694c8807f2ae16faecc43dc17d74b3eb042482789fd0eb64b39a2e04e087053f"
checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc"
dependencies = [
"serde",
]
@@ -779,9 +779,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.35"
version = "0.4.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a"
checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e"
dependencies = [
"android-tzdata",
"iana-time-zone",
@@ -814,9 +814,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.3"
version = "4.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "949626d00e063efc93b6dca932419ceb5432f99769911c0b995f7e884c778813"
checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0"
dependencies = [
"clap_builder",
"clap_derive",
@@ -836,14 +836,14 @@ dependencies = [
[[package]]
name = "clap_derive"
version = "4.5.3"
version = "4.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90239a040c80f5e14809ca132ddc4176ab33d5e17e49691793296e3fcb34d72f"
checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -1177,9 +1177,9 @@ dependencies = [
[[package]]
name = "der"
version = "0.7.8"
version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c"
checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0"
dependencies = [
"const-oid",
"zeroize",
@@ -1409,7 +1409,7 @@ checksum = "6fd000fd6988e73bbe993ea3db9b1aa64906ab88766d654973924340c8cddb42"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -1575,7 +1575,7 @@ dependencies = [
"regex",
"serde",
"serde_json",
"syn 2.0.53",
"syn 2.0.58",
"toml",
"walkdir",
]
@@ -1593,7 +1593,7 @@ dependencies = [
"proc-macro2",
"quote",
"serde_json",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -1619,7 +1619,7 @@ dependencies = [
"serde",
"serde_json",
"strum",
"syn 2.0.53",
"syn 2.0.58",
"tempfile",
"thiserror",
"tiny-keccak",
@@ -1772,7 +1772,7 @@ dependencies = [
"ark-std 0.3.0",
"bincode",
"chrono",
"clap 4.5.3",
"clap 4.5.4",
"colored",
"colored_json",
"console_error_panic_hook",
@@ -1837,9 +1837,9 @@ checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fastrand"
version = "2.0.1"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5"
checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984"
[[package]]
name = "fastrlp"
@@ -2028,7 +2028,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -2258,7 +2258,7 @@ dependencies = [
[[package]]
name = "halo2curves"
version = "0.3.2"
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?tag=0.3.2#9f5c50810bbefe779ee5cf1d852b2fe85dc35d5e"
source = "git+https://github.com/privacy-scaling-explorations/halo2curves.git?tag=0.3.2#9f5c50810bbefe779ee5cf1d852b2fe85dc35d5e"
dependencies = [
"ff",
"group",
@@ -2602,9 +2602,9 @@ checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683"
[[package]]
name = "indexmap"
version = "2.2.5"
version = "2.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4"
checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [
"equivalent",
"hashbrown 0.14.3",
@@ -2626,9 +2626,9 @@ dependencies = [
[[package]]
name = "indoc"
version = "2.0.4"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "inout"
@@ -2700,10 +2700,19 @@ dependencies = [
]
[[package]]
name = "itoa"
version = "1.0.10"
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "jobserver"
@@ -2823,13 +2832,12 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libredox"
version = "0.0.1"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
"bitflags 2.5.0",
"libc",
"redox_syscall",
]
[[package]]
@@ -2877,7 +2885,7 @@ checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -2962,24 +2970,24 @@ dependencies = [
[[package]]
name = "memchr"
version = "2.7.1"
version = "2.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d"
[[package]]
name = "memmap2"
version = "0.5.10"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327"
checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322"
dependencies = [
"libc",
]
[[package]]
name = "memoffset"
version = "0.9.0"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
@@ -3187,7 +3195,7 @@ dependencies = [
"proc-macro-crate 3.1.0",
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -3271,7 +3279,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -3291,9 +3299,9 @@ dependencies = [
[[package]]
name = "openssl-sys"
version = "0.9.101"
version = "0.9.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dda2b0f344e78efc2facf7d195d098df0dd72151b26ab98da807afc26c198dff"
checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2"
dependencies = [
"cc",
"libc",
@@ -3444,9 +3452,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pest"
version = "2.7.8"
version = "2.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56f8023d0fb78c8e03784ea1c7f3fa36e68a723138990b8d5a47d916b651e7a8"
checksum = "311fb059dee1a7b802f036316d790138c613a4e8b180c822e3925a662e9f0c95"
dependencies = [
"memchr",
"thiserror",
@@ -3455,9 +3463,9 @@ dependencies = [
[[package]]
name = "pest_derive"
version = "2.7.8"
version = "2.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0d24f72393fd16ab6ac5738bc33cdb6a9aa73f8b902e8fe29cf4e67d7dd1026"
checksum = "f73541b156d32197eecda1a4014d7f868fd2bcb3c550d5386087cfba442bf69c"
dependencies = [
"pest",
"pest_generator",
@@ -3465,22 +3473,22 @@ dependencies = [
[[package]]
name = "pest_generator"
version = "2.7.8"
version = "2.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdc17e2a6c7d0a492f0158d7a4bd66cc17280308bbaff78d5bef566dca35ab80"
checksum = "c35eeed0a3fab112f75165fdc026b3913f4183133f19b49be773ac9ea966e8bd"
dependencies = [
"pest",
"pest_meta",
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
name = "pest_meta"
version = "2.7.8"
version = "2.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "934cd7631c050f4674352a6e835d5f6711ffbfb9345c2fc0107155ac495ae293"
checksum = "2adbf29bb9776f28caece835398781ab24435585fe0d4dc1374a61db5accedca"
dependencies = [
"once_cell",
"pest",
@@ -3550,7 +3558,7 @@ dependencies = [
"phf_shared 0.11.2",
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -3588,14 +3596,14 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
name = "pin-project-lite"
version = "0.2.13"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "pin-utils"
@@ -3719,12 +3727,12 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]]
name = "prettyplease"
version = "0.2.16"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5"
checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7"
dependencies = [
"proc-macro2",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -3933,7 +3941,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -3946,7 +3954,7 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -4049,9 +4057,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.9.0"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
@@ -4089,9 +4097,9 @@ dependencies = [
[[package]]
name = "redox_users"
version = "0.4.4"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4"
checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891"
dependencies = [
"getrandom",
"libredox",
@@ -4100,9 +4108,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.3"
version = "1.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15"
checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
dependencies = [
"aho-corasick",
"memchr",
@@ -4123,9 +4131,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.8.2"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
[[package]]
name = "remove_dir_all"
@@ -4444,9 +4452,9 @@ dependencies = [
[[package]]
name = "scale-info"
version = "2.11.0"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ef2175c2907e7c8bc0a9c3f86aeb5ec1f3b275300ad58a44d0c3ae379a5e52e"
checksum = "788745a868b0e751750388f4e6546eb921ef714a4317fa6954f7cde114eb2eb7"
dependencies = [
"cfg-if",
"derive_more",
@@ -4456,9 +4464,9 @@ dependencies = [
[[package]]
name = "scale-info-derive"
version = "2.11.0"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "634d9b8eb8fd61c5cdd3390d9b2132300a7e7618955b98b8416f118c1b4e144f"
checksum = "7dc2f4e8bc344b9fc3d5f74f72c2e55bfc38d28dc2ebc69c194a3df424e4d9ac"
dependencies = [
"proc-macro-crate 1.3.1",
"proc-macro2",
@@ -4524,9 +4532,9 @@ dependencies = [
[[package]]
name = "security-framework"
version = "2.9.2"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de"
checksum = "770452e37cad93e0a50d5abc3990d2bc351c36d0328f86cefec2f2fb206eaef6"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
@@ -4537,9 +4545,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
version = "2.9.1"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a"
checksum = "41f3cc463c0ef97e11c3461a9d3787412d30e8e7eb907c79180c4a57bf7c04ef"
dependencies = [
"core-foundation-sys",
"libc",
@@ -4601,9 +4609,9 @@ dependencies = [
[[package]]
name = "serde-wasm-bindgen"
version = "0.4.5"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b4c031cd0d9014307d82b8abf653c0290fbdaeb4c02d00c63cf52f728628bf"
checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b"
dependencies = [
"js-sys",
"serde",
@@ -4637,14 +4645,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
name = "serde_json"
version = "1.0.114"
version = "1.0.115"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0"
checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd"
dependencies = [
"itoa",
"ryu",
@@ -4845,12 +4853,12 @@ checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
[[package]]
name = "string-interner"
version = "0.14.0"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e2531d8525b29b514d25e275a43581320d587b86db302b9a7e464bac579648"
checksum = "07f9fdfdd31a0ff38b59deb401be81b73913d76c9cc5b1aed4e1330a223420b9"
dependencies = [
"cfg-if",
"hashbrown 0.11.2",
"hashbrown 0.14.3",
"serde",
]
@@ -4880,9 +4888,9 @@ dependencies = [
[[package]]
name = "strsim"
version = "0.11.0"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
@@ -4903,7 +4911,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -4938,9 +4946,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.53"
version = "2.0.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032"
checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687"
dependencies = [
"proc-macro2",
"quote",
@@ -5113,7 +5121,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -5179,9 +5187,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.36.0"
version = "1.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931"
checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787"
dependencies = [
"backtrace",
"bytes",
@@ -5202,7 +5210,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -5337,7 +5345,7 @@ source = "git+https://github.com/zkonduit/enum_to_subcommand#42e9870f1f757932bab
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -5365,7 +5373,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -5389,8 +5397,8 @@ dependencies = [
[[package]]
name = "tract-core"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
dependencies = [
"anyhow",
"bit-set",
@@ -5413,12 +5421,12 @@ dependencies = [
[[package]]
name = "tract-data"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
dependencies = [
"anyhow",
"half 2.2.1",
"itertools 0.10.5",
"itertools 0.12.1",
"lazy_static",
"maplit",
"ndarray",
@@ -5432,8 +5440,8 @@ dependencies = [
[[package]]
name = "tract-hir"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
dependencies = [
"derive-new",
"log",
@@ -5442,8 +5450,8 @@ dependencies = [
[[package]]
name = "tract-linalg"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
dependencies = [
"cc",
"derive-new",
@@ -5466,8 +5474,8 @@ dependencies = [
[[package]]
name = "tract-nnef"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
dependencies = [
"byteorder",
"flate2",
@@ -5480,8 +5488,8 @@ dependencies = [
[[package]]
name = "tract-onnx"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
dependencies = [
"bytes",
"derive-new",
@@ -5497,8 +5505,8 @@ dependencies = [
[[package]]
name = "tract-onnx-opl"
version = "0.20.23-pre"
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
version = "0.21.3"
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
dependencies = [
"getrandom",
"log",
@@ -5755,7 +5763,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
"wasm-bindgen-shared",
]
@@ -5799,7 +5807,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -5844,7 +5852,7 @@ checksum = "b7f89739351a2e03cb94beb799d47fb2cac01759b40ec441f7de39b00cbf7ef0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -6154,7 +6162,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]
[[package]]
@@ -6174,5 +6182,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
"syn 2.0.58",
]

View File

@@ -80,7 +80,7 @@ pyo3-asyncio = { version = "0.20.0", features = [
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
@@ -95,10 +95,10 @@ getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-bindgen-test = "0.3.42"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"

View File

@@ -1,3 +1,5 @@
use std::collections::HashMap;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use ezkl::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
@@ -48,7 +50,7 @@ impl Circuit<Fr> for MyCircuit {
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0)?;
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
}
}

2
docs/python/build.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
sphinx-build ./src build

View File

@@ -0,0 +1,4 @@
ezkl==10.2.9
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

29
docs/python/src/conf.py Normal file
View File

@@ -0,0 +1,29 @@
import ezkl
project = 'ezkl'
release = '10.2.9'
version = release
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.inheritance_diagram',
'sphinx.ext.autosectionlabel',
'sphinx.ext.napoleon',
'sphinx_rtd_theme',
]
autosummary_generate = True
autosummary_imported_members = True
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']

11
docs/python/src/index.rst Normal file
View File

@@ -0,0 +1,11 @@
.. extension documentation master file, created by
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
ezkl python bindings
================================================
.. automodule:: ezkl
:members:
:undoc-members:

View File

@@ -67,6 +67,7 @@
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
"model.output_names=['output']\n",
"\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",

View File

@@ -38,7 +38,7 @@
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers.legacy import Adam\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
@@ -71,9 +71,11 @@
},
"outputs": [],
"source": [
"opt = Adam()\n",
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
@@ -114,8 +116,11 @@
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
@@ -415,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -349,6 +349,8 @@
"z_log_var = Dense(ZDIM)(x)\n",
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
"dec = get_decoder()\n",
"dec.output_names=['output']\n",
"\n",
"out = dec(z)\n",
"\n",
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",

View File

@@ -61,11 +61,10 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"from torch import nn\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"\n",
@@ -77,28 +76,12 @@
"clr.fit(X_train, y_train)\n",
"\n",
"\n",
"trees = []\n",
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
" super(RandomForest, self).__init__()\n",
" self.trees = nn.ModuleList(trees)\n",
"\n",
" def forward(self, x):\n",
" out = self.trees[0](x)\n",
" for tree in self.trees[1:]:\n",
" out += tree(x)\n",
" return out / len(self.trees)\n",
"\n",
"\n",
"torch_rf = RandomForest(trees)\n",
"torch_rf = convert(clr, 'torch')\n",
"# assert predictions from torch are = to sklearn \n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
@@ -134,14 +117,12 @@
"\n",
"# export to onnx format\n",
"\n",
"torch_rf.eval()\n",
"\n",
"# Input to the model\n",
"shape = X_train.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=False)\n",
"torch_out = torch_rf(x)\n",
"torch_out = torch_rf.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(torch_rf, # model being run\n",
"torch.onnx.export(torch_rf.model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
@@ -158,7 +139,7 @@
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -321,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -227,6 +227,10 @@
" self.length = self.compute_length(self.file_good)\n",
" self.data = self.load_data(self.file_good)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -749,7 +753,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -209,6 +209,11 @@
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
" self.data = self.load_data(self.file_good, self.file_bad)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -637,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,13 @@
{
"input_data": [
[
0.8894134163856506,
0.8894201517105103
]
],
"output_data": [
[
0.8436377
]
]
}

Binary file not shown.

View File

@@ -17,7 +17,7 @@
"clean": "rm -r dist || true",
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=0.14,<0.15"]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.pytest.ini_options]

View File

@@ -1,14 +1,14 @@
attrs==22.2.0
exceptiongroup==1.1.1
importlib-metadata==6.1.0
attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.0.1
packaging==23.0
pluggy==1.0.0
pytest==7.2.2
maturin==1.5.1
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
tomli==2.0.1
typing-extensions==4.5.0
zipp==3.15.0
onnx==1.14.1
onnxruntime==1.14.1
numpy==1.21.6
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -15,6 +15,8 @@ pub use planner::*;
use crate::tensor::{TensorType, ValTensor};
use super::region::ConstantsMap;
/// Module trait used to extend ezkl functionality
pub trait Module<F: PrimeField + TensorType + PartialOrd> {
/// Config
@@ -39,6 +41,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
&self,
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
constants: &mut ConstantsMap<F>,
) -> Result<Self::InputAssignments, Error>;
/// Layout
fn layout(
@@ -46,6 +49,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
row_offset: usize,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, Error>;
/// Number of instance values the module uses every time it is applied
fn instance_increment_input(&self) -> Vec<usize>;

View File

@@ -4,6 +4,8 @@ is already implemented in halo2_gadgets, there is no wrapper chip that makes it
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
use std::collections::HashMap;
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::poly::commitment::{Blind, CommitmentScheme, Params};
@@ -13,6 +15,7 @@ use halo2curves::group::prime::PrimeCurveAffine;
use halo2curves::group::Curve;
use halo2curves::CurveAffine;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::Module;
@@ -107,6 +110,7 @@ impl Module<Fp> for PolyCommitChip {
&self,
_: &mut impl Layouter<Fp>,
_: &[ValTensor<Fp>],
_: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
Ok(())
}
@@ -119,11 +123,24 @@ impl Module<Fp> for PolyCommitChip {
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
_: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
assert_eq!(input.len(), 1);
let local_constants = constants.clone();
layouter.assign_region(
|| "PolyCommit",
|mut region| self.config.inputs.assign(&mut region, 0, &input[0]),
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
}
@@ -184,7 +201,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let polycommit_chip = PolyCommitChip::new(config);
polycommit_chip.layout(&mut layouter, &[self.message.clone()], 0);
polycommit_chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
);
Ok(())
}

View File

@@ -18,6 +18,7 @@ use maybe_rayon::slice::ParallelSlice;
use std::marker::PhantomData;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType};
use super::Module;
@@ -172,12 +173,15 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
&self,
layouter: &mut impl Layouter<Fp>,
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
assert_eq!(message.len(), 1);
let message = message[0].clone();
let start_time = instant::Instant::now();
let local_constants = constants.clone();
let res = layouter.assign_region(
|| "load message",
|mut region| {
@@ -199,12 +203,26 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
log::error!("constant not previously assigned");
Error::Synthesis
})?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
Ok(res)
}
}
e => {
log::error!(
"wrong input type {:?}, must be previously assigned",
@@ -270,8 +288,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input)?;
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
input_cells.iter().map(|e| ValType::from(e.clone())).into();
@@ -434,7 +453,7 @@ mod tests {
*,
};
use std::marker::PhantomData;
use std::{collections::HashMap, marker::PhantomData};
use halo2_gadgets::poseidon::primitives::Spec;
use halo2_proofs::{
@@ -477,7 +496,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.message.clone()], 0)?;
chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
)?;
Ok(())
}

View File

@@ -345,7 +345,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {

View File

@@ -46,7 +46,8 @@ pub enum HybridOp {
dim: usize,
},
Softmax {
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
},
RangeCheck(Tolerance),
@@ -70,7 +71,7 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
@@ -130,9 +131,16 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
kernel_shape,
normalized,
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
HybridOp::Softmax { scale, axes } => {
tensor::ops::nonlinearities::softmax_axes(&x, scale.into(), axes)
}
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => tensor::ops::nonlinearities::softmax_axes(
&x,
input_scale.into(),
output_scale.into(),
axes,
),
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
@@ -203,8 +211,15 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
HybridOp::Softmax { scale, axes } => {
format!("SOFTMAX (scale={}, axes={:?})", scale, axes)
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".into(),
@@ -324,9 +339,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::ReduceArgMin { dim } => {
layouts::argmin_axes(config, region, values[..].try_into()?, *dim)?
}
HybridOp::Softmax { scale, axes } => {
layouts::softmax_axes(config, region, values[..].try_into()?, *scale, axes)?
}
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => layouts::softmax_axes(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
axes,
)?,
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
config,
region,
@@ -359,8 +383,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { .. } => 2 * in_scales[0],
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
_ => in_scales[0],
};
Ok(scale)

View File

@@ -9,6 +9,7 @@ use halo2curves::ff::PrimeField;
use itertools::Itertools;
use log::{error, trace};
use maybe_rayon::{
iter::IntoParallelRefIterator,
prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
slice::ParallelSliceMut,
};
@@ -33,7 +34,7 @@ use super::*;
use crate::circuit::ops::lookup::LookupOp;
/// Same as div but splits the division into N parts
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -68,7 +69,7 @@ pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
}
/// Div accumulated layout
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -93,9 +94,9 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
@@ -133,7 +134,7 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
}
/// recip accumulated layout
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -166,9 +167,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
felt_to_i128(input_scale) as f64,
felt_to_i128(output_scale) as f64,
)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
@@ -226,7 +227,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
}
/// Dot product accumulated layout
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -337,7 +338,7 @@ pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
}
/// Einsum
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
inputs: &[ValTensor<F>],
@@ -524,14 +525,7 @@ pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
// Compute the product of all input tensors
for pair in input_pairs {
let product_across_pair = prod(
config,
region,
&[pair.try_into().map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?],
)?;
let product_across_pair = prod(config, region, &[pair.into()])?;
if let Some(product) = prod_res {
prod_res = Some(
@@ -563,7 +557,7 @@ pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -574,12 +568,12 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()?;
let sorted = if is_assigned {
input
.get_int_evals()?
.iter()
.sorted_by(|a, b| a.cmp(b))
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
let mut int_evals = input.get_int_evals()?;
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
int_evals
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
@@ -607,7 +601,7 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
}
///
fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -622,7 +616,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
}
/// Select top k elements
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -642,11 +636,12 @@ pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
fn select<F: PrimeField + TensorType + PartialOrd>(
fn select<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let start = instant::Instant::now();
let (mut input, index) = (values[0].clone(), values[1].clone());
input.flatten();
@@ -656,12 +651,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
let output: ValTensor<F> = if is_assigned {
let output: ValTensor<F> = if is_assigned && region.witness_gen() {
let felt_evals = input.get_felt_evals()?;
index
.get_int_evals()?
.iter()
.map(|x| Ok(Value::known(input.get_felt_evals()?.get(&[*x as usize]))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(felt_evals.get(&[*x as usize])))
.collect::<Tensor<Value<F>>>()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); index.len()]),
@@ -673,10 +669,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
let (_, assigned_output) =
dynamic_lookup(config, region, &[index, output], &[dim_indices, input])?;
let end = start.elapsed();
trace!("select took: {:?}", end);
Ok(assigned_output)
}
fn one_hot<F: PrimeField + TensorType + PartialOrd>(
fn one_hot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -692,7 +691,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
let output: ValTensor<F> = if is_assigned {
let int_evals = input.get_int_evals()?;
let res = tensor::ops::one_hot(&int_evals, num_classes, 1)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<_>>()
} else {
@@ -728,12 +727,13 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
}
/// Dynamic lookup
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
lookups: &[ValTensor<F>; 2],
tables: &[ValTensor<F>; 2],
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let start = instant::Instant::now();
// if not all lookups same length err
if lookups[0].len() != lookups[1].len() {
return Err("lookups must be same length".into());
@@ -753,20 +753,28 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookups.tables[1], &table_1)?;
let table_len = table_0.len();
trace!("assigning tables took: {:?}", start.elapsed());
// now create a vartensor of constants for the dynamic lookup index
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
let _table_index =
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
trace!("assigning table index took: {:?}", start.elapsed());
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();
trace!("assigning lookups took: {:?}", start.elapsed());
// now set the lookup index
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
trace!("assigning lookup index took: {:?}", start.elapsed());
if !region.is_dummy() {
(0..table_len)
.map(|i| {
@@ -802,11 +810,14 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
region.increment_dynamic_lookup_index(1);
region.increment(lookup_len);
let end = start.elapsed();
trace!("dynamic lookup took: {:?}", end);
Ok((lookup_0, lookup_1))
}
/// Shuffle arg
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input: &[ValTensor<F>; 1],
@@ -869,7 +880,7 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
}
/// One hot accumulated layout
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -922,7 +933,7 @@ pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -950,7 +961,7 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -973,7 +984,7 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1024,7 +1035,7 @@ pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1032,6 +1043,7 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
is_flat_index: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let start_time = instant::Instant::now();
let index = values[0].clone();
if !is_flat_index {
assert_eq!(index.dims().len(), dims.len());
@@ -1105,6 +1117,9 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
region.apply_in_loop(&mut output, inner_loop_function)?;
let elapsed = start_time.elapsed();
trace!("linearize_element_index took: {:?}", elapsed);
Ok(output.into())
}
@@ -1125,7 +1140,7 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
/// If indices_shape[-1] == r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b, where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
/// Let us think of each such r-b ranked tensor as indices_slice. Each scalar value corresponding to data[0:b-1,indices_slice] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Example 1 below)
/// If indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b. Let us think of each such tensors as indices_slice. Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Examples 2, 3, 4 and 5 below)
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1232,7 +1247,6 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
let const_offset = create_constant_tensor(const_offset, 1);
let mut results = vec![];
@@ -1250,16 +1264,18 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
let res = sum(config, region, &[res])?;
results.push(res.get_inner_tensor()?.clone());
// assert than res is less than the product of the dims
assert!(
if region.witness_gen() {
assert!(
res.get_int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as i128),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
}
let result_tensor = Tensor::from(results.into_iter());
@@ -1273,7 +1289,9 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
Ok(output.into())
}
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn get_missing_set_elements<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1304,7 +1322,7 @@ pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
}
fullset_evals
.iter()
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1337,7 +1355,7 @@ pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -1354,14 +1372,14 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let claimed_output: ValTensor<F> = if is_assigned && region.witness_gen() {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1419,7 +1437,7 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Scatter Nd
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -1433,14 +1451,14 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let claimed_output: ValTensor<F> = if is_assigned && region.witness_gen() {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter_nd(&input_inner, &index_inner, &src_inner)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1457,7 +1475,6 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) =
gather_nd(config, region, &[claimed_output.clone(), index.clone()], 0)?;
@@ -1498,7 +1515,7 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
}
/// sum accumulated layout
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1581,7 +1598,7 @@ pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
}
/// product accumulated layout
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1661,7 +1678,7 @@ pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
}
/// Axes wise op wrapper
fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1722,7 +1739,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
}
/// Sum accumulated layout
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1733,7 +1750,7 @@ pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Sum accumulated layout
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1744,7 +1761,7 @@ pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// argmax layout
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1762,7 +1779,7 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Max accumulated layout
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1774,7 +1791,7 @@ pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmin layout
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1792,7 +1809,7 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Min accumulated layout
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1804,7 +1821,7 @@ pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Pairwise (elementwise) op layout
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1958,8 +1975,23 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
pub(crate) fn mean_of_squares_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
axes: &[usize],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let squared = pow(config, region, values, 2)?;
let sum_squared = sum_axes(config, region, &[squared], axes)?;
let dividand: usize = values[0].len() / sum_squared.len();
let mean_squared = div(config, region, &[sum_squared], F::from(dividand as u64))?;
Ok(mean_squared)
}
/// expand the tensor to the given shape
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1972,7 +2004,7 @@ pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1995,7 +2027,7 @@ pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2018,7 +2050,7 @@ pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2028,7 +2060,7 @@ pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2038,7 +2070,7 @@ pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
}
/// And boolean operation
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2049,7 +2081,7 @@ pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
}
/// Or boolean operation
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2065,7 +2097,7 @@ pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
}
/// Equality boolean operation
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2075,7 +2107,7 @@ pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
}
/// Equality boolean operation
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2109,7 +2141,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
}
/// Xor boolean operation
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2135,7 +2167,7 @@ pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
}
/// Not boolean operation
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2151,7 +2183,7 @@ pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
}
/// Iff
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -2175,7 +2207,7 @@ pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
}
/// Negation operation accumulated layout
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2185,7 +2217,7 @@ pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
}
/// Sumpool accumulated layout
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
@@ -2239,7 +2271,7 @@ pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
}
/// Convolution accumulated layout
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2304,7 +2336,7 @@ pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
/// DeConvolution accumulated layout
pub(crate) fn deconv<
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
@@ -2397,7 +2429,7 @@ pub(crate) fn deconv<
/// Convolution accumulated layout
pub(crate) fn conv<
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
@@ -2578,7 +2610,7 @@ pub(crate) fn conv<
}
/// Power accumulated layout
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2594,7 +2626,7 @@ pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
}
/// Rescaled op accumulated layout
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
@@ -2616,7 +2648,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
}
/// Dummy (no contraints) reshape layout
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
new_dims: &[usize],
) -> Result<ValTensor<F>, Box<dyn Error>> {
@@ -2626,7 +2658,7 @@ pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
}
/// Dummy (no contraints) move_axis layout
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
source: usize,
destination: usize,
@@ -2637,7 +2669,7 @@ pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
}
/// resize layout
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2651,7 +2683,7 @@ pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
}
/// Slice layout
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2674,7 +2706,7 @@ pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
}
/// Trilu layout
pub(crate) fn trilu<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn trilu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2698,7 +2730,7 @@ pub(crate) fn trilu<F: PrimeField + TensorType + PartialOrd>(
}
/// Concat layout
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>],
axis: &usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
@@ -2710,7 +2742,7 @@ pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
}
/// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2725,7 +2757,7 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
}
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2761,7 +2793,7 @@ pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
}
/// Downsample layout
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2778,7 +2810,7 @@ pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for enforcing two sets of cells to be equal
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2804,7 +2836,7 @@ pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for range check.
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2860,12 +2892,13 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
if region.throw_range_check_error() {
let is_assigned = !w.any_unknowns()?;
if is_assigned && region.witness_gen() {
// assert is within range
let int_values = w.get_int_evals()?;
for v in int_values {
if v < range.0 || v > range.1 {
log::debug!("Value ({:?}) out of range: {:?}", v, range);
for v in int_values.iter() {
if v < &range.0 || v > &range.1 {
log::error!("Value ({:?}) out of range: {:?}", v, range);
return Err(Box::new(TensorError::TableLookupError));
}
}
@@ -2885,7 +2918,7 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for nonlinearity check.
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2987,7 +3020,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmax
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3023,7 +3056,7 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmin
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3059,7 +3092,7 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
}
/// max layout
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3069,7 +3102,7 @@ pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
}
/// min layout
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3077,7 +3110,7 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
_sort_ascending(config, region, values)?.get_slice(&[0..1])
}
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3180,18 +3213,19 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
}
/// softmax layout
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
axes: &[usize],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let soft_max_at_scale = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
softmax(config, region, values, scale)
softmax(config, region, values, input_scale, output_scale)
};
let output = multi_dim_axes_op(config, region, values, axes, soft_max_at_scale)?;
@@ -3199,33 +3233,66 @@ pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// softmax func
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd>(
/// percent func
pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// elementwise exponential
let ex = nonlinearity(config, region, values, &LookupOp::Exp { scale })?;
let is_assigned = values[0].all_prev_assigned();
let mut input = values[0].clone();
if !is_assigned {
input = region.assign(&config.custom_gates.inputs[0], &values[0])?;
region.increment(input.len());
};
// sum of exps
let denom = sum(config, region, &[ex.clone()])?;
// get the inverse
let felt_scale = F::from(scale.0 as u64);
let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?;
let denom = sum(config, region, &[input.clone()])?;
let input_felt_scale = F::from(input_scale.0 as u64);
let output_felt_scale = F::from(output_scale.0 as u64);
let inv_denom = recip(
config,
region,
&[denom],
input_felt_scale,
output_felt_scale,
)?;
// product of num * (1 / denom) = 2*output_scale
let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?;
let percent = pairwise(config, region, &[input, inv_denom], BaseOp::Mult)?;
Ok(softmax)
// rebase the percent to 2x the scale
loop_div(config, region, &[percent], input_felt_scale)
}
/// softmax func
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// get the max then subtract it
let max_val = max(config, region, values)?;
// rebase the input to 0
let sub = pairwise(config, region, &[values[0].clone(), max_val], BaseOp::Sub)?;
// elementwise exponential
let ex = nonlinearity(
config,
region,
&[sub],
&LookupOp::Exp { scale: input_scale },
)?;
percent(config, region, &[ex.clone()], input_scale, output_scale)
}
/// Checks that the percent error between the expected public output and the actual output value
/// is within the percent error expressed by the `tol` input, where `tol == 1.0` means the percent
/// error tolerance is 1 percent.
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],

View File

@@ -137,7 +137,7 @@ impl LookupOp {
}
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self

View File

@@ -27,12 +27,14 @@ pub mod region;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
std::fmt::Debug + Send + Sync + Any
{
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
/// Returns a string representation of the operation.
@@ -98,7 +100,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
}
}
impl<F: PrimeField + TensorType + PartialOrd> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -165,7 +167,7 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(self.scale)
}
@@ -226,7 +228,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(0)
}
@@ -256,7 +258,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
pub quantized_values: Tensor<F>,
///
@@ -266,7 +268,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -293,8 +295,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for Constant<F>
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self

View File

@@ -1,6 +1,6 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{self, Tensor, TensorError},
};
@@ -62,6 +62,9 @@ pub enum PolyOp {
Sum {
axes: Vec<usize>,
},
MeanOfSquares {
axes: Vec<usize>,
},
Prod {
axes: Vec<usize>,
len_prod: usize,
@@ -89,8 +92,14 @@ pub enum PolyOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for PolyOp
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -99,10 +108,28 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::GatherND {
batch_dims,
indices,
} => format!(
"GATHERND (batch_dims={}, constant_idx{})",
batch_dims,
indices.is_some()
),
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
PolyOp::ScatterElements { dim, constant_idx } => format!(
"SCATTERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::ScatterND { constant_idx } => {
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
}
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -140,6 +167,10 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut inputs = inputs.to_vec();
let res = match &self {
PolyOp::MeanOfSquares { axes } => {
let x = inputs[0].map(|x| felt_to_i128(x));
Ok(tensor::ops::nonlinearities::mean_of_squares_axes(&x, axes).map(i128_to_felt))
}
PolyOp::MultiBroadcastTo { shape } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch(
@@ -286,6 +317,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
PolyOp::MeanOfSquares { axes } => {
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
@@ -398,6 +432,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {

View File

@@ -2,24 +2,28 @@ use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
use halo2_proofs::{
circuit::Region,
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use portable_atomic::AtomicI128 as AtomicInt;
use std::{
cell::RefCell,
collections::HashSet,
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
@@ -120,12 +124,11 @@ impl From<Box<dyn std::error::Error>> for RegionError {
#[derive(Debug)]
/// A context for a region
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
@@ -133,13 +136,34 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
throw_range_check_error: bool,
witness_gen: bool,
assigned_constants: ConstantsMap<F>,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
#[cfg(not(target_arch = "wasm32"))]
///
pub fn increment_total_constants(&mut self, n: usize) {
self.total_constants += n;
pub fn debug_report(&self) {
log::debug!(
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
self.row().to_string().blue(),
self.linear_coord().to_string().yellow(),
self.total_constants().to_string().red(),
self.max_lookup_inputs().to_string().green(),
self.min_lookup_inputs().to_string().green(),
self.max_range_size().to_string().green(),
self.dynamic_lookup_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green());
}
///
pub fn assigned_constants(&self) -> &ConstantsMap<F> {
&self.assigned_constants
}
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants);
}
///
@@ -163,8 +187,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
pub fn witness_gen(&self) -> bool {
self.witness_gen
}
/// Create a new region context
@@ -177,7 +201,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -185,9 +208,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: true,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_with_constants(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
constants: ConstantsMap<F>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols);
new_self.assigned_constants = constants;
new_self
}
/// Create a new region context from a wrapped region
pub fn from_wrapped_region(
region: Option<RefCell<Region<'a, F>>>,
@@ -202,7 +238,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
@@ -210,16 +245,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: false,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -228,7 +260,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -236,17 +267,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy_with_constants(
pub fn new_dummy_with_linear_coord(
row: usize,
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
witness_gen: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -254,7 +285,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -262,7 +292,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
assigned_constants: HashMap::new(),
}
}
@@ -312,29 +343,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(), RegionError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output
.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
let starting_constants = constants.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_constants(
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
starting_constants,
self.num_inner_cols,
self.throw_range_check_error,
self.witness_gen,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -343,10 +372,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
constants.fetch_add(
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
@@ -362,11 +387,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
res
})
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
@@ -410,6 +437,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?;
Ok(())
}
@@ -435,7 +470,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
if range.0 > range.1 {
return Err("update_max_min_lookup_range: invalid range".into());
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
}
let range_size = (range.1 - range.0).abs();
@@ -477,7 +512,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Get the total number of constants
pub fn total_constants(&self) -> usize {
self.total_constants
self.assigned_constants.len()
}
/// Get the dynamic lookup index
@@ -525,26 +560,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.max_range_size
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;
if let Some(region) = &self.region {
let cell = var.assign_constant(&mut region.borrow_mut(), self.linear_coord, value)?;
Ok(cell.into())
} else {
Ok(value.into())
}
}
/// Assign a valtensor to a vartensor
pub fn assign(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(&mut region.borrow_mut(), self.linear_coord, values)
var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -560,14 +593,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -594,13 +631,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
} else {
self.total_constants += values.num_constants();
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
for o in ommissions {
self.total_constants -= inner_tensor.get_flat_index(**o).is_constant() as usize;
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
self.assigned_constants.extend(values_map);
Ok(values.clone())
}
}
@@ -615,24 +659,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len, total_assigned_constants) = var.assign_with_duplication(
let (res, len) = var.assign_with_duplication(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((res, len))
} else {
let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication(
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((values.clone(), len))
}
}
@@ -699,9 +743,4 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
Ok(())
}
/// increment constants
pub fn increment_constants(&mut self, n: usize) {
self.total_constants += n
}
}

View File

@@ -98,7 +98,7 @@ pub struct Table<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
@@ -138,7 +138,7 @@ pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
(range_len / (col_size as i128)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
@@ -275,7 +275,7 @@ pub struct RangeCheck<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
@@ -303,7 +303,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);

View File

@@ -1911,6 +1911,8 @@ mod add_with_overflow {
#[cfg(test)]
mod add_with_overflow_and_poseidon {
use std::collections::HashMap;
use halo2curves::bn256::Fr;
use crate::circuit::modules::{
@@ -1969,8 +1971,10 @@ mod add_with_overflow_and_poseidon {
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
PoseidonChip::new(config.poseidon.clone());
let assigned_inputs_a = poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0)?;
let assigned_inputs_b = poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1)?;
let assigned_inputs_a =
poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0, &mut HashMap::new())?;
let assigned_inputs_b =
poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1, &mut HashMap::new())?;
layouter.assign_region(|| "_new_module", |_| Ok(()))?;

View File

@@ -444,7 +444,7 @@ pub enum Commands {
disable_selector_compression: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Aggregates proofs :)
Aggregate {
@@ -479,7 +479,7 @@ pub enum Commands {
split_proofs: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
@@ -726,7 +726,7 @@ pub enum Commands {
logrows: u32,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl

View File

@@ -24,6 +24,8 @@ use crate::pfsys::{
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::{Commitments, RunArgs};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
@@ -337,7 +339,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
split_proofs,
disable_selector_compression,
commitment,
commitment.into(),
),
Commands::Aggregate {
proof_path,
@@ -358,7 +360,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
check_mode,
split_proofs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Verify {
@@ -382,7 +384,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
reduced_srs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(not(target_arch = "wasm32"))]
@@ -538,7 +540,7 @@ fn check_srs_hash(
let path = get_srs_path(logrows, srs_path, commitment);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
let predefined_hash = match crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) {
Some(h) => h,
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
};
@@ -584,7 +586,7 @@ pub(crate) async fn get_srs_cmd(
} else if let Some(settings_p) = settings_path {
if settings_p.exists() {
let settings = GraphSettings::load(&settings_p)?;
settings.run_args.commitment
settings.run_args.commitment.into()
} else {
return Err(err_string.into());
}
@@ -664,27 +666,23 @@ pub(crate) async fn gen_witness(
// if any of the settings have kzg visibility then we need to load the srs
let commitment: Commitments = settings.run_args.commitment.into();
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
if get_srs_path(
settings.run_args.logrows,
srs_path.clone(),
settings.run_args.commitment,
)
.exists()
{
match settings.run_args.commitment {
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
match Commitments::from(settings.run_args.commitment) {
Commitments::KZG => {
let srs: ParamsKZG<Bn256> = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<KZGCommitmentScheme<_>>(
&mut input,
vk.as_ref(),
Some(&srs),
false,
true,
)?
}
Commitments::IPA => {
@@ -692,22 +690,22 @@ pub(crate) async fn gen_witness(
load_params_prover::<IPACommitmentScheme<G1Affine>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<IPACommitmentScheme<_>>(
&mut input,
vk.as_ref(),
Some(&srs),
false,
true,
)?
}
}
} else {
warn!("SRS for poly commit does not exist (will be ignored)");
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
}
} else {
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
};
// print each variable tuple (symbol, value) as symbol=value
@@ -819,7 +817,15 @@ impl AccuracyResults {
let error = (original.clone() - calibrated.clone())?;
let abs_error = error.map(|x| x.abs());
let squared_error = error.map(|x| x.powi(2));
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
let percentage_error = error.enum_map(|i, x| {
// if everything is 0 then we can't divide by 0 so we just return 0
let res = if original[i] == 0.0 && x == 0.0 {
0.0
} else {
x / original[i]
};
Ok::<f32, TensorError>(res)
})?;
let abs_percentage_error = percentage_error.map(|x| x.abs());
errors.extend(error);
@@ -888,6 +894,7 @@ pub(crate) fn calibrate(
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use log::error;
use std::collections::HashMap;
use tabled::Table;
@@ -900,9 +907,9 @@ pub(crate) fn calibrate(
let model = Model::from_run_args(&settings.run_args, &model_path)?;
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());
info!("num calibration batches: {}", chunks.len());
info!("running onnx predictions...");
debug!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(
&settings.run_args,
&model_path,
@@ -970,10 +977,18 @@ pub(crate) fn calibrate(
let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");
let mut num_failed = 0;
let mut num_passed = 0;
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().blue(),
div_rebasing.to_string().yellow(),
num_failed.to_string().red(),
num_passed.to_string().green()
));
let key = (
@@ -1007,7 +1022,9 @@ pub(crate) fn calibrate(
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
debug!("circuit creation from run args failed: {:?}", e);
error!("circuit creation from run args failed: {:?}", e);
pb.inc(1);
num_failed += 1;
continue;
}
};
@@ -1039,7 +1056,9 @@ pub(crate) fn calibrate(
Ok(_) => (),
// typically errors will be due to the circuit overflowing the i128 limit
Err(e) => {
debug!("forward pass failed: {:?}", e);
error!("forward pass failed: {:?}", e);
pb.inc(1);
num_failed += 1;
continue;
}
}
@@ -1104,8 +1123,10 @@ pub(crate) fn calibrate(
"found settings: \n {}",
found_settings.as_json()?.to_colored_json_auto()?
);
num_passed += 1;
} else {
debug!("calibration failed {}", res.err().unwrap());
error!("calibration failed {}", res.err().unwrap());
num_failed += 1;
}
pb.inc(1);
@@ -1278,17 +1299,19 @@ pub(crate) fn create_evm_verifier(
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1322,17 +1345,18 @@ pub(crate) fn create_evm_vk(
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1601,8 +1625,9 @@ pub(crate) fn setup(
}
let logrows = circuit.settings().run_args.logrows;
let commitment: Commitments = circuit.settings().run_args.commitment.into();
let pk = match circuit.settings().run_args.commitment {
let pk = match commitment {
Commitments::KZG => {
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
@@ -1711,7 +1736,8 @@ pub(crate) fn prove(
let transcript: TranscriptType = proof_type.into();
let proof_split_commits: Option<ProofSplitCommit> = data.into();
let commitment = circuit_settings.run_args.commitment;
let commitment = circuit_settings.run_args.commitment.into();
let logrows = circuit_settings.run_args.logrows;
// creates and verifies the proof
let mut snark = match commitment {
Commitments::KZG => {
@@ -1720,7 +1746,7 @@ pub(crate) fn prove(
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
logrows,
Commitments::KZG,
)?;
match strategy {
@@ -1879,7 +1905,9 @@ pub(crate) fn mock_aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG").into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -1922,7 +1950,9 @@ pub(crate) fn setup_aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG",).into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -1983,7 +2013,9 @@ pub(crate) fn aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG").into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -2156,8 +2188,9 @@ pub(crate) fn verify(
let circuit_settings = GraphSettings::load(&settings_path)?;
let logrows = circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
match circuit_settings.run_args.commitment {
match commitment {
Commitments::KZG => {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let params: ParamsKZG<Bn256> = if reduced_srs {

View File

@@ -26,6 +26,7 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::ConstantsMap;
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
@@ -38,7 +39,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
@@ -155,7 +156,7 @@ use std::cell::RefCell;
thread_local!(
/// This is a global variable that holds the settings for the graph
/// This is used to pass settings to the layouter and other parts of the circuit without needing to heavily modify the Halo2 API in a new fork
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = RefCell::new(None)
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = const { RefCell::new(None) }
);
/// Result from a forward pass
@@ -1051,12 +1052,10 @@ impl GraphCircuit {
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let margin = (
(
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
);
margin
)
}
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
@@ -1240,7 +1239,7 @@ impl GraphCircuit {
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1289,7 +1288,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
.forward(inputs, &self.settings().run_args, witness_gen)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1452,7 +1451,8 @@ impl GraphCircuit {
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
/// The configuration for the graph circuit
pub struct CircuitSize {
num_instances: usize,
num_advice_columns: usize,
num_fixed: usize,
@@ -1462,7 +1462,8 @@ struct CircuitSize {
}
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
///
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),
@@ -1604,6 +1605,8 @@ impl Circuit<Fp> for GraphCircuit {
let output_vis = &self.settings().run_args.output_visibility;
let mut graph_modules = GraphModules::new();
let mut constants = ConstantsMap::new();
let mut config = config.clone();
let mut inputs = self
@@ -1649,6 +1652,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut input_outlets,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace inputs with the outlets
for (i, outlet) in outlets.iter().enumerate() {
@@ -1661,6 +1665,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut inputs,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
}
@@ -1697,6 +1702,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut flattened_params,
param_visibility,
&mut instance_offset,
&mut constants,
)?;
let shapes = self.model().const_shapes();
@@ -1725,6 +1731,7 @@ impl Circuit<Fp> for GraphCircuit {
&inputs,
&mut vars,
&outputs,
&mut constants,
)
.map_err(|e| {
log::error!("{}", e);
@@ -1749,6 +1756,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut output_outlets,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace outputs with the outlets
@@ -1762,6 +1770,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut outputs,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
}

View File

@@ -5,6 +5,7 @@ use super::vars::*;
use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
use crate::circuit::table::Range;
use crate::circuit::Input;
@@ -404,7 +405,7 @@ impl ParsedNodes {
.get(input)
.ok_or(GraphError::MissingNode(*input))?;
let input_dims = node.out_dims();
let input_dim = input_dims.get(0).ok_or(GraphError::MissingNode(*input))?;
let input_dim = input_dims.first().ok_or(GraphError::MissingNode(*input))?;
inputs.push(input_dim.clone());
}
@@ -514,21 +515,24 @@ impl Model {
instance_shapes.len().to_string().blue(),
"instances".blue()
);
// this is the total number of variables we will need to allocate
// for the circuit
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = self
.graph
.input_shapes()?
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
let len = shape.iter().product();
let mut t: ValTensor<Fp> = (0..len)
.map(|_| {
if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
t.reshape(shape)?;
Ok(t)
})
@@ -577,13 +581,13 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?;
Ok(res.into())
}
@@ -799,13 +803,18 @@ impl Model {
let input_state_idx = input_state_idx(&input_mappings);
let mut output_mappings = vec![];
for mapping in b.output_mapping.iter() {
for (i, mapping) in b.output_mapping.iter().enumerate() {
let mut mappings = vec![];
if let Some(outlet) = mapping.last_value_slot {
mappings.push(OutputMapping::Single {
outlet,
is_state: mapping.state,
});
} else if mapping.state {
mappings.push(OutputMapping::Single {
outlet: i,
is_state: mapping.state,
});
}
if let Some(last) = mapping.scan {
mappings.push(OutputMapping::Stacked {
@@ -814,6 +823,7 @@ impl Model {
is_state: false,
});
}
output_mappings.push(mappings);
}
@@ -1071,6 +1081,8 @@ impl Model {
/// * `layouter` - Halo2 Layouter.
/// * `inputs` - The values to feed into the circuit.
/// * `vars` - The variables for the circuit.
/// * `witnessed_outputs` - The values to compare against.
/// * `constants` - The constants for the circuit.
pub fn layout(
&self,
mut config: ModelConfig,
@@ -1079,6 +1091,7 @@ impl Model {
inputs: &[ValTensor<Fp>],
vars: &mut ModelVars<Fp>,
witnessed_outputs: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
info!("model layout...");
@@ -1104,14 +1117,12 @@ impl Model {
config.base.layout_tables(layouter)?;
config.base.layout_range_checks(layouter)?;
let mut num_rows = 0;
let mut linear_coord = 0;
let mut total_const_size = 0;
let original_constants = constants.clone();
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new(region, 0, run_args.num_inner_cols);
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1157,29 +1168,17 @@ impl Model {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
} else if !run_args.output_visibility.is_private() {
for output in &outputs {
thread_safe_region.increment_total_constants(output.num_constants());
}
}
num_rows = thread_safe_region.row();
linear_coord = thread_safe_region.linear_coord();
total_const_size = thread_safe_region.total_constants();
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
thread_safe_region.debug_report();
*constants = thread_safe_region.assigned_constants().clone();
Ok(outputs)
},
)?;
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
num_rows.to_string().blue(),
"rows".blue(),
linear_coord.to_string().yellow(),
total_const_size.to_string().red()
);
)?;
let duration = start_time.elapsed();
trace!("model layout took: {:?}", duration);
@@ -1201,6 +1200,20 @@ impl Model {
.collect();
for (idx, node) in self.graph.nodes.iter() {
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
node.inputs()
.iter()
@@ -1212,31 +1225,11 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!("output dims: {:?}", node.out_dims());
debug!(
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
idx,
node.as_str(),
region.row(),
region.linear_coord(),
region.total_constants(),
region.max_lookup_inputs(),
region.min_lookup_inputs()
);
debug!("dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
match &node {
NodeType::Node(n) => {
@@ -1277,8 +1270,8 @@ impl Model {
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, inputs, model.graph.inputs
"{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}",
num_iter, inputs, model.graph.inputs, model.graph.outputs
);
let mut full_results: Vec<ValTensor<Fp>> = vec![];
@@ -1310,6 +1303,7 @@ impl Model {
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
let mut outlets = BTreeMap::new();
let mut stacked_outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res) {
for mapping in mappings {
@@ -1322,25 +1316,42 @@ impl Model {
let stacked_res = full_results[*outlet]
.clone()
.concat_axis(outlet_res.clone(), axis)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
stacked_outlets.insert(outlet, stacked_res);
}
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
full_results = outlets.into_values().collect_vec();
// now extend with stacked elements
let mut pre_stacked_outlets = outlets.clone();
pre_stacked_outlets.extend(stacked_outlets);
let outlets = outlets.into_values().collect_vec();
full_results = pre_stacked_outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
assert_eq!(
input_states.len(),
output_states.len(),
"input and output states must be the same length, got {:?} and {:?}",
input_mappings,
output_mappings
);
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
values[*input_idx] = full_results[output_idx].clone();
assert_eq!(
values[*input_idx].dims(),
outlets[output_idx].dims(),
"input and output dims must be the same, got {:?} and {:?}",
values[*input_idx].dims(),
outlets[output_idx].dims()
);
values[*input_idx] = outlets[output_idx].clone();
}
}
@@ -1380,7 +1391,7 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1399,29 +1410,31 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
let default_value = if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
let mut comparator: ValTensor<Fp> = (0..output.len())
.map(|_| {
if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
comparator.reshape(output.dims())?;
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
@@ -1432,7 +1445,7 @@ impl Model {
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
region.update_constants(output.create_constants_map());
}
}
@@ -1441,14 +1454,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
region.row().to_string().blue(),
"rows".blue(),
region.linear_coord().to_string().yellow(),
region.total_constants().to_string().red()
);
region.debug_report();
let outputs = outputs
.iter()

View File

@@ -2,6 +2,7 @@ use crate::circuit::modules::polycommit::{PolyCommitChip, PolyCommitConfig};
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use crate::circuit::modules::Module;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor};
use halo2_proofs::circuit::Layouter;
use halo2_proofs::plonk::{Column, ConstraintSystem, Error, Instance, VerifyingKey};
@@ -211,12 +212,13 @@ impl GraphModules {
layouter: &mut impl Layouter<Fp>,
x: &mut Vec<ValTensor<Fp>>,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
// reserve module 0 for ... modules
// hash the input and replace the constrained cells in the input
let cloned_x = (*x).clone();
x[0] = module
.layout(layouter, &cloned_x, instance_offset.to_owned())
.layout(layouter, &cloned_x, instance_offset.to_owned(), constants)
.unwrap();
for inc in module.instance_increment_input().iter() {
// increment the instance offset to make way for future module layouts
@@ -234,6 +236,7 @@ impl GraphModules {
values: &mut [ValTensor<Fp>],
element_visibility: &Visibility,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
if element_visibility.is_polycommit() && !values.is_empty() {
// concat values and sk to get the inputs
@@ -248,7 +251,7 @@ impl GraphModules {
layouter
.assign_region(|| format!("_enter_module_{}", module_offset), |_| Ok(()))
.unwrap();
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
// increment the current index
self.polycommit_idx += 1;
});
@@ -270,7 +273,7 @@ impl GraphModules {
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
// layout the module
inputs.iter_mut().for_each(|x| {
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
});
// replace the inputs with the outputs
values.iter_mut().enumerate().for_each(|(i, x)| {

View File

@@ -509,7 +509,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
@@ -582,7 +582,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -734,6 +734,19 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Sum { axes })
}
"Reduce<MeanOfSquares>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"mean of squares".to_string(),
)));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();
SupportedOp::Linear(PolyOp::MeanOfSquares { axes })
}
"Max" => {
// Extract the max value
// first find the input that is a constant
@@ -1072,8 +1085,12 @@ pub fn new_op_from_onnx(
}
};
let in_scale = inputs[0].out_scales()[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Softmax {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
})
}
@@ -1161,7 +1178,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar pow")
}

View File

@@ -346,7 +346,7 @@ pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
pub instance: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
/// Get instance col
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
if let Some(instance) = &self.instance {

View File

@@ -23,7 +23,6 @@
)]
// we allow this for our dynamic range based indexing scheme
#![allow(clippy::single_range_in_vec_init)]
#![feature(round_ties_even)]
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
@@ -115,6 +114,12 @@ pub enum Commitments {
IPA,
}
impl From<Option<Commitments>> for Commitments {
fn from(value: Option<Commitments>) -> Self {
value.unwrap_or(Commitments::KZG)
}
}
impl FromStr for Commitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
@@ -215,7 +220,7 @@ pub struct RunArgs {
pub check_mode: CheckMode,
/// commitment scheme
#[arg(long, default_value = "kzg")]
pub commitment: Commitments,
pub commitment: Option<Commitments>,
}
impl Default for RunArgs {
@@ -235,7 +240,7 @@ impl Default for RunArgs {
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: Commitments::KZG,
commitment: None,
}
}
}

View File

@@ -1,4 +1,8 @@
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::CircuitSize;
use crate::pfsys::{Snark, SnarkWitness};
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::plonk::{self};
use halo2_proofs::{
@@ -16,6 +20,8 @@ use halo2_wrong_ecc::{
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::debug;
use log::trace;
use rand::rngs::OsRng;
use snark_verifier::loader::native::NativeLoader;
@@ -193,6 +199,23 @@ impl AggregationConfig {
let main_gate_config = MainGate::<F>::configure(meta);
let range_config =
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
#[cfg(not(target_arch = "wasm32"))]
{
let circuit_size = CircuitSize::from_cs(meta, 23);
// not wasm
debug!(
"circuit size: \n {}",
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
.unwrap()
);
}
AggregationConfig {
main_gate_config,
range_config,

View File

@@ -33,6 +33,7 @@ use tokio::runtime::Runtime;
type PyFelt = String;
/// pyclass representing an enum
#[pyclass]
#[derive(Debug, Clone)]
enum PyTestDataSource {
@@ -51,14 +52,17 @@ impl From<PyTestDataSource> for TestDataSource {
}
}
/// pyclass containing the struct used for G1
/// pyclass containing the struct used for G1, this is mostly a helper class
#[pyclass]
#[derive(Debug, Clone)]
struct PyG1 {
#[pyo3(get, set)]
/// Field Element representing x
x: PyFelt,
#[pyo3(get, set)]
/// Field Element representing y
y: PyFelt,
/// Field Element representing y
#[pyo3(get, set)]
z: PyFelt,
}
@@ -134,39 +138,59 @@ impl pyo3::ToPyObject for PyG1Affine {
}
}
/// pyclass containing the struct used for run_args
/// Python class containing the struct used for run_args
///
/// Returns
/// -------
/// PyRunArgs
///
#[pyclass]
#[derive(Clone)]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
pub tolerance: f32,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing inputs
pub input_scale: crate::Scale,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing parameters
pub param_scale: crate::Scale,
#[pyo3(get, set)]
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
pub scale_rebase_multiplier: u32,
#[pyo3(get, set)]
/// list[int]: The min and max elements in the lookup table input column
pub lookup_range: crate::circuit::table::Range,
#[pyo3(get, set)]
/// int: The log_2 number of rows
pub logrows: u32,
#[pyo3(get, set)]
/// int: The number of inner columns used for the lookup table
pub num_inner_cols: usize,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub input_visibility: Visibility,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub output_visibility: Visibility,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub param_visibility: Visibility,
#[pyo3(get, set)]
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
/// bool: Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[pyo3(get, set)]
/// bool: Should constants with 0.0 fraction be rebased to scale 0
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
/// str: check mode, accepts `safe`, `unsafe`
pub check_mode: CheckMode,
#[pyo3(get, set)]
/// str: commitment type, accepts `kzg`, `ipa`
pub commitment: PyCommitments,
}
@@ -197,7 +221,7 @@ impl From<PyRunArgs> for RunArgs {
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: py_run_args.commitment.into(),
commitment: Some(py_run_args.commitment.into()),
}
}
}
@@ -226,7 +250,7 @@ impl Into<PyRunArgs> for RunArgs {
#[pyclass]
#[derive(Debug, Clone)]
/// Pyclass marking the type of commitment
/// pyclass representing an enum, denoting the type of commitment
pub enum PyCommitments {
/// KZG commitment
KZG,
@@ -234,6 +258,16 @@ pub enum PyCommitments {
IPA,
}
impl From<Option<Commitments>> for PyCommitments {
fn from(commitment: Option<Commitments>) -> Self {
match commitment {
Some(Commitments::KZG) => PyCommitments::KZG,
Some(Commitments::IPA) => PyCommitments::IPA,
None => PyCommitments::KZG,
}
}
}
impl From<PyCommitments> for Commitments {
fn from(py_commitments: PyCommitments) -> Self {
match py_commitments {
@@ -263,7 +297,19 @@ impl FromStr for PyCommitments {
}
}
/// Converts a felt to big endian
/// Converts a field element hex string to big endian
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
///
/// Returns
/// -------
/// str
/// field element represented as a string
///
#[pyfunction(signature = (
felt,
))]
@@ -273,22 +319,45 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
}
/// Converts a field element hex string to an integer
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
/// Returns
/// -------
/// int
///
#[pyfunction(signature = (
array,
felt,
))]
fn felt_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
fn felt_to_int(felt: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_i128(felt);
Ok(int_rep)
}
/// Converts a field eleement hex string to a floating point number
/// Converts a field element hex string to a floating point number
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
/// scale: float
/// The scaling factor used to convert the field element into a floating point representation
///
/// Returns
/// -------
/// float
///
#[pyfunction(signature = (
array,
felt,
scale
))]
fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_i128(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
@@ -296,9 +365,23 @@ fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
}
/// Converts a floating point element to a field element hex string
///
/// Arguments
/// -------
/// input: float
/// The field element represented as a string
///
/// scale: float
/// The scaling factor used to quantize the float into a field element
///
/// Returns
/// -------
/// str
/// The field element represented as a string
///
#[pyfunction(signature = (
input,
scale
input,
scale
))]
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
@@ -308,9 +391,20 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
}
/// Converts a buffer to vector of field elements
///
/// Arguments
/// -------
/// buffer: list[int]
/// List of integers representing a buffer
///
/// Returns
/// -------
/// list[str]
/// List of field elements represented as strings
///
#[pyfunction(signature = (
buffer
))]
))]
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
@@ -369,9 +463,20 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
}
/// Generate a poseidon hash.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// Returns
/// -------
/// list[str]
/// List of field elements represented as strings
///
#[pyfunction(signature = (
message,
))]
))]
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
@@ -392,12 +497,31 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
}
/// Generate a kzg commitment.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// vk_path: str
/// Path to the verification key
///
/// settings_path: str
/// Path to the settings file
///
/// srs_path: str
/// Path to the Structure Reference String (SRS) file
///
/// Returns
/// -------
/// list[PyG1Affine]
///
#[pyfunction(signature = (
message,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
))]
fn kzg_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -432,12 +556,31 @@ fn kzg_commit(
}
/// Generate an ipa commitment.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// vk_path: str
/// Path to the verification key
///
/// settings_path: str
/// Path to the settings file
///
/// srs_path: str
/// Path to the Structure Reference String (SRS) file
///
/// Returns
/// -------
/// list[PyG1Affine]
///
#[pyfunction(signature = (
message,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
))]
fn ipa_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -472,10 +615,19 @@ fn ipa_commit(
}
/// Swap the commitments in a proof
///
/// Arguments
/// -------
/// proof_path: str
/// Path to the proof file
///
/// witness_path: str
/// Path to the witness file
///
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF),
witness_path=PathBuf::from(DEFAULT_WITNESS),
))]
))]
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
@@ -484,11 +636,27 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
}
/// Generates a vk from a pk for a model circuit and saves it to a file
///
/// Arguments
/// -------
/// path_to_pk: str
/// Path to the proving key
///
/// circuit_settings_path: str
/// Path to the witness file
///
/// vk_output_path: str
/// Path to create the vk file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
path_to_pk=PathBuf::from(DEFAULT_PK),
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_output_path=PathBuf::from(DEFAULT_VK),
))]
))]
fn gen_vk_from_pk_single(
path_to_pk: PathBuf,
circuit_settings_path: PathBuf,
@@ -510,10 +678,22 @@ fn gen_vk_from_pk_single(
}
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
///
/// Arguments
/// -------
/// path_to_pk: str
/// Path to the proving key
///
/// vk_output_path: str
/// Path to create the vk file
///
/// Returns
/// -------
/// bool
#[pyfunction(signature = (
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
))]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
@@ -528,6 +708,17 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
}
/// Displays the table as a string in python
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// Returns
/// ---------
/// str
/// Table of the nodes in the onnx file
///
#[pyfunction(signature = (
model = PathBuf::from(DEFAULT_MODEL),
py_run_args = None
@@ -543,7 +734,16 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
}
}
/// generates the srs
/// Generates the Structured Reference String (SRS), use this only for testing purposes
///
/// Arguments
/// ---------
/// srs_path: str
/// Path to the create the SRS file
///
/// logrows: int
/// The number of logrows for the SRS file
///
#[pyfunction(signature = (
srs_path,
logrows,
@@ -554,7 +754,26 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
Ok(())
}
/// gets a public srs
/// Gets a public srs
///
/// Arguments
/// ---------
/// settings_path: str
/// Path to the settings file
///
/// logrows: int
/// The number of logrows for the SRS file
///
/// srs_path: str
/// Path to the create the SRS file
///
/// commitment: str
/// Specify the commitment used ("kzg", "ipa")
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
settings_path=PathBuf::from(DEFAULT_SETTINGS),
logrows=None,
@@ -587,7 +806,23 @@ fn get_srs(
Ok(true)
}
/// generates the circuit settings
/// Generates the circuit settings
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// output: str
/// Path to create the settings file
///
/// py_run_args: PyRunArgs
/// PyRunArgs object to initialize the settings
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
output=PathBuf::from(DEFAULT_SETTINGS),
@@ -608,7 +843,38 @@ fn gen_settings(
Ok(true)
}
/// calibrates the circuit settings
/// Calibrates the circuit settings
///
/// Arguments
/// ---------
/// data: str
/// Path to the calibration data
///
/// model: str
/// Path to the onnx file
///
/// settings: str
/// Path to the settings file
///
/// lookup_safety_margin: int
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
///
/// scales: list[int]
/// Optional scales to specifically try for calibration
///
/// scale_rebase_multiplier: list[int]
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
///
/// max_logrows: int
/// Optional max logrows to use for calibration
///
/// only_range_check_rebase: bool
/// Check ranges when rebasing
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
model = PathBuf::from(DEFAULT_MODEL),
@@ -650,7 +916,30 @@ fn calibrate_settings(
Ok(true)
}
/// runs the forward pass operation
/// Runs the forward pass operation to generate a witness
///
/// Arguments
/// ---------
/// data: str
/// Path to the data file
///
/// model: str
/// Path to the compiled model file
///
/// output: str
/// Path to create the witness file
///
/// vk_path: str
/// Path to the verification key
///
/// srs_path: str
/// Path to the SRS file
///
/// Returns
/// -------
/// dict
/// Python object containing the witness values
///
#[pyfunction(signature = (
data=PathBuf::from(DEFAULT_DATA),
model=PathBuf::from(DEFAULT_MODEL),
@@ -677,7 +966,20 @@ fn gen_witness(
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// mocks the prover
/// Mocks the prover
///
/// Arguments
/// ---------
/// witness: str
/// Path to the witness file
///
/// model: str
/// Path to the compiled model file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_MODEL),
@@ -690,7 +992,23 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
Ok(true)
}
/// mocks the aggregate prover
/// Mocks the aggregate prover
///
/// Arguments
/// ---------
/// aggregation_snarks: list[str]
/// List of paths to the relevant proof files
///
/// logrows: int
/// Number of logrows to use for the aggregation circuit
///
/// split_proofs: bool
/// Indicates whether the accumulated are segments of a larger proof
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
@@ -709,7 +1027,32 @@ fn mock_aggregate(
Ok(true)
}
/// runs the prover on a set of inputs
/// Runs the setup process
///
/// Arguments
/// ---------
/// model: str
/// Path to the compiled model file
///
/// vk_path: str
/// Path to create the verification key file
///
/// pk_path: str
/// Path to create the proving key file
///
/// srs_path: str
/// Path to the SRS file
///
/// witness_path: str
/// Path to the witness file
///
/// disable_selector_compression: bool
/// Whether to compress the selectors or not
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
vk_path=PathBuf::from(DEFAULT_VK),
@@ -742,7 +1085,32 @@ fn setup(
Ok(true)
}
/// runs the prover on a set of inputs
/// Runs the prover on a set of inputs
///
/// Arguments
/// ---------
/// witness: str
/// Path to the witness file
///
/// model: str
/// Path to the compiled model file
///
/// pk_path: str
/// Path to the proving key file
///
/// proof_path: str
/// Path to create the proof file
///
/// proof_type: str
/// Accepts `single`, `for-aggr`
///
/// srs_path: str
/// Path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_MODEL),
@@ -776,7 +1144,29 @@ fn prove(
Python::with_gil(|py| Ok(snark.to_object(py)))
}
/// verifies a given proof
/// Verifies a given proof
///
/// Arguments
/// ---------
/// proof_path: str
/// Path to create the proof file
///
/// settings_path: str
/// Path to the settings file
///
/// vk_path: str
/// Path to the verification key file
///
/// srs_path: str
/// Path to the SRS file
///
/// non_reduced_srs: bool
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -806,6 +1196,38 @@ fn verify(
Ok(true)
}
/// Runs the setup process for an aggregate setup
///
/// Arguments
/// ---------
/// sample_snarks: list[str]
/// List of paths to the various proofs
///
/// vk_path: str
/// Path to create the aggregated VK
///
/// pk_path: str
/// Path to create the aggregated PK
///
/// logrows: int
/// Number of logrows to use
///
/// split_proofs: bool
/// Whether the accumulated are segments of a larger proof
///
/// srs_path: str
/// Path to the SRS file
///
/// disable_selector_compression: bool
/// Whether to compress selectors
///
/// commitment: str
/// Accepts `kzg`, `ipa`
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
sample_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
@@ -844,6 +1266,23 @@ fn setup_aggregate(
Ok(true)
}
/// Compiles the circuit for use in other steps
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx model file
///
/// compiled_circuit: str
/// Path to output the compiled circuit
///
/// settings_path: str
/// Path to the settings files
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
@@ -862,7 +1301,41 @@ fn compile_circuit(
Ok(true)
}
/// creates an aggregated proof
/// Creates an aggregated proof
///
/// Arguments
/// ---------
/// aggregation_snarks: list[str]
/// List of paths to the various proofs
///
/// proof_path: str
/// Path to output the aggregated proof
///
/// vk_path: str
/// Path to the VK file
///
/// transcript:
/// Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
///
/// logrows:
/// Logrows used for aggregation circuit
///
/// check_mode: str
/// Run sanity checks during calculations. Accepts `safe` or `unsafe`
///
/// split-proofs: bool
/// Whether the accumulated proofs are segments of a larger circuit
///
/// srs_path: str
/// Path to the SRS used
///
/// commitment: str
/// Accepts "kzg" or "ipa"
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
@@ -905,7 +1378,32 @@ fn aggregate(
Ok(true)
}
/// verifies and aggregate proof
/// Verifies and aggregate proof
///
/// Arguments
/// ---------
/// proof_path: str
/// The path to the proof file
///
/// vk_path: str
/// The path to the verification key file
///
/// logrows: int
/// logrows used for aggregation circuit
///
/// commitment: str
/// Accepts "kzg" or "ipa"
///
/// reduced_srs: bool
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
///
/// srs_path: str
/// The path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
vk_path=PathBuf::from(DEFAULT_VK),
@@ -938,7 +1436,32 @@ fn verify_aggr(
Ok(true)
}
/// creates an EVM compatible verifier, you will need solc installed in your environment to run this
/// Creates an EVM compatible verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// vk_path: str
/// The path to the verification key file
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifier
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -971,7 +1494,26 @@ fn create_evm_verifier(
Ok(true)
}
// creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// input_data: str
/// The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifier
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
input_data=PathBuf::from(DEFAULT_DATA),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -993,6 +1535,32 @@ fn create_evm_data_attestation(
Ok(true)
}
/// Setup test evm witness
///
/// Arguments
/// ---------
/// data_path: str
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
///
/// compiled_circuit_path: str
/// The path to the compiled model file (generated using the compile-circuit command)
///
/// test_data: str
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
///
/// input_sources: str
/// Where the input data comes from
///
/// output_source: str
/// Where the output data comes from
///
/// rpc_url: str
/// RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
data_path,
compiled_circuit_path,
@@ -1027,6 +1595,7 @@ fn setup_test_evm_witness(
Ok(true)
}
/// deploys the solidity verifier
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
@@ -1059,6 +1628,7 @@ fn deploy_evm(
Ok(true)
}
/// deploys the solidity vk verifier
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
@@ -1091,6 +1661,7 @@ fn deploy_vk_evm(
Ok(true)
}
/// deploys the solidity da verifier
#[pyfunction(signature = (
addr_path,
input_data,
@@ -1128,6 +1699,27 @@ fn deploy_da_evm(
Ok(true)
}
/// verifies an evm compatible proof, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// addr_verifier: str
/// The path to verifier contract's address
///
/// proof_path: str
/// The path to the proof file (generated using the prove command)
///
/// rpc_url: str
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
///
/// addr_da: str
/// does the verifier use data attestation ?
///
/// addr_vk: str
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
addr_verifier,
proof_path=PathBuf::from(DEFAULT_PROOF),
@@ -1173,7 +1765,35 @@ fn verify_evm(
Ok(true)
}
/// creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
/// Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// aggregation_settings: str
/// path to the settings file
///
/// vk_path: str
/// The path to load the desired verification key file
///
/// sol_code_path: str
/// The path to the Solidity code
///
/// abi_path: str
/// The path to output the Solidity verifier ABI
///
/// logrows: int
/// Number of logrows used during aggregated setup
///
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_settings=vec![PathBuf::from(DEFAULT_PROOF)],
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),

View File

@@ -378,7 +378,7 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
{
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<Value<F>> {
let mut output = Vec::new();
for (_, x) in value.iter().enumerate() {
for x in value.iter() {
output.push(x.value_field().evaluate());
}
Tensor::new(Some(&output), value.dims()).unwrap()
@@ -434,6 +434,18 @@ impl<F: PrimeField + TensorType + Clone> From<Tensor<i128>> for Tensor<Value<F>>
}
}
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::FromParallelIterator<T> for Tensor<T>
{
fn from_par_iter<I>(par_iter: I) -> Self
where
I: maybe_rayon::iter::IntoParallelIterator<Item = T>,
{
let inner: Vec<T> = par_iter.into_par_iter().collect();
Tensor::new(Some(&inner), &[inner.len()]).unwrap()
}
}
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::IntoParallelIterator for Tensor<T>
{
@@ -922,6 +934,7 @@ impl<T: Clone + TensorType> Tensor<T> {
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
assert!(source < self.dims.len());
assert!(destination < self.dims.len());
let mut new_dims = self.dims.clone();
new_dims.remove(source);
new_dims.insert(destination, self.dims[source]);
@@ -953,6 +966,8 @@ impl<T: Clone + TensorType> Tensor<T> {
old_coord[source - 1] = *c;
} else if (i < source && source < destination)
|| (i < destination && source > destination)
|| (i > source && source > destination)
|| (i > destination && source < destination)
{
old_coord[i] = *c;
} else if i > source && source < destination {
@@ -965,7 +980,10 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
}
output.set(&coord, self.get(&old_coord));
let value = self.get(&old_coord);
output.set(&coord, value);
}
Ok(output)

View File

@@ -1875,11 +1875,7 @@ pub fn topk<T: TensorType + PartialOrd>(
let mut indexed_a = a.clone();
indexed_a.flatten();
let mut indexed_a = a
.iter()
.enumerate()
.map(|(i, x)| (i, x))
.collect::<Vec<_>>();
let mut indexed_a = a.iter().enumerate().collect::<Vec<_>>();
if largest {
indexed_a.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
@@ -3532,12 +3528,17 @@ pub mod nonlinearities {
}
/// softmax layout
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
pub fn softmax_axes(
a: &Tensor<i128>,
input_scale: f64,
output_scale: f64,
axes: &[usize],
) -> Tensor<i128> {
// we want this to be as small as possible so we set the output scale to 1
let dims = a.dims();
if dims.len() == 1 {
return softmax(a, scale);
return softmax(a, input_scale, output_scale);
}
let cartesian_coord = dims[..dims.len() - 1]
@@ -3560,7 +3561,7 @@ pub mod nonlinearities {
let softmax_input = a.get_slice(&sum_dims).unwrap();
let res = softmax(&softmax_input, scale);
let res = softmax(&softmax_input, input_scale, output_scale);
outputs.push(res);
}
@@ -3587,20 +3588,25 @@ pub mod nonlinearities {
/// Some(&[2, 2, 3, 2, 2, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = softmax(&x, 128.0);
/// let result = softmax(&x, 128.0, 128.0 * 128.0);
/// // doubles the scale of the input
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2734, 2734, 2755, 2734, 2734, 2692]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
pub fn softmax(a: &Tensor<i128>, input_scale: f64, output_scale: f64) -> Tensor<i128> {
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
let exp = exp(a, scale);
let exp = exp(a, input_scale);
let sum = sum(&exp).unwrap();
let inv_denom = recip(&sum, scale, scale);
(exp * inv_denom).unwrap()
let inv_denom = recip(&sum, input_scale, output_scale);
let mut res = (exp * inv_denom).unwrap();
res = res
.iter()
.map(|x| ((*x as f64) / input_scale).round() as i128)
.collect();
res.reshape(a.dims()).unwrap();
res
}
/// Applies range_check_percent
@@ -4398,6 +4404,32 @@ pub mod nonlinearities {
let sum = sum(a).unwrap();
const_div(&sum, (scale * a.len()) as f64)
}
/// Mean of squares axes
/// # Arguments
/// * `a` - Tensor
/// * `axis` - [usize]
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::nonlinearities::mean_of_squares_axes;
/// let x = Tensor::<i128>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = mean_of_squares_axes(&x, &[1]);
/// let expected = Tensor::<i128>::new(
/// Some(&[78, 1]),
/// &[2, 1],
/// ).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn mean_of_squares_axes(a: &Tensor<i128>, axes: &[usize]) -> Tensor<i128> {
let square = a.map(|a_i| a_i * a_i);
let sum = sum_axes(&square, axes).unwrap();
let denominator = a.len() / sum.len();
const_div(&sum, denominator as f64)
}
}
/// Ops that return the transcript i.e intermediate calcs of an op

View File

@@ -1,8 +1,12 @@
use core::{iter::FilterMap, slice::Iter};
use crate::circuit::region::ConstantsMap;
use super::{
ops::{intercalate_values, pad, resize},
*,
};
use halo2_proofs::{arithmetic::Field, plonk::Instance};
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
@@ -51,6 +55,24 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
}
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
/// Returns the inner cell of the [ValType].
pub fn cell(&self) -> Option<Cell> {
match self {
ValType::PrevAssigned(cell) => Some(cell.cell()),
ValType::AssignedConstant(cell, _) => Some(cell.cell()),
_ => None,
}
}
/// Returns the assigned cell of the [ValType].
pub fn assigned_cell(&self) -> Option<AssignedCell<F, F>> {
match self {
ValType::PrevAssigned(cell) => Some(cell.clone()),
ValType::AssignedConstant(cell, _) => Some(cell.clone()),
_ => None,
}
}
/// Returns true if the value is previously assigned.
pub fn is_prev_assigned(&self) -> bool {
matches!(
@@ -293,7 +315,7 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
}
}
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
pub fn new_instance(
cs: &mut ConstraintSystem<F>,
@@ -428,10 +450,37 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
/// Returns the number of constants in the [ValTensor].
pub fn num_constants(&self) -> usize {
pub fn create_constants_map_iterator(
&self,
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
ValTensor::Instance { .. } => 0,
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
}),
ValTensor::Instance { .. } => {
unreachable!("Instance tensors do not have constants")
}
}
}
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map(&self) -> ConstantsMap<F> {
match self {
ValTensor::Value { inner, .. } => inner
.par_iter()
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
}
}

View File

@@ -2,7 +2,7 @@ use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::CheckMode;
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
@@ -289,9 +289,10 @@ impl VarTensor {
&self,
region: &mut Region<F>,
offset: usize,
coord: usize,
constant: F,
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset);
let (x, y, z) = self.cartesian_coord(offset + coord);
match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice_from_constant(|| "constant", advices[x][y], z, constant)
@@ -304,33 +305,28 @@ impl VarTensor {
}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd>(
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<&usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
unimplemented!("cannot assign instance to advice columns with omissions")
}
ValTensor::Value { inner: v, .. } => Ok::<_, halo2_proofs::plonk::Error>(
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok(k);
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell = self.assign_value(region, offset, k.clone(), assigned_coord)?;
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
match k {
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
ValType::AssignedConstant(cell, f),
),
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
_ => Ok(ValType::PrevAssigned(cell)),
}
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
@@ -340,11 +336,12 @@ impl VarTensor {
}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign<F: PrimeField + TensorType + PartialOrd>(
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut res: ValTensor<F> = match values {
ValTensor::Instance {
@@ -382,14 +379,7 @@ impl VarTensor {
},
ValTensor::Value { inner: v, .. } => Ok(v
.enum_map(|coord, k| {
let cell = self.assign_value(region, offset, k.clone(), coord)?;
match k {
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
ValType::AssignedConstant(cell, f),
),
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
_ => Ok(ValType::PrevAssigned(cell)),
}
self.assign_value(region, offset, k.clone(), coord, constants)
})?
.into()),
}?;
@@ -399,13 +389,16 @@ impl VarTensor {
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn dummy_assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
pub fn dummy_assign_with_duplication<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
row: usize,
offset: usize,
values: &ValTensor<F>,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
@@ -430,21 +423,24 @@ impl VarTensor {
// duplicates every nth element to adjust for column overflow
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
let constants_map = res.create_constants_map();
constants.extend(constants_map);
let total_used_len = res.len();
let total_constants = res.num_constants();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
Ok((res, total_used_len, total_constants))
Ok((res, total_used_len))
}
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
row: usize,
@@ -452,7 +448,8 @@ impl VarTensor {
values: &ValTensor<F>,
check_mode: &CheckMode,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
let mut prev_cell = None;
match values {
@@ -494,7 +491,7 @@ impl VarTensor {
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
};
let cell = self.assign_value(region, offset, k.clone(), coord * step)?;
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
if single_inner_col {
if z == 0 {
@@ -502,28 +499,23 @@ impl VarTensor {
prev_cell = Some(cell.clone());
} else if coord > 0 && z == 0 && single_inner_col {
if let Some(prev_cell) = prev_cell.as_ref() {
region.constrain_equal(prev_cell.cell(),cell.cell())?;
let cell = cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
let prev_cell = prev_cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Error copy-constraining previous value: {:?}", (x,y));
return Err(halo2_proofs::plonk::Error::Synthesis);
}
}}
match k {
ValType::Constant(f) => {
Ok(ValType::AssignedConstant(cell, f))
},
ValType::AssignedConstant(_, f) => {
Ok(ValType::AssignedConstant(cell, f))
},
_ => {
Ok(ValType::PrevAssigned(cell))
}
}
Ok(cell)
})?.into()};
let total_used_len = res.len();
let total_constants = res.num_constants();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
@@ -542,42 +534,61 @@ impl VarTensor {
)};
}
Ok((res, total_used_len, total_constants))
Ok((res, total_used_len))
}
}
}
fn assign_value<F: PrimeField + TensorType + PartialOrd>(
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
k: ValType<F>,
coord: usize,
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset + coord);
match k {
let res = match k {
ValType::Value(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice(|| "k", advices[x][y], z, || v)
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
}
_ => unimplemented!(),
},
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => match &self {
ValType::PrevAssigned(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
v.copy_advice(|| "k", region, advices[x][y], z)
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
}
_ => {
error!("PrevAssigned is only supported for advice columns");
Err(halo2_proofs::plonk::Error::Synthesis)
_ => unimplemented!(),
},
ValType::AssignedConstant(v, val) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
}
_ => unimplemented!(),
},
ValType::AssignedValue(v) => match &self {
VarTensor::Advice { inner: advices, .. } => region
.assign_advice(|| "k", advices[x][y], z, || v)
.map(|a| a.evaluate()),
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
region
.assign_advice(|| "k", advices[x][y], z, || v)?
.evaluate(),
),
_ => unimplemented!(),
},
ValType::Constant(v) => self.assign_constant(region, offset + coord, v),
}
ValType::Constant(v) => {
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
let value = ValType::AssignedConstant(
self.assign_constant(region, offset, coord, v)?,
v,
);
e.insert(value.clone());
value
} else {
let cell = constants.get(&v).unwrap();
self.assign_value(region, offset, cell.clone(), coord, constants)?
}
}
};
Ok(res)
}
}

View File

@@ -8,12 +8,14 @@ use crate::graph::quantize_float;
use crate::graph::scale_to_multiplier;
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::create_proof_circuit;
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::pfsys::verify_proof_circuit;
use crate::pfsys::TranscriptType;
use crate::tensor::TensorType;
use crate::CheckMode;
use crate::Commitments;
use console_error_panic_hook;
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
@@ -33,11 +35,10 @@ use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, PrimeField};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use console_error_panic_hook;
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
@@ -336,8 +337,94 @@ pub fn verify(
let orig_n = 1 << circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match circuit_settings.run_args.commitment {
let result = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(JsError::new(&format!("{}", e))),
}
}
#[wasm_bindgen]
#[allow(non_snake_case)]
/// Verify aggregate proof in browser using wasm
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
let orig_n = 1 << logrows;
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
@@ -436,8 +523,9 @@ pub fn prove(
// read in kzg params
let mut reader = std::io::BufReader::new(&srs[..]);
let commitment = circuit.settings().run_args.commitment.into();
// creates and verifies the proof
let proof = match circuit.settings().run_args.commitment {
let proof = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)

View File

@@ -122,7 +122,7 @@ mod native_tests {
let settings: GraphSettings = serde_json::from_str(&settings).unwrap();
let logrows = settings.run_args.logrows;
download_srs(logrows, settings.run_args.commitment);
download_srs(logrows, settings.run_args.commitment.into());
}
fn mv_test_(test_dir: &str, test: &str) {
@@ -200,7 +200,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 91] = [
const TESTS: [&str; 92] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -296,6 +296,7 @@ mod native_tests {
"reducel1",
"reducel2", // 89
"1l_lppool",
"lstm_large", // 91
];
const WASM_TESTS: [&str; 46] = [
@@ -534,7 +535,7 @@ mod native_tests {
}
});
seq!(N in 0..=90 {
seq!(N in 0..=91 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -622,7 +623,7 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
if test != "gather_nd" && test != "lstm_large" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
@@ -1970,7 +1971,7 @@ mod native_tests {
.expect("failed to parse settings file");
// get_srs for the graph_settings_num_instances
download_srs(1, graph_settings.run_args.commitment);
download_srs(1, graph_settings.run_args.commitment.into());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([

View File

@@ -56,33 +56,38 @@ mod py_tests {
// source .env/bin/activate
// pip install -r requirements.txt
// maturin develop --release --features python-bindings
// first install tf2onnx as it has protobuf conflict with onnx
let status = Command::new("pip")
.args(["install", "tf2onnx==1.16.1"])
.status()
.expect("failed to execute process");
assert!(status.success());
// now install torch, pandas, numpy, seaborn, jupyter
let status = Command::new("pip")
.args([
"install",
"torch-geometric==2.5.0",
"torch==2.0.1",
"torchvision==0.15.2",
"pandas==2.0.3",
"numpy==1.23",
"seaborn==0.12.2",
"jupyter==1.0.0",
"onnx==1.14.0",
"kaggle==1.5.15",
"py-solc-x==1.1.1",
"web3==6.5.0",
"librosa==0.10.0.post2",
"keras==2.12.0",
"tensorflow==2.12.0",
"tensorflow-datasets==4.9.3",
"tf2onnx==1.14.0",
"pytorch-lightning==2.0.6",
"torch-geometric==2.5.2",
"torch==2.2.2",
"torchvision==0.17.2",
"pandas==2.2.1",
"numpy==1.26.4",
"seaborn==0.13.2",
"notebook==7.1.2",
"nbconvert==7.16.3",
"onnx==1.16.0",
"kaggle==1.6.8",
"py-solc-x==2.0.2",
"web3==6.16.0",
"librosa==0.10.1",
"keras==3.1.1",
"tensorflow==2.16.1",
"tensorflow-datasets==4.9.4",
"pytorch-lightning==2.2.1",
"sk2torch==1.2.0",
"scikit-learn==1.3.1",
"xgboost==1.7.6",
"hummingbird-ml==0.4.9",
"lightgbm==4.0.0",
"scikit-learn==1.4.1.post1",
"xgboost==2.0.3",
"hummingbird-ml==0.4.11",
"lightgbm==4.3.0",
])
.status()
.expect("failed to execute process");

View File

@@ -11,7 +11,7 @@ mod wasm32 {
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
@@ -27,10 +27,29 @@ mod wasm32 {
pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx");
pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json");
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json");
pub const PROOF_AGGR: &[u8] = include_bytes!("../tests/wasm/proof_aggr.json");
pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json");
pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key");
pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key");
pub const VK_AGGR: &[u8] = include_bytes!("../tests/wasm/vk_aggr.key");
pub const SRS: &[u8] = include_bytes!("../tests/wasm/kzg");
pub const SRS1: &[u8] = include_bytes!("../tests/wasm/kzg1.srs");
#[wasm_bindgen_test]
async fn can_verify_aggr() {
let value = verifyAggr(
wasm_bindgen::Clamped(PROOF_AGGR.to_vec()),
wasm_bindgen::Clamped(VK_AGGR.to_vec()),
21,
wasm_bindgen::Clamped(SRS1.to_vec()),
"kzg",
)
.map_err(|_| "failed")
.unwrap();
// should not fail
assert!(value);
}
#[wasm_bindgen_test]
async fn verify_encode_verifier_calldata() {

BIN
tests/wasm/kzg1.srs Normal file

Binary file not shown.

Binary file not shown.

3075
tests/wasm/proof_aggr.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -24,8 +24,7 @@
"param_visibility": "Private",
"div_rebasing": false,
"rebase_frac_zero_constants": false,
"check_mode": "UNSAFE",
"commitment": "KZG"
"check_mode": "UNSAFE"
},
"num_rows": 16,
"total_dynamic_col_size": 0,

BIN
tests/wasm/vk_aggr.key Normal file

Binary file not shown.