mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
21 Commits
v18.0.0
...
ac/fix-gpu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77a0a5f1eb | ||
|
|
b2e5150c52 | ||
|
|
99f741304a | ||
|
|
20ac99fdbf | ||
|
|
532fa65e93 | ||
|
|
cfe5db545c | ||
|
|
21ad56aea1 | ||
|
|
4ed7e0fd29 | ||
|
|
05d1f10615 | ||
|
|
9a8c754e45 | ||
|
|
d82766d413 | ||
|
|
820a80122b | ||
|
|
9c64e42bd3 | ||
|
|
27b5e5dde3 | ||
|
|
83c4afce3b | ||
|
|
50740a22df | ||
|
|
a2624f6303 | ||
|
|
fc5be4f949 | ||
|
|
d0ba505baa | ||
|
|
f35688917d | ||
|
|
7ae541ed35 |
22
.github/workflows/benchmarks.yml
vendored
22
.github/workflows/benchmarks.yml
vendored
@@ -8,6 +8,8 @@ on:
|
||||
jobs:
|
||||
|
||||
bench_poseidon:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -22,6 +24,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench poseidon
|
||||
|
||||
bench_einsum_accum_matmul:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -37,6 +41,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_einsum_matmul
|
||||
|
||||
bench_accum_matmul_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -52,6 +58,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu
|
||||
|
||||
bench_accum_matmul_relu_overflow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -67,6 +75,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu_overflow
|
||||
|
||||
bench_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -82,6 +92,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench relu
|
||||
|
||||
bench_accum_dot:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -97,6 +109,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_dot
|
||||
|
||||
bench_accum_conv:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -112,6 +126,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_conv
|
||||
|
||||
bench_accum_sumpool:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -127,6 +143,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sumpool
|
||||
|
||||
bench_pairwise_add:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -142,6 +160,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench pairwise_add
|
||||
|
||||
bench_accum_sum:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -157,6 +177,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sum
|
||||
|
||||
bench_pairwise_pow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
|
||||
78
.github/workflows/engine.yml
vendored
78
.github/workflows/engine.yml
vendored
@@ -15,7 +15,12 @@ defaults:
|
||||
working-directory: .
|
||||
jobs:
|
||||
publish-wasm-bindings:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
@@ -44,41 +49,41 @@ jobs:
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
echo '{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "${RELEASE_TAG}",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}' > pkg/package.json
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
@@ -186,9 +191,14 @@ jobs:
|
||||
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -196,10 +206,8 @@ jobs:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${RELEASE_TAG}\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
|
||||
|
||||
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -6,6 +6,8 @@ on:
|
||||
description: "Test scenario tags"
|
||||
jobs:
|
||||
large-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: kaiju
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
12
.github/workflows/pypi-gpu.yml
vendored
12
.github/workflows/pypi-gpu.yml
vendored
@@ -18,10 +18,15 @@ defaults:
|
||||
jobs:
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: GPU
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -38,6 +43,9 @@ jobs:
|
||||
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: rename ezkl.pyi to ezkl-gpu.pyi
|
||||
run: mv ezkl.pyi ezkl-gpu.pyi
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -46,8 +54,6 @@ jobs:
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
@@ -73,7 +79,7 @@ jobs:
|
||||
pip install ezkl-gpu --no-index --find-links dist --force-reinstall
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
152
.github/workflows/pypi.yml
vendored
152
.github/workflows/pypi.yml
vendored
@@ -16,11 +16,15 @@ defaults:
|
||||
|
||||
jobs:
|
||||
macos:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64, universal2-apple-darwin]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -30,10 +34,14 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
@@ -47,6 +55,13 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'universal2-apple-darwin'
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'x86_64'
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
@@ -58,12 +73,14 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-macos-${{ matrix.target }}
|
||||
path: dist
|
||||
|
||||
windows:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: windows-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -78,6 +95,14 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: ${{ matrix.target }}
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -105,12 +130,14 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-windows-${{ matrix.target }}
|
||||
path: dist
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -125,6 +152,14 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -135,7 +170,6 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -169,63 +203,14 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-linux-${{ matrix.target }}
|
||||
path: dist
|
||||
|
||||
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
|
||||
# linux-cross:
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
# matrix:
|
||||
# target: [aarch64, armv7]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: 3.12
|
||||
|
||||
# - name: Install cross-compilation tools for aarch64
|
||||
# if: matrix.target == 'aarch64'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
|
||||
|
||||
# - name: Install cross-compilation tools for armv7
|
||||
# if: matrix.target == 'armv7'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
|
||||
|
||||
# - name: Build wheels
|
||||
# uses: PyO3/maturin-action@v1
|
||||
# with:
|
||||
# target: ${{ matrix.target }}
|
||||
# manylinux: auto
|
||||
# args: --release --out dist --features python-bindings
|
||||
|
||||
# - uses: uraimo/run-on-arch-action@v2.5.0
|
||||
# name: Install built wheel
|
||||
# with:
|
||||
# arch: ${{ matrix.target }}
|
||||
# distro: ubuntu20.04
|
||||
# githubToken: ${{ github.token }}
|
||||
# install: |
|
||||
# apt-get update
|
||||
# apt-get install -y --no-install-recommends python3 python3-pip
|
||||
# pip3 install -U pip
|
||||
# run: |
|
||||
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
|
||||
# python3 -c "import ezkl"
|
||||
|
||||
# - name: Upload wheels
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: wheels
|
||||
# path: dist
|
||||
|
||||
musllinux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -258,6 +243,7 @@ jobs:
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -285,12 +271,14 @@ jobs:
|
||||
python3 -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-musllinux-${{ matrix.target }}
|
||||
path: dist
|
||||
|
||||
musllinux-cross:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -306,6 +294,14 @@ jobs:
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -338,9 +334,9 @@ jobs:
|
||||
python3 -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-musllinux-${{ matrix.platform.target }}
|
||||
path: dist
|
||||
|
||||
pypi-publish:
|
||||
@@ -349,35 +345,33 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
if: "startsWith(github.ref, 'refs/tags/')"
|
||||
# TODO: Uncomment if linux-cross is working
|
||||
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
|
||||
needs: [macos, windows, linux, musllinux, musllinux-cross]
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: wheels
|
||||
pattern: dist-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
- name: List Files
|
||||
run: ls -R
|
||||
|
||||
# Both publish steps will fail if there is no trusted publisher setup
|
||||
# On failure the publish step will then simply continue to the next one
|
||||
|
||||
# # publishes to TestPyPI
|
||||
# - name: Publish package distribution to TestPyPI
|
||||
# uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# packages-dir: ./
|
||||
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
@@ -390,4 +384,4 @@ jobs:
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
18
.github/workflows/release.yml
vendored
18
.github/workflows/release.yml
vendored
@@ -10,6 +10,9 @@ on:
|
||||
- "*"
|
||||
jobs:
|
||||
create-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: create-release
|
||||
runs-on: ubuntu-22.04
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -33,6 +36,9 @@ jobs:
|
||||
tag_name: ${{ env.EZKL_VERSION }}
|
||||
|
||||
build-release-gpu:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: build-release-gpu
|
||||
needs: ["create-release"]
|
||||
runs-on: GPU
|
||||
@@ -94,6 +100,10 @@ jobs:
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
build-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
issues: write
|
||||
name: build-release
|
||||
needs: ["create-release"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -186,14 +196,18 @@ jobs:
|
||||
echo "target flag is: ${{ env.TARGET_FLAGS }}"
|
||||
echo "target dir is: ${{ env.TARGET_DIR }}"
|
||||
|
||||
- name: Build release binary (no asm)
|
||||
if: matrix.build != 'linux-gnu'
|
||||
- name: Build release binary (no asm or metal)
|
||||
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
|
||||
|
||||
- name: Build release binary (metal)
|
||||
if: matrix.build == 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
|
||||
|
||||
- name: Strip release binary
|
||||
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
|
||||
run: strip "target/${{ matrix.target }}/release/ezkl"
|
||||
|
||||
436
.github/workflows/rust.yml
vendored
436
.github/workflows/rust.yml
vendored
@@ -19,8 +19,10 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
|
||||
fr-age-test:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -31,12 +33,17 @@ jobs:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: fr age Mock
|
||||
run: cargo test --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
|
||||
|
||||
build:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -50,6 +57,8 @@ jobs:
|
||||
run: cargo build --verbose
|
||||
|
||||
docs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -64,6 +73,8 @@ jobs:
|
||||
run: cargo doc --verbose
|
||||
|
||||
library-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -92,8 +103,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -124,6 +135,8 @@ jobs:
|
||||
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
|
||||
ultra-overflow-tests_og-lookup:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -150,15 +163,17 @@ jobs:
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run --release matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
- name: Conv overflow
|
||||
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
|
||||
ultra-overflow-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -185,15 +200,17 @@ jobs:
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture -- --include-ignored
|
||||
run: cargo nextest run lookup_ultra_overflow --no-capture -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture -- --include-ignored
|
||||
- name: Conv overflow
|
||||
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
|
||||
|
||||
model-serialization:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -212,7 +229,9 @@ jobs:
|
||||
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
|
||||
|
||||
wasm32-tests:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -225,7 +244,7 @@ jobs:
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
version: "v0.12.1"
|
||||
- uses: nanasess/setup-chromedriver@v2
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
@@ -239,8 +258,9 @@ jobs:
|
||||
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
|
||||
|
||||
mock-proving-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -255,55 +275,57 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
- name: kzg inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
|
||||
- name: kzg params
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_params_::t --test-threads 32
|
||||
- name: kzg outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_output_::t --test-threads 32
|
||||
- name: kzg inputs + params + outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_all_::t --test-threads 32
|
||||
- name: Mock fixed inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_inputs_ --test-threads 32
|
||||
- name: Mock fixed outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_outputs --test-threads 32
|
||||
- name: Mock accuracy calibration
|
||||
run: cargo nextest run --release --verbose tests::mock_accuracy_cal_tests::a
|
||||
- name: hashed inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
|
||||
- name: hashed params
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
|
||||
- name: hashed params public inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
|
||||
- name: hashed outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
|
||||
- name: hashed inputs + params + outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_all_::t --test-threads 32
|
||||
- name: hashed inputs + fixed params
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_output_fixed_params_::t --test-threads 32
|
||||
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: MNIST Gan Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
|
||||
- name: NanoGPT Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_1_expects -- --include-ignored
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_1_expects -- --include-ignored
|
||||
- name: Self Attention Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_0_expects -- --include-ignored
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_0_expects -- --include-ignored
|
||||
- name: Multihead Attention Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
|
||||
- name: public outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_public_outputs_ --test-threads 32
|
||||
run: cargo nextest run --verbose tests::mock_public_outputs_ --test-threads 32
|
||||
- name: public inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_public_inputs_ --test-threads 32
|
||||
run: cargo nextest run --verbose tests::mock_public_inputs_ --test-threads 32
|
||||
- name: fixed params
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
- name: kzg inputs
|
||||
run: cargo nextest run --verbose tests::mock_kzg_input_::t --test-threads 32
|
||||
- name: kzg params
|
||||
run: cargo nextest run --verbose tests::mock_kzg_params_::t --test-threads 32
|
||||
- name: kzg outputs
|
||||
run: cargo nextest run --verbose tests::mock_kzg_output_::t --test-threads 32
|
||||
- name: kzg inputs + params + outputs
|
||||
run: cargo nextest run --verbose tests::mock_kzg_all_::t --test-threads 32
|
||||
- name: Mock fixed inputs
|
||||
run: cargo nextest run --verbose tests::mock_fixed_inputs_ --test-threads 32
|
||||
- name: Mock fixed outputs
|
||||
run: cargo nextest run --verbose tests::mock_fixed_outputs --test-threads 32
|
||||
- name: Mock accuracy calibration
|
||||
run: cargo nextest run --verbose tests::mock_accuracy_cal_tests::a
|
||||
- name: hashed inputs
|
||||
run: cargo nextest run --verbose tests::mock_hashed_input_::t --test-threads 32
|
||||
- name: hashed params
|
||||
run: cargo nextest run --verbose tests::mock_hashed_params_::t --test-threads 32
|
||||
- name: hashed params public inputs
|
||||
run: cargo nextest run --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
|
||||
- name: hashed outputs
|
||||
run: cargo nextest run --verbose tests::mock_hashed_output_::t --test-threads 32
|
||||
- name: hashed inputs + params + outputs
|
||||
run: cargo nextest run --verbose tests::mock_hashed_all_::t --test-threads 32
|
||||
- name: hashed inputs + fixed params
|
||||
run: cargo nextest run --verbose tests::mock_hashed_output_fixed_params_::t --test-threads 32
|
||||
|
||||
prove-and-verify-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -342,7 +364,7 @@ jobs:
|
||||
NODE_ENV: development
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
|
||||
@@ -356,37 +378,73 @@ jobs:
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg inputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_input_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_input_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg params)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain outputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & outputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & kzg outputs + params)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain all kzg)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + hashed inputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_input_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_input_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + hashed params)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
|
||||
|
||||
# prove-and-verify-tests-metal:
|
||||
# permissions:
|
||||
# contents: read
|
||||
# runs-on: macos-13
|
||||
# # needs: [build, library-tests, docs]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: jetli/wasm-pack-action@v0.4.0
|
||||
# with:
|
||||
# # Pin to version 0.12.1
|
||||
# version: 'v0.12.1'
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2024-07-18
|
||||
# - uses: actions/checkout@v3
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - name: Use pnpm 8
|
||||
# uses: pnpm/action-setup@v2
|
||||
# with:
|
||||
# version: 8
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
|
||||
|
||||
prove-and-verify-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
@@ -401,7 +459,7 @@ jobs:
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
version: "v0.12.1"
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
@@ -431,40 +489,40 @@ jobs:
|
||||
locked: true
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
- name: KZG prove and verify tests (hashed inputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
- name: IPA prove and verify tests
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
run: cargo nextest run --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
- name: IPA prove and verify tests (ipa outputs)
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
run: cargo nextest run --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
- name: KZG prove and verify tests single inner col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_single_col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_single_col
|
||||
- name: KZG prove and verify tests triple inner col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_triple_col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_triple_col
|
||||
- name: KZG prove and verify tests quadruple inner col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_quadruple_col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_quadruple_col
|
||||
- name: KZG prove and verify tests octuple inner col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed
|
||||
|
||||
# prove-and-verify-tests-gpu:
|
||||
# runs-on: GPU
|
||||
@@ -472,8 +530,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -487,23 +545,25 @@ jobs:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (kzg outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public inputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (fixed params)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (hashed outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -520,7 +580,7 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Mock aggr tests (KZG)
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
|
||||
# prove-and-verify-aggr-tests-gpu:
|
||||
# runs-on: GPU
|
||||
@@ -528,8 +588,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -543,6 +603,8 @@ jobs:
|
||||
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -559,9 +621,11 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG tests
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -582,9 +646,11 @@ jobs:
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
examples:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
@@ -604,6 +670,8 @@ jobs:
|
||||
run: cargo nextest run --release tests_examples
|
||||
|
||||
python-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
@@ -627,11 +695,13 @@ jobs:
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
|
||||
- name: Run pytest
|
||||
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
|
||||
|
||||
accuracy-measurement-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -653,17 +723,19 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
|
||||
- name: Public inputs
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
|
||||
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_public_inputs_
|
||||
- name: fixed params
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_fixed_params_
|
||||
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_fixed_params_
|
||||
- name: Public outputs
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_outputs_
|
||||
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_public_outputs_
|
||||
- name: Public outputs + resources
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
|
||||
run: source .env/bin/activate; cargo nextest run --verbose tests::resources_accuracy_measurement_public_outputs_
|
||||
|
||||
python-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
services:
|
||||
# Label used to access the service container
|
||||
@@ -708,7 +780,11 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: Neural bow
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
@@ -726,87 +802,87 @@ jobs:
|
||||
# # now dump the contents of the file into a file called kaggle.json
|
||||
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Reusable verifier tutorial
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
|
||||
ios-integration-tests:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
|
||||
swift-package-tests:
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
|
||||
3
.github/workflows/static-analysis.yml
vendored
3
.github/workflows/static-analysis.yml
vendored
@@ -8,8 +8,9 @@ on:
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
3
.github/workflows/swift-pm.yml
vendored
3
.github/workflows/swift-pm.yml
vendored
@@ -9,6 +9,9 @@ on:
|
||||
|
||||
jobs:
|
||||
build-and-update:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
130
Cargo.lock
generated
130
Cargo.lock
generated
@@ -1,6 +1,6 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
@@ -944,7 +944,7 @@ dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"itertools 0.11.0",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
@@ -1760,7 +1760,7 @@ checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
|
||||
[[package]]
|
||||
name = "ecc"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"integer",
|
||||
"num-bigint",
|
||||
@@ -1835,6 +1835,16 @@ dependencies = [
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.2"
|
||||
@@ -1848,6 +1858,19 @@ dependencies = [
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"env_filter",
|
||||
"humantime",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
@@ -1923,7 +1946,7 @@ dependencies = [
|
||||
"console_error_panic_hook",
|
||||
"criterion 0.5.1",
|
||||
"ecc",
|
||||
"env_logger",
|
||||
"env_logger 0.10.2",
|
||||
"ethabi",
|
||||
"foundry-compilers",
|
||||
"gag",
|
||||
@@ -1931,7 +1954,7 @@ dependencies = [
|
||||
"halo2_gadgets",
|
||||
"halo2_proofs",
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.7.0",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"hex",
|
||||
"indicatif",
|
||||
"instant",
|
||||
@@ -1939,20 +1962,17 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"metal",
|
||||
"mimalloc",
|
||||
"mnist",
|
||||
"num",
|
||||
"objc",
|
||||
"openssl",
|
||||
"pg_bigdecimal",
|
||||
"portable-atomic",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"semver 1.0.22",
|
||||
"seq-macro",
|
||||
@@ -2377,7 +2397,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#d7ecad83c7439fa1cb450ee4a89c2d0b45604ceb"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec",
|
||||
@@ -2394,14 +2414,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6#ee4e1a09ebdb1f79f797685b78951c6034c430a6"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
"env_logger 0.10.2",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2curves 0.7.0",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"icicle-bn254",
|
||||
"icicle-core",
|
||||
"icicle-cuda-runtime",
|
||||
@@ -2409,6 +2429,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"mopro-msm",
|
||||
"rand_chacha",
|
||||
"rand_core 0.6.4",
|
||||
"rustc-hash 2.0.0",
|
||||
@@ -2494,6 +2515,36 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d380afeef3f1d4d3245b76895172018cfb087d9976a7cabcd5597775b2933e07"
|
||||
dependencies = [
|
||||
"blake2",
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"pairing",
|
||||
"pasta_curves",
|
||||
"paste",
|
||||
"rand 0.8.5",
|
||||
"rand_core 0.6.4",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_arrays",
|
||||
"sha2",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.7.0"
|
||||
@@ -2503,7 +2554,7 @@ dependencies = [
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive",
|
||||
"halo2derive 0.1.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
@@ -2523,6 +2574,20 @@ dependencies = [
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bdb99e7492b4f5ff469d238db464131b86c2eaac814a78715acba369f64d2c76"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
@@ -2539,7 +2604,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2_proofs",
|
||||
"num-bigint",
|
||||
@@ -2890,7 +2955,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "integer"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"maingate",
|
||||
"num-bigint",
|
||||
@@ -3201,7 +3266,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "maingate"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2wrong",
|
||||
"num-bigint",
|
||||
@@ -3283,7 +3348,8 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "metal"
|
||||
version = "0.29.0"
|
||||
source = "git+https://github.com/gfx-rs/metal-rs#0e1918b34689c4b8cd13a43372f9898680547ee9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"block",
|
||||
@@ -3354,6 +3420,28 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mopro-msm"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/metal-msm-gpu-acceleration.git#be5f647b1a6c1a6ea9024390744a2b4d87f5d002"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"env_logger 0.11.6",
|
||||
"halo2curves 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"instant",
|
||||
"itertools 0.13.0",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"metal",
|
||||
"objc",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
@@ -3587,9 +3675,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-src"
|
||||
version = "300.2.3+3.2.1"
|
||||
version = "300.4.1+3.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cff92b6f71555b61bb9315f7c64da3ca43d87531622120fea0195fc761b4843"
|
||||
checksum = "faa4eac4138c62414b5622d1b31c5c304f34b406b013c079c2bbc652fdd6678c"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
@@ -5142,7 +5230,7 @@ checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
|
||||
[[package]]
|
||||
name = "snark-verifier"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac%2Fchunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
dependencies = [
|
||||
"ecc",
|
||||
"halo2_proofs",
|
||||
@@ -6146,7 +6234,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "uniffi_testing"
|
||||
version = "0.28.0"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat/testing-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat%2Ftesting-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"camino",
|
||||
|
||||
21
Cargo.toml
21
Cargo.toml
@@ -40,7 +40,6 @@ maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = { version = "1.6.0", optional = true }
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
|
||||
semver = { version = "1.0.22", optional = true }
|
||||
|
||||
@@ -74,7 +73,6 @@ tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
regex = { version = "1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
@@ -91,7 +89,6 @@ pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", ver
|
||||
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
pyo3-stub-gen = { version = "0.6.0", optional = true }
|
||||
@@ -245,16 +242,14 @@ ezkl = [
|
||||
"dep:indicatif",
|
||||
"dep:gag",
|
||||
"dep:reqwest",
|
||||
"dep:openssl",
|
||||
"dep:tokio-postgres",
|
||||
"dep:pg_bigdecimal",
|
||||
"dep:lazy_static",
|
||||
"dep:regex",
|
||||
"dep:tokio",
|
||||
"dep:openssl",
|
||||
"dep:mimalloc",
|
||||
"dep:chrono",
|
||||
"dep:sha256",
|
||||
"dep:portable-atomic",
|
||||
"dep:clap_complete",
|
||||
"dep:halo2_solidity_verifier",
|
||||
"dep:semver",
|
||||
@@ -277,13 +272,15 @@ icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
no-update = []
|
||||
|
||||
macos-metal = ["halo2_proofs/macos"]
|
||||
ios-metal = ["halo2_proofs/ios"]
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
|
||||
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
@@ -292,9 +289,13 @@ uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "fea
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
# panic = "abort"
|
||||
#panic = "abort"
|
||||
|
||||
|
||||
[profile.test-runs]
|
||||
inherits = "dev"
|
||||
opt-level = 3
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = [
|
||||
"-O4",
|
||||
|
||||
@@ -150,6 +150,13 @@ Ezkl is unaudited, beta software undergoing rapid development. There may be bugs
|
||||
|
||||
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
|
||||
|
||||
|
||||
### Advanced security topics
|
||||
|
||||
Check out `docs/advanced_security` for more advanced information on potential threat vectors.
|
||||
|
||||
|
||||
|
||||
### no warranty
|
||||
|
||||
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
|
||||
@@ -10,6 +10,7 @@ use rand::Rng;
|
||||
|
||||
// Assuming these are your types
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
enum ValType {
|
||||
Constant(F),
|
||||
AssignedConstant(usize, F),
|
||||
|
||||
41
docs/advanced_security/public_commitments.md
Normal file
41
docs/advanced_security/public_commitments.md
Normal file
@@ -0,0 +1,41 @@
|
||||
## EZKL Security Note: Public Commitments and Low-Entropy Data
|
||||
|
||||
> **Disclaimer:** this a more technical post that requires some prior knowledge of how ZK proving systems like Halo2 operate, and in particular in how these APIs are constructed. For background reading we highly recommend the [Halo2 book](https://zcash.github.io/halo2/) and [Halo2 Club](https://halo2.club/).
|
||||
|
||||
## Overview of commitments in EZKL
|
||||
|
||||
A common design pattern in a zero knowledge (zk) application is thus:
|
||||
- A prover has some data which is used within a circuit.
|
||||
- This data, as it may be high-dimensional or somewhat private, is pre-committed to using some hash function.
|
||||
- The zk-circuit which forms the core of the application then proves (para-phrasing) a statement of the form:
|
||||
>"I know some data D which when hashed corresponds to the pre-committed to value H + whatever else the circuit is proving over D".
|
||||
|
||||
From our own experience, we've implemented such patterns using snark-friendly hash functions like [Poseidon](https://www.poseidon-hash.info/), for which there is a relatively well vetted [implementation](https://docs.rs/halo2_gadgets/latest/halo2_gadgets/poseidon/index.html) in Halo2. Even then these hash functions can introduce lots of overhead and can be very expensive to generate proofs for if the dimensionality of the data D is large.
|
||||
|
||||
You can also implement such a pattern using Halo2's `Fixed` columns _if the privacy preservation of the pre-image is not necessary_. These are Halo2 columns (i.e in reality just polynomials) that are left unblinded (unlike the blinded `Advice` columns), and whose commitments are shared with the verifier by way of the verifying key for the application's zk-circuit. These commitments are much lower cost to generate than implementing a hashing function, such as Poseidon, within a circuit.
|
||||
|
||||
> **Note:** Blinding is the process whereby a certain set of the final elements (i.e rows) of a Halo2 column are set to random field elements. This is the mechanism by which Halo2 achieves its zero knowledge properties for `Advice` columns. By contrast `Fixed` columns aren't zero-knowledge in that they are vulnerable to dictionary attacks in the same manner a hash function is. Given some set of known or popular data D an attacker can attempt to recover the pre-image of a hash by running D through the hash function to see if the outputs match a public commitment. These attacks aren't "possible" on blinded `Advice` columns.
|
||||
|
||||
> **Further Note:** Note that without blinding, with access to `M` proofs, each of which contains an evaluation of the polynomial at a different point, an attacker can more easily recover a non blinded column's pre-image. This is because each proof generates a new query and evaluation of the polynomial represented by the column and as such with repetition a clearer picture can emerge of the column's pre-image. Thus unblinded columns should only be used for privacy preservation, in the manner of a hash, if the number of proofs generated against a fixed set of values is limited. More formally if M independent and _unique_ queries are generated; if M is equal to the degree + 1 of the polynomial represented by the column (i.e the unique lagrange interpolation of the values in the columns), then the column's pre-image can be recovered. As such as the logrows K increases, the more queries are required to recover the pre-image (as 2^K unique queries are required). This assumes that the entries in the column are not structured, as if they are then the number of queries required to recover the pre-image is reduced (eg. if all rows above a certain point are known to be nil).
|
||||
|
||||
The annoyance in using `Fixed` columns comes from the fact that they require generating a new verifying key every time a new set of commitments is generated.
|
||||
|
||||
> **Example:** Say for instance an application leverages a zero-knowledge circuit to prove the correct execution of a neural network. Every week the neural network is finetuned or retrained on new data. If the architecture remains the same then commiting to the new network parameters, along with a new proof of performance on a test set, would be an ideal setup. If we leverage `Fixed` columns to commit to the model parameters, each new commitment will require re-generating a verifying key and sharing the new key with the verifier(s). This is not-ideal UX and can become expensive if the verifier is deployed on-chain.
|
||||
|
||||
An ideal commitment would thus have the low cost of a `Fixed` column but wouldn't require regenerating a new verifying key for each new commitment.
|
||||
|
||||
### Unblinded Advice Columns
|
||||
|
||||
A first step in designing such a commitment is to allow for optionally unblinded `Advice` columns within the Halo2 API. These won't be included in the verifying key, AND are blinded with a constant factor `1` -- such that if someone knows the pre-image to the commitment, they can recover it by running it through the corresponding polynomial commitment scheme (in ezkl's case [KZG commitments](https://dankradfeist.de/ethereum/2020/06/16/kate-polynomial-commitments.html)).
|
||||
|
||||
This is implemented using the `polycommit` visibility parameter in the ezkl API.
|
||||
|
||||
## The Vulnerability of Public Commitments
|
||||
|
||||
|
||||
Public commitments in EZKL (both Poseidon-hashed inputs and KZG commitments) can be vulnerable to brute-force attacks when input data has low entropy. A malicious actor could reveal committed data by searching through possible input values, compromising privacy in applications like anonymous credentials. This is particularly relevant when input data comes from known finite sets (e.g., names, dates).
|
||||
|
||||
Example Risk: In an anonymous credential system using EZKL for ID verification, an attacker could match hashed outputs against a database of common identifying information to deanonymize users.
|
||||
|
||||
|
||||
|
||||
22
docs/advanced_security/quantization_backdoors.md
Normal file
22
docs/advanced_security/quantization_backdoors.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# EZKL Security Note: Quantization-Induced Model Backdoors
|
||||
|
||||
> Note: this only affects a situation where a party separate to an application's developer has access to the model's weights and can modify them. This is a common scenario in adversarial machine learning research, but can be less common in real-world applications. If you're building your models in house and deploying them yourself, this is less of a concern. If you're building a permisionless system where anyone can submit models, this is more of a concern.
|
||||
|
||||
Models processed through EZKL's quantization step can harbor backdoors that are dormant in the original full-precision model but activate during quantization. These backdoors force specific outputs when triggered, with impact varying by application.
|
||||
|
||||
Key Factors:
|
||||
|
||||
- Larger models increase attack feasibility through more parameter capacity
|
||||
- Smaller quantization scales facilitate attacks by allowing greater weight modifications
|
||||
- Rebase ratio of 1 enables exploitation of convolutional layer consistency
|
||||
|
||||
Limitations:
|
||||
|
||||
- Attack effectiveness depends on calibration settings and internal rescaling operations.
|
||||
- Further research needed on backdoor persistence through witness/proof stages.
|
||||
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model.
|
||||
|
||||
References:
|
||||
|
||||
1. [Quantization Backdoors to Deep Learning Commercial Frameworks (Ma et al., 2021)](https://arxiv.org/abs/2108.09187)
|
||||
2. [Planting Undetectable Backdoors in Machine Learning Models (Goldwasser et al., 2022)](https://arxiv.org/abs/2204.06974)
|
||||
@@ -77,6 +77,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gip_run_args = ezkl.PyRunArgs()\n",
|
||||
"gip_run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
|
||||
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
|
||||
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
|
||||
@@ -335,9 +336,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,8 +308,11 @@
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.decomp_legs = 4\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
|
||||
@@ -152,9 +152,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# logrows\n",
|
||||
"run_args.logrows = 20\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -302,7 +304,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -167,6 +167,8 @@
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# \"hashed/private\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
|
||||
"run_args.input_visibility = \"hashed/private/0\"\n",
|
||||
"# as the inputs are felts we turn off input range checks\n",
|
||||
"run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"# we set it to fix the set we want to check membership for\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"# the output is public -- set membership fails if it is not = 0\n",
|
||||
@@ -519,4 +521,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -204,6 +204,7 @@
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
|
||||
"run_args.input_visibility = \"polycommit\"\n",
|
||||
"run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"# the parameters are public\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"# the output is public (this is the inequality test)\n",
|
||||
@@ -514,4 +515,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -60,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -94,7 +94,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -134,7 +134,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -183,7 +183,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -201,6 +201,7 @@
|
||||
"run_args.input_visibility = \"public\"\n",
|
||||
"run_args.param_visibility = \"private\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.decomp_legs=6\n",
|
||||
"run_args.num_inner_cols = 1\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]"
|
||||
]
|
||||
|
||||
42
examples/onnx/integer_div/gen.py
Normal file
42
examples/onnx/integer_div/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x // 3
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.randint(0, 10, (1, 2, 2, 8))
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(x)
|
||||
print(out)
|
||||
print(x/3)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/integer_div/input.json
Normal file
1
examples/onnx/integer_div/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}
|
||||
BIN
examples/onnx/integer_div/network.onnx
Normal file
BIN
examples/onnx/integer_div/network.onnx
Normal file
Binary file not shown.
@@ -1,7 +1,11 @@
|
||||
// ignore file if compiling for wasm
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::{CommandFactory, Parser};
|
||||
|
||||
@@ -8,7 +8,6 @@ use crate::circuit::InputType;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
|
||||
@@ -207,6 +206,9 @@ struct PyRunArgs {
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
|
||||
#[pyo3(get, set)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -239,6 +241,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
decomp_base: py_run_args.decomp_base,
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -263,6 +266,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
commitment: self.commitment.into(),
|
||||
decomp_base: self.decomp_base,
|
||||
decomp_legs: self.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -573,10 +577,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
.map(crate::pfsys::string_to_field::<Fr>)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(|_| PyIOError::new_err("Failed to run poseidon"))?;
|
||||
|
||||
let hash = output[0]
|
||||
|
||||
@@ -8,10 +8,7 @@ use crate::{
|
||||
Module,
|
||||
},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
|
||||
graph::{
|
||||
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
|
||||
GraphSettings,
|
||||
},
|
||||
graph::{quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings},
|
||||
};
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::{
|
||||
@@ -231,10 +228,7 @@ pub fn poseidonHash(
|
||||
let message: Vec<Fr> = serde_json::from_slice(&message[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
|
||||
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(
|
||||
|
||||
@@ -8,13 +8,11 @@ pub mod poseidon_params;
|
||||
pub mod spec;
|
||||
|
||||
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
use halo2_gadgets::poseidon::{primitives::*, Hash, Pow5Chip, Pow5Config};
|
||||
use halo2_proofs::arithmetic::Field;
|
||||
use halo2_gadgets::poseidon::{
|
||||
primitives::VariableLength, primitives::*, Hash, Pow5Chip, Pow5Config,
|
||||
};
|
||||
use halo2_proofs::halo2curves::bn256::Fr as Fp;
|
||||
use halo2_proofs::{circuit::*, plonk::*};
|
||||
// use maybe_rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
|
||||
use maybe_rayon::prelude::ParallelIterator;
|
||||
use maybe_rayon::slice::ParallelSlice;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -40,22 +38,17 @@ pub struct PoseidonConfig<const WIDTH: usize, const RATE: usize> {
|
||||
pub pow5_config: Pow5Config<Fp, WIDTH, RATE>,
|
||||
}
|
||||
|
||||
type InputAssignments = (Vec<AssignedCell<Fp, Fp>>, AssignedCell<Fp, Fp>);
|
||||
type InputAssignments = Vec<AssignedCell<Fp, Fp>>;
|
||||
|
||||
/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseidonChip<
|
||||
S: Spec<Fp, WIDTH, RATE> + Sync,
|
||||
const WIDTH: usize,
|
||||
const RATE: usize,
|
||||
const L: usize,
|
||||
> {
|
||||
pub struct PoseidonChip<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> {
|
||||
config: PoseidonConfig<WIDTH, RATE>,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
PoseidonChip<S, WIDTH, RATE, L>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
PoseidonChip<S, WIDTH, RATE>
|
||||
{
|
||||
/// Creates a new PoseidonChip
|
||||
pub fn configure_with_cols(
|
||||
@@ -82,8 +75,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
PoseidonChip<S, WIDTH, RATE, L>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
PoseidonChip<S, WIDTH, RATE>
|
||||
{
|
||||
/// Configuration of the PoseidonChip
|
||||
pub fn configure_with_optional_instance(
|
||||
@@ -113,8 +106,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
Module<Fp> for PoseidonChip<S, WIDTH, RATE, L>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Module<Fp>
|
||||
for PoseidonChip<S, WIDTH, RATE>
|
||||
{
|
||||
type Config = PoseidonConfig<WIDTH, RATE>;
|
||||
type InputAssignments = InputAssignments;
|
||||
@@ -170,7 +163,10 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
assert_eq!(message.len(), 1);
|
||||
if message.len() != 1 {
|
||||
return Err(ModuleError::InputWrongLength(message.len()));
|
||||
}
|
||||
|
||||
let message = message[0].clone();
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -180,95 +176,81 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let res = layouter.assign_region(
|
||||
|| "load message",
|
||||
|mut region| {
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
|
||||
match &message {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
v.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, _> = match &message {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
|
||||
match value {
|
||||
ValType::Value(v) => region
|
||||
.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
)
|
||||
.map_err(|e| e.into()),
|
||||
ValType::PrevAssigned(v)
|
||||
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants
|
||||
.get(f)
|
||||
.unwrap()
|
||||
.assigned_cell()
|
||||
.ok_or(ModuleError::ConstantNotAssigned)?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
constants.insert(
|
||||
*f,
|
||||
ValType::AssignedConstant(res.clone(), *f),
|
||||
);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"PrevAssigned".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
match value {
|
||||
ValType::Value(v) => region
|
||||
.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
};
|
||||
.map_err(|e| e.into()),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
Ok(v.clone())
|
||||
}
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants
|
||||
.get(f)
|
||||
.unwrap()
|
||||
.assigned_cell()
|
||||
.ok_or(ModuleError::ConstantNotAssigned)?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
let offset = message.len() / WIDTH + 1;
|
||||
constants
|
||||
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
|
||||
|
||||
let zero_val = region
|
||||
.assign_advice_from_constant(
|
||||
|| "",
|
||||
self.config.hash_inputs[0],
|
||||
offset,
|
||||
Fp::ZERO,
|
||||
)
|
||||
.unwrap();
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"AssignedValue".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
};
|
||||
|
||||
Ok((assigned_message?, zero_val))
|
||||
Ok(assigned_message?)
|
||||
},
|
||||
);
|
||||
log::trace!(
|
||||
@@ -289,7 +271,13 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
let input_cells = self.layout_inputs(layouter, input, constants)?;
|
||||
|
||||
// empty hash case
|
||||
if input_cells.is_empty() {
|
||||
return Ok(input[0].clone());
|
||||
}
|
||||
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
input_cells.iter().map(|e| ValType::from(e.clone())).into();
|
||||
@@ -297,52 +285,25 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let mut one_iter = false;
|
||||
// do the Tree dance baby
|
||||
while input_cells.len() > 1 || !one_iter {
|
||||
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
|
||||
.chunks(L)
|
||||
.enumerate()
|
||||
.map(|(i, block)| {
|
||||
let _start_time = instant::Instant::now();
|
||||
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
|
||||
// initialize the hasher
|
||||
let hasher = Hash::<_, _, S, VariableLength, WIDTH, RATE>::init(
|
||||
pow5_chip,
|
||||
layouter.namespace(|| "block_hasher"),
|
||||
)?;
|
||||
|
||||
let mut block = block.to_vec();
|
||||
let remainder = block.len() % L;
|
||||
|
||||
if remainder != 0 {
|
||||
block.extend(vec![zero_val.clone(); L - remainder]);
|
||||
}
|
||||
|
||||
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
|
||||
// initialize the hasher
|
||||
let hasher = Hash::<_, _, S, ConstantLength<L>, WIDTH, RATE>::init(
|
||||
pow5_chip,
|
||||
layouter.namespace(|| "block_hasher"),
|
||||
)?;
|
||||
|
||||
let hash = hasher.hash(
|
||||
layouter.namespace(|| "hash"),
|
||||
block.to_vec().try_into().map_err(|_| Error::Synthesis)?,
|
||||
);
|
||||
|
||||
if i == 0 {
|
||||
log::trace!("block (L={:?}) took: {:?}", L, _start_time.elapsed());
|
||||
}
|
||||
|
||||
hash
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into());
|
||||
|
||||
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
|
||||
one_iter = true;
|
||||
input_cells = hashes?;
|
||||
}
|
||||
let hash: AssignedCell<Fp, Fp> = hasher.hash(
|
||||
layouter.namespace(|| "hash"),
|
||||
input_cells
|
||||
.to_vec()
|
||||
.try_into()
|
||||
.map_err(|_| Error::Synthesis)?,
|
||||
)?;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
log::trace!("layout (N={:?}) took: {:?}", len, duration);
|
||||
|
||||
let result = Tensor::from(input_cells.iter().map(|e| ValType::from(e.clone())));
|
||||
let result = Tensor::from(vec![ValType::from(hash.clone())].into_iter());
|
||||
|
||||
let output = match result[0].clone() {
|
||||
ValType::PrevAssigned(v) => v,
|
||||
@@ -381,69 +342,59 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
|
||||
///
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
|
||||
let mut hash_inputs = message;
|
||||
|
||||
let len = hash_inputs.len();
|
||||
let len = message.len();
|
||||
if len == 0 {
|
||||
return Ok(vec![vec![]]);
|
||||
}
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let mut one_iter = false;
|
||||
// do the Tree dance baby
|
||||
while hash_inputs.len() > 1 || !one_iter {
|
||||
let hashes: Vec<Fp> = hash_inputs
|
||||
.par_chunks(L)
|
||||
.map(|block| {
|
||||
let mut block = block.to_vec();
|
||||
let remainder = block.len() % L;
|
||||
|
||||
if remainder != 0 {
|
||||
block.extend(vec![Fp::ZERO; L - remainder].iter());
|
||||
}
|
||||
|
||||
let block_len = block.len();
|
||||
|
||||
let message = block
|
||||
.try_into()
|
||||
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
|
||||
|
||||
Ok(halo2_gadgets::poseidon::primitives::Hash::<
|
||||
_,
|
||||
S,
|
||||
ConstantLength<L>,
|
||||
{ WIDTH },
|
||||
{ RATE },
|
||||
>::init()
|
||||
.hash(message))
|
||||
})
|
||||
.collect::<Result<Vec<_>, ModuleError>>()?;
|
||||
one_iter = true;
|
||||
hash_inputs = hashes;
|
||||
}
|
||||
let hash = halo2_gadgets::poseidon::primitives::Hash::<
|
||||
_,
|
||||
S,
|
||||
VariableLength,
|
||||
{ WIDTH },
|
||||
{ RATE },
|
||||
>::init()
|
||||
.hash(message);
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
log::trace!("run (N={:?}) took: {:?}", len, duration);
|
||||
|
||||
Ok(vec![hash_inputs])
|
||||
Ok(vec![vec![hash]])
|
||||
}
|
||||
|
||||
fn num_rows(mut input_len: usize) -> usize {
|
||||
fn num_rows(input_len: usize) -> usize {
|
||||
// this was determined by running the circuit and looking at the number of constraints
|
||||
// in the test called hash_for_a_range_of_input_sizes, then regressing in python to find the slope
|
||||
let fixed_cost: usize = 41 * L;
|
||||
// import numpy as np
|
||||
// from scipy import stats
|
||||
|
||||
let mut num_rows = 0;
|
||||
// x = np.array([32, 64, 96, 128, 160, 192])
|
||||
// y = np.array([1298, 2594, 3890, 5186, 6482, 7778])
|
||||
|
||||
loop {
|
||||
// the number of times the input_len is divisible by L
|
||||
let num_chunks = input_len / L + 1;
|
||||
num_rows += num_chunks * fixed_cost;
|
||||
if num_chunks == 1 {
|
||||
break;
|
||||
}
|
||||
input_len = num_chunks;
|
||||
}
|
||||
// slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
|
||||
|
||||
num_rows
|
||||
// print(f"slope: {slope}")
|
||||
// print(f"intercept: {intercept}")
|
||||
// print(f"R^2: {r_value**2}")
|
||||
|
||||
// # Predict for any x
|
||||
// def predict(x):
|
||||
// return slope * x + intercept
|
||||
|
||||
// # Test prediction
|
||||
// test_x = 256
|
||||
// print(f"Predicted value for x={test_x}: {predict(test_x)}")
|
||||
// our output:
|
||||
// slope: 40.5
|
||||
// intercept: 2.0
|
||||
// R^2: 1.0
|
||||
// Predicted value for x=256: 10370.0
|
||||
let fixed_cost: usize = 41 * input_len;
|
||||
|
||||
// the cost of the hash function is linear with the number of inputs
|
||||
fixed_cost + 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -470,12 +421,12 @@ mod tests {
|
||||
const RATE: usize = POSEIDON_RATE;
|
||||
const R: usize = 240;
|
||||
|
||||
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>, const L: usize> {
|
||||
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>> {
|
||||
message: ValTensor<Fp>,
|
||||
_spec: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE>, const L: usize> Circuit<Fp> for HashCircuit<S, L> {
|
||||
impl<S: Spec<Fp, WIDTH, RATE>> Circuit<Fp> for HashCircuit<S> {
|
||||
type Config = PoseidonConfig<WIDTH, RATE>;
|
||||
type FloorPlanner = ModulePlanner;
|
||||
type Params = ();
|
||||
@@ -491,7 +442,7 @@ mod tests {
|
||||
}
|
||||
|
||||
fn configure(meta: &mut ConstraintSystem<Fp>) -> PoseidonConfig<WIDTH, RATE> {
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, L>::configure(meta, ())
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(meta, ())
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -499,7 +450,7 @@ mod tests {
|
||||
config: PoseidonConfig<WIDTH, RATE>,
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
|
||||
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> = PoseidonChip::new(config);
|
||||
chip.layout(
|
||||
&mut layouter,
|
||||
&[self.message.clone()],
|
||||
@@ -511,18 +462,33 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash_empty() {
|
||||
let message = [];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, vec![vec![]]).unwrap();
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash() {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
let message = [Fp::random(rng), Fp::random(rng)];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 2> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -535,13 +501,13 @@ mod tests {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
let message = [Fp::random(rng), Fp::random(rng), Fp::random(rng)];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 3>::run(message.to_vec()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 3> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -557,23 +523,21 @@ mod tests {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
env_logger::init();
|
||||
|
||||
{
|
||||
let i = 32;
|
||||
for i in (32..128).step_by(32) {
|
||||
// print a bunch of new lines
|
||||
println!(
|
||||
log::info!(
|
||||
"i is {} -------------------------------------------------",
|
||||
i
|
||||
);
|
||||
|
||||
let message: Vec<Fp> = (0..i).map(|_| Fp::random(rng)).collect::<Vec<_>>();
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, 32>::run(message.clone()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 17;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 32> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -590,13 +554,13 @@ mod tests {
|
||||
|
||||
let mut message: Vec<Fp> = (0..2048).map(|_| Fp::random(rng)).collect::<Vec<_>>();
|
||||
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 25>::run(message.clone()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 17;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 25> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
|
||||
@@ -17,7 +17,6 @@ pub enum BaseOp {
|
||||
Sub,
|
||||
SumInit,
|
||||
Sum,
|
||||
IsBoolean,
|
||||
}
|
||||
|
||||
/// Matches a [BaseOp] to an operation over inputs
|
||||
@@ -34,7 +33,6 @@ impl BaseOp {
|
||||
BaseOp::Add => a + b,
|
||||
BaseOp::Sub => a - b,
|
||||
BaseOp::Mult => a * b,
|
||||
BaseOp::IsBoolean => b,
|
||||
_ => panic!("nonaccum_f called on accumulating operation"),
|
||||
}
|
||||
}
|
||||
@@ -74,7 +72,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => "MULT",
|
||||
BaseOp::Sum => "SUM",
|
||||
BaseOp::SumInit => "SUMINIT",
|
||||
BaseOp::IsBoolean => "ISBOOLEAN",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +87,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => (0, 1),
|
||||
BaseOp::Sum => (-1, 2),
|
||||
BaseOp::SumInit => (0, 1),
|
||||
BaseOp::IsBoolean => (0, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +102,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => 2,
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +117,6 @@ impl BaseOp {
|
||||
BaseOp::SumInit => 0,
|
||||
BaseOp::CumProd => 1,
|
||||
BaseOp::CumProdInit => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::str::FromStr;
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector, TableColumn},
|
||||
poly::Rotation,
|
||||
};
|
||||
use log::debug;
|
||||
@@ -341,6 +341,8 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
/// shared table inputs
|
||||
pub shared_table_inputs: Vec<TableColumn>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
@@ -353,6 +355,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
shared_table_inputs: vec![],
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -391,7 +394,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,24 +427,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.query_offset_rng();
|
||||
|
||||
let constraints = match base_op {
|
||||
BaseOp::IsBoolean => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let output = expected_output[base_op.constraint_idx()].clone();
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res]
|
||||
}
|
||||
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res]
|
||||
};
|
||||
|
||||
Constraints::with_selector(selector, constraints)
|
||||
@@ -497,6 +488,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
shared_table_inputs: vec![],
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -527,21 +519,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
return Err(CircuitError::WrongColumnType(output.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let table = if !self.static_lookups.tables.contains_key(nl) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let table = if let Some(table) = self.static_lookups.tables.values().next() {
|
||||
Table::<F>::configure(
|
||||
cs,
|
||||
lookup_range,
|
||||
logrows,
|
||||
nl,
|
||||
Some(table.table_inputs.clone()),
|
||||
)
|
||||
} else {
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
|
||||
};
|
||||
let table =
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, &mut self.shared_table_inputs);
|
||||
self.static_lookups.tables.insert(nl.clone(), table.clone());
|
||||
table
|
||||
} else {
|
||||
@@ -592,9 +572,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// this is 0 if the index is the same as the column index (starting from 1)
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* table
|
||||
* (table
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier =
|
||||
table.selector_constructor.get_selector_val_at_idx(col_idx);
|
||||
@@ -626,6 +606,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.static_lookups
|
||||
.selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
@@ -866,7 +880,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
|
||||
self.range_checks.ranges.entry(range)
|
||||
{
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
@@ -904,9 +917,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let default_x = range_check.get_first_element(col_idx);
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* range_check
|
||||
* (range_check
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier = range_check
|
||||
.selector_constructor
|
||||
@@ -929,6 +942,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.range_checks
|
||||
.selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
|
||||
@@ -25,7 +25,7 @@ pub enum CircuitError {
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
///
|
||||
/// Invalid einsum expression
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
/// Flush error
|
||||
@@ -103,4 +103,10 @@ pub enum CircuitError {
|
||||
#[error("an element is missing from the shuffled version of the tensor")]
|
||||
/// An element is missing from the shuffled version of the tensor
|
||||
MissingShuffleElement,
|
||||
/// Visibility has not been set
|
||||
#[error("visibility has not been set")]
|
||||
UnsetVisibility,
|
||||
/// A decomposition base overflowed
|
||||
#[error("decomposition base overflowed")]
|
||||
DecompositionBaseOverflow,
|
||||
}
|
||||
|
||||
@@ -76,7 +76,10 @@ pub enum HybridOp {
|
||||
output_scale: utils::F32,
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
RangeCheck(Tolerance),
|
||||
Output {
|
||||
tol: Tolerance,
|
||||
decomp: bool,
|
||||
},
|
||||
Greater,
|
||||
GreaterEqual,
|
||||
Less,
|
||||
@@ -178,7 +181,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale, output_scale, axes
|
||||
)
|
||||
}
|
||||
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
|
||||
HybridOp::Output { tol, decomp } => {
|
||||
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
|
||||
}
|
||||
HybridOp::Greater => "GREATER".to_string(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
|
||||
HybridOp::Less => "LESS".to_string(),
|
||||
@@ -314,12 +319,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
*output_scale,
|
||||
axes,
|
||||
)?,
|
||||
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
|
||||
HybridOp::Output { tol, decomp } => layouts::output(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
tol.scale,
|
||||
tol.val,
|
||||
*decomp,
|
||||
)?,
|
||||
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
|
||||
HybridOp::GreaterEqual => {
|
||||
|
||||
@@ -11,7 +11,6 @@ use log::{error, trace};
|
||||
use maybe_rayon::{
|
||||
iter::IntoParallelRefIterator,
|
||||
prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
|
||||
slice::ParallelSliceMut,
|
||||
};
|
||||
|
||||
use self::tensor::{create_constant_tensor, create_zero_tensor};
|
||||
@@ -75,7 +74,7 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>,
|
||||
f: impl Fn(&BaseConfig<F>, &mut RegionCtx<F>, &ValTensor<F>) -> Result<ValTensor<F>, CircuitError>,
|
||||
) -> Result<(), CircuitError> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let one = create_constant_tensor(F::from(1), 1);
|
||||
|
||||
let f_x = f(config, region, x)?;
|
||||
@@ -87,22 +86,17 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
let f_x_minus_1 = f(config, region, &x_minus_1)?;
|
||||
|
||||
// because the function is convex, the result should be the minimum of the three
|
||||
// not that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) <= f(x-1)
|
||||
// note that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) < f(x-1)
|
||||
// the result is 1 if the function is optimal solely because of the convexity of the function
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal (or f(x) and f(x-1)).
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal, but if (f(x) = f(x + 1))
|
||||
// f(x+1) is not smaller than f(x + 1 - 1) = f(x) and thus f(x) is unique
|
||||
let f_x_is_opt_rhs = less_equal(config, region, &[f_x.clone(), f_x_plus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less_equal(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
|
||||
let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(())
|
||||
Ok(is_opt)
|
||||
}
|
||||
|
||||
/// Err is less than some constant
|
||||
@@ -160,13 +154,15 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
// implicitly check if the prover provided output is within range
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
if F::from_u128(IntegerRep::MAX as u128)
|
||||
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
|
||||
{
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -226,9 +222,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// implicitly check if the prover provided output is within range
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
// divide by input_scale
|
||||
let zero_inverse_val =
|
||||
tensor::ops::nonlinearities::zero_recip(felt_to_integer_rep(output_scale) as f64)[0];
|
||||
@@ -259,10 +255,12 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
if F::from_u128(IntegerRep::MAX as u128)
|
||||
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
|
||||
{
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -290,7 +288,14 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
// we need to add 1 to the points where it is zero to ignore the cvx opt conditions at those points
|
||||
let mut is_opt = optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
is_opt = pairwise(config, region, &[is_opt, equal_zero_mask], BaseOp::Add)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
@@ -344,12 +349,8 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// force the output to be positive or zero
|
||||
// force the output to be positive or zero, also implicitly checks that the ouput is in range
|
||||
let claimed_output = abs(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
// rescaled input
|
||||
let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?;
|
||||
|
||||
@@ -362,7 +363,13 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
let is_opt = optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
@@ -899,11 +906,23 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
/// Determines how to handle collisions in sorting.
|
||||
pub enum SortCollisionMode {
|
||||
/// Do not sort (no rule)
|
||||
Unsorted,
|
||||
/// Sort by smallest index first
|
||||
SmallestIndexFirst,
|
||||
/// Sort by largest index first on collision
|
||||
LargestIndexFirst,
|
||||
}
|
||||
|
||||
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
collision_handling: SortCollisionMode,
|
||||
) -> Result<(ValTensor<F>, ValTensor<F>), CircuitError> {
|
||||
let mut input = values[0].clone();
|
||||
input.flatten();
|
||||
|
||||
@@ -911,7 +930,7 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let sorted = if is_assigned {
|
||||
let mut int_evals = input.int_evals()?;
|
||||
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
|
||||
int_evals.sort_unstable();
|
||||
int_evals
|
||||
.par_iter()
|
||||
.map(|x| Value::known(integer_rep_to_felt(*x)))
|
||||
@@ -924,21 +943,73 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
};
|
||||
|
||||
let assigned_sort = region.assign(&config.custom_gates.inputs[0], &sorted.into())?;
|
||||
|
||||
region.increment(assigned_sort.len());
|
||||
// assert that this is a permutation/shuffle
|
||||
let indices = shuffles(
|
||||
config,
|
||||
region,
|
||||
&[assigned_sort.clone()],
|
||||
&[input.clone()],
|
||||
collision_handling,
|
||||
)?;
|
||||
|
||||
let window_a = assigned_sort.get_slice(&[0..assigned_sort.len() - 1])?;
|
||||
let window_b = assigned_sort.get_slice(&[1..assigned_sort.len()])?;
|
||||
|
||||
let is_greater = greater_equal(config, region, &[window_b.clone(), window_a.clone()])?;
|
||||
let unit = create_unit_tensor(is_greater.len());
|
||||
let indices_a = indices.get_slice(&[0..indices.len() - 1])?;
|
||||
let indices_b = indices.get_slice(&[1..indices.len()])?;
|
||||
|
||||
enforce_equality(config, region, &[unit, is_greater])?;
|
||||
let unit = create_unit_tensor(window_a.len());
|
||||
|
||||
// assert that this is a permutation/shuffle
|
||||
shuffles(config, region, &[assigned_sort.clone()], &[input.clone()])?;
|
||||
match collision_handling {
|
||||
SortCollisionMode::Unsorted => {
|
||||
let is_greater = greater_equal(config, region, &[window_b.clone(), window_a.clone()])?;
|
||||
enforce_equality(config, region, &[unit, is_greater])?;
|
||||
}
|
||||
SortCollisionMode::SmallestIndexFirst => {
|
||||
let is_greater = greater(config, region, &[window_b.clone(), window_a.clone()])?;
|
||||
let is_equal = equals(config, region, &[window_b.clone(), window_a.clone()])?;
|
||||
let is_greater_indices =
|
||||
greater(config, region, &[indices_b.clone(), indices_a.clone()])?;
|
||||
|
||||
Ok(assigned_sort)
|
||||
let is_equal_and_is_greater_indices =
|
||||
and(config, region, &[is_equal, is_greater_indices])?;
|
||||
|
||||
let is_greater_or_is_equal_and_is_greater_indices = or(
|
||||
config,
|
||||
region,
|
||||
&[is_greater, is_equal_and_is_greater_indices],
|
||||
)?;
|
||||
|
||||
enforce_equality(
|
||||
config,
|
||||
region,
|
||||
&[unit, is_greater_or_is_equal_and_is_greater_indices],
|
||||
)?;
|
||||
}
|
||||
SortCollisionMode::LargestIndexFirst => {
|
||||
let is_greater = greater(config, region, &[window_b.clone(), window_a.clone()])?;
|
||||
let is_equal = equals(config, region, &[window_b.clone(), window_a.clone()])?;
|
||||
let is_lesser_indices = less(config, region, &[indices_b.clone(), indices_a.clone()])?;
|
||||
|
||||
let is_equal_and_is_lesser_indices =
|
||||
and(config, region, &[is_equal, is_lesser_indices])?;
|
||||
|
||||
let is_greater_or_is_equal_and_is_greater_indices = or(
|
||||
config,
|
||||
region,
|
||||
&[is_greater, is_equal_and_is_lesser_indices],
|
||||
)?;
|
||||
|
||||
enforce_equality(
|
||||
config,
|
||||
region,
|
||||
&[unit, is_greater_or_is_equal_and_is_greater_indices],
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((assigned_sort, indices))
|
||||
}
|
||||
|
||||
/// Returns top K values.
|
||||
@@ -949,7 +1020,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
k: usize,
|
||||
largest: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut sorted = _sort_ascending(config, region, values)?;
|
||||
let mut sorted = _sort_ascending(config, region, values, SortCollisionMode::Unsorted)?.0;
|
||||
if largest {
|
||||
sorted.reverse()?;
|
||||
}
|
||||
@@ -1177,8 +1248,6 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
|
||||
region.enable(Some(lookup_selector), z)?;
|
||||
|
||||
// region.enable(Some(lookup_selector), z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
@@ -1198,12 +1267,13 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
/// 4. index_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 5. value_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 6. Given the above, and given the fixed index_input , we go through every (index_input, value_input) pair and ascertain that it is contained in the input.
|
||||
/// Given the fixed incrementing index index_input, we avoid multiplicity in the output by leveraging this surrogate index: if index_output isn't matched to the exact value where for `index_input=index_output` -> `value_input=value_output`, then the lookup fails
|
||||
/// 7. Given the fixed incrementing index index_input, we avoid multiplicity in the output by leveraging this surrogate index: if index_output isn't matched to the exact value where for `index_input=index_output` -> `value_input=value_output`, then the lookup fails
|
||||
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
output: &[ValTensor<F>; 1],
|
||||
input: &[ValTensor<F>; 1],
|
||||
collision_handling: SortCollisionMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let shuffle_index = region.shuffle_index();
|
||||
let (output, input) = (output[0].clone(), input[0].clone());
|
||||
@@ -1247,13 +1317,26 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
.iter()
|
||||
.map(|x| {
|
||||
// Find all positions of the current element
|
||||
let positions: Vec<usize> = input
|
||||
let mut positions: Vec<usize> = input
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, y)| *y == x)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
match collision_handling {
|
||||
SortCollisionMode::Unsorted => {}
|
||||
SortCollisionMode::SmallestIndexFirst => {
|
||||
// Sort the positions by the index of the input element
|
||||
positions.sort_unstable_by(|a, b| input[*a].cmp(&input[*b]));
|
||||
}
|
||||
|
||||
SortCollisionMode::LargestIndexFirst => {
|
||||
// Sort the positions by the index of the input element
|
||||
positions.reverse();
|
||||
}
|
||||
}
|
||||
|
||||
// Find the first unused position for this element
|
||||
let pos = positions
|
||||
.iter()
|
||||
@@ -1326,7 +1409,7 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
region.increment_shuffle_index(1);
|
||||
region.increment(output_len);
|
||||
|
||||
Ok(output)
|
||||
Ok(claimed_index_output)
|
||||
}
|
||||
|
||||
/// One hot accumulated layout
|
||||
@@ -1795,11 +1878,18 @@ pub(crate) fn get_missing_set_elements<
|
||||
region,
|
||||
&[input_and_claimed_output.clone()],
|
||||
&[fullset.clone()],
|
||||
SortCollisionMode::Unsorted,
|
||||
)?;
|
||||
|
||||
if ordered {
|
||||
// assert that the claimed output is sorted
|
||||
claimed_output = _sort_ascending(config, region, &[claimed_output])?;
|
||||
claimed_output = _sort_ascending(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output],
|
||||
SortCollisionMode::Unsorted,
|
||||
)?
|
||||
.0;
|
||||
}
|
||||
|
||||
Ok(claimed_output)
|
||||
@@ -2566,9 +2656,9 @@ pub fn mean_of_squares_axes<F: PrimeField + TensorType + PartialOrd + std::hash:
|
||||
let squared = pow(config, region, values, 2)?;
|
||||
let sum_squared = sum_axes(config, region, &[squared], axes)?;
|
||||
|
||||
let dividand: usize = values[0].len() / sum_squared.len();
|
||||
let dividend: usize = values[0].len() / sum_squared.len();
|
||||
|
||||
let mean_squared = div(config, region, &[sum_squared], F::from(dividand as u64))?;
|
||||
let mean_squared = div(config, region, &[sum_squared], F::from(dividend as u64))?;
|
||||
Ok(mean_squared)
|
||||
}
|
||||
|
||||
@@ -2977,7 +3067,7 @@ pub fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let lhs_and_rhs_not = and(config, region, &[lhs, rhs_not.clone()])?;
|
||||
let lhs_not_and_rhs = and(config, region, &[rhs, lhs_not])?;
|
||||
|
||||
// we can safely use add and not OR here because we know that lhs_and_rhs_not and lhs_not_and_rhs are =1 at different incices
|
||||
// we can safely use add and not OR here because we know that lhs_and_rhs_not and lhs_not_and_rhs are =1 at different indices
|
||||
let res: ValTensor<F> = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -3254,11 +3344,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.map(|(i, d)| {
|
||||
let d = padding[i].0 + d + padding[i].1;
|
||||
d.checked_sub(pool_dims[i])
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))?
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))?
|
||||
.checked_div(stride[i])
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))?
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))?
|
||||
.checked_add(1)
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))
|
||||
})
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
|
||||
@@ -3917,11 +4007,24 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
decomp: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut output = values[0].clone();
|
||||
if !output.all_prev_assigned() {
|
||||
output = region.assign(&config.custom_gates.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
// checks they are in range
|
||||
if decomp {
|
||||
output = decompose(
|
||||
config,
|
||||
region,
|
||||
&[output.clone()],
|
||||
®ion.base(),
|
||||
®ion.legs(),
|
||||
)?
|
||||
.1;
|
||||
} else {
|
||||
output = region.assign(&config.custom_gates.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
@@ -3942,23 +4045,8 @@ pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd + std::ha
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.custom_gates.output.cartesian_coord(index);
|
||||
let selector = config
|
||||
.custom_gates
|
||||
.selectors
|
||||
.get(&(BaseOp::IsBoolean, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
range_check(config, region, values, &(0, 1))?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -4209,9 +4297,11 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[values[0].clone(), assigned_argmax.clone()],
|
||||
)?;
|
||||
|
||||
let max_val = max(config, region, &[values[0].clone()])?;
|
||||
let (sorted_val, indices) =
|
||||
_sort_ascending(config, region, values, SortCollisionMode::LargestIndexFirst)?;
|
||||
|
||||
enforce_equality(config, region, &[claimed_val, max_val])?;
|
||||
enforce_equality(config, region, &[claimed_val, sorted_val.last()?])?;
|
||||
enforce_equality(config, region, &[assigned_argmax.clone(), indices.last()?])?;
|
||||
|
||||
Ok(assigned_argmax)
|
||||
}
|
||||
@@ -4245,9 +4335,14 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region,
|
||||
&[values[0].clone(), assigned_argmin.clone()],
|
||||
)?;
|
||||
let min_val = min(config, region, &[values[0].clone()])?;
|
||||
|
||||
enforce_equality(config, region, &[claimed_val, min_val])?;
|
||||
let (min_val, indices) = _sort_ascending(
|
||||
config,
|
||||
region,
|
||||
values,
|
||||
SortCollisionMode::SmallestIndexFirst,
|
||||
)?;
|
||||
enforce_equality(config, region, &[claimed_val, min_val.first()?])?;
|
||||
enforce_equality(config, region, &[assigned_argmin.clone(), indices.first()?])?;
|
||||
|
||||
Ok(assigned_argmin)
|
||||
}
|
||||
@@ -4362,7 +4457,11 @@ pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
Ok(_sort_ascending(config, region, values)?.last()?)
|
||||
Ok(
|
||||
_sort_ascending(config, region, values, SortCollisionMode::Unsorted)?
|
||||
.0
|
||||
.last()?,
|
||||
)
|
||||
}
|
||||
|
||||
/// min layout
|
||||
@@ -4371,7 +4470,11 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
Ok(_sort_ascending(config, region, values)?.first()?)
|
||||
Ok(
|
||||
_sort_ascending(config, region, values, SortCollisionMode::Unsorted)?
|
||||
.0
|
||||
.first()?,
|
||||
)
|
||||
}
|
||||
|
||||
/// floor layout
|
||||
@@ -4411,7 +4514,7 @@ pub fn floor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?.0;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
@@ -4524,7 +4627,7 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?.0;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
@@ -4678,7 +4781,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input.dims())?;
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
let claimed_output = identity(&config, region, &[claimed_output], true)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
let pow2_of_claimed_output = nonlinearity(
|
||||
@@ -4924,7 +5027,7 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?.0;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
@@ -4942,6 +5045,7 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
1,
|
||||
);
|
||||
let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?;
|
||||
region.increment(assigned_midway_point.len());
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
@@ -5067,7 +5171,7 @@ pub fn round_half_to_even<F: PrimeField + TensorType + PartialOrd + std::hash::H
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?.0;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
@@ -5175,58 +5279,64 @@ pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
values: &[ValTensor<F>; 1],
|
||||
base: &usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = values[0].clone();
|
||||
let mut input = values[0].clone();
|
||||
|
||||
let first_dims = input.dims().to_vec()[..input.dims().len() - 1].to_vec();
|
||||
let num_first_dims = first_dims.iter().product::<usize>();
|
||||
let n = input.dims().last().unwrap() - 1;
|
||||
|
||||
let is_assigned = !input.all_prev_assigned();
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
region.increment(input.len());
|
||||
}
|
||||
|
||||
let bases: ValTensor<F> = Tensor::from(
|
||||
(0..n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))),
|
||||
)
|
||||
// to force the bases to be assigned
|
||||
if input.is_singleton() {
|
||||
input.reshape(&[1])?;
|
||||
}
|
||||
|
||||
let mut bases: ValTensor<F> = Tensor::from({
|
||||
(0..num_first_dims)
|
||||
.flat_map(|_| {
|
||||
(0..n).rev().map(|x| {
|
||||
let base = (*base).checked_pow(x as u32);
|
||||
if let Some(base) = base {
|
||||
Ok(ValType::Constant(integer_rep_to_felt(base as IntegerRep)))
|
||||
} else {
|
||||
Err(CircuitError::DecompositionBaseOverflow)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?
|
||||
.into_iter()
|
||||
})
|
||||
.into();
|
||||
let mut bases_dims = first_dims.clone();
|
||||
bases_dims.push(n);
|
||||
bases.reshape(&bases_dims)?;
|
||||
|
||||
// multiply and sum the values
|
||||
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, &first_dims)?;
|
||||
// equation needs to be constructed as ij,j->i but for arbitrary n dims we need to construct this dynamically
|
||||
// indices should map in order of the alphabet
|
||||
// start with lhs
|
||||
let lhs = ASCII_ALPHABET.chars().take(input.dims().len()).join("");
|
||||
let rhs = ASCII_ALPHABET.chars().take(input.dims().len() - 1).join("");
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
let equation = format!("{},{}->{}", lhs, lhs, rhs);
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let mut sign_slice = first_dims.iter().map(|x| 0..*x).collect::<Vec<_>>();
|
||||
sign_slice.push(0..1);
|
||||
let mut rest_slice = first_dims.iter().map(|x| 0..*x).collect::<Vec<_>>();
|
||||
rest_slice.push(1..n + 1);
|
||||
|
||||
if !is_assigned {
|
||||
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
|
||||
}
|
||||
let sign = input.get_slice(&sign_slice)?;
|
||||
let rest = input.get_slice(&rest_slice)?;
|
||||
|
||||
// get the sign bit and make sure it is valid
|
||||
let sign = sliced_input.first()?;
|
||||
let rest = sliced_input.get_slice(&[1..sliced_input.len()])?;
|
||||
// now add the rhs
|
||||
let prod_recomp = einsum(config, region, &[rest.clone(), bases], &equation)?;
|
||||
let mut signed_recomp = pairwise(config, region, &[prod_recomp, sign], BaseOp::Mult)?;
|
||||
signed_recomp.reshape(&first_dims)?;
|
||||
|
||||
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
|
||||
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
Ok(signed_decomp.get_inner_tensor()?.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let mut combined_output = output.combine()?;
|
||||
|
||||
combined_output.reshape(&first_dims)?;
|
||||
|
||||
Ok(combined_output.into())
|
||||
Ok(signed_recomp.into())
|
||||
}
|
||||
|
||||
pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
@@ -5235,24 +5345,35 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
values: &[ValTensor<F>; 1],
|
||||
base: &usize,
|
||||
n: &usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
) -> Result<(ValTensor<F>, ValTensor<F>), CircuitError> {
|
||||
let mut input = values[0].clone();
|
||||
|
||||
let is_assigned = !input.all_prev_assigned();
|
||||
|
||||
if !is_assigned {
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
}
|
||||
|
||||
let mut bases: ValTensor<F> = Tensor::from(
|
||||
// repeat it input.len() times
|
||||
(0..input.len()).flat_map(|_| {
|
||||
(0..*n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep)))
|
||||
}),
|
||||
)
|
||||
// to force the bases to be assigned
|
||||
if input.is_singleton() {
|
||||
input.reshape(&[1])?;
|
||||
}
|
||||
|
||||
let mut bases: ValTensor<F> = Tensor::from({
|
||||
(0..input.len())
|
||||
.flat_map(|_| {
|
||||
(0..*n).rev().map(|x| {
|
||||
let base = (*base).checked_pow(x as u32);
|
||||
if let Some(base) = base {
|
||||
Ok(ValType::Constant(integer_rep_to_felt(base as IntegerRep)))
|
||||
} else {
|
||||
Err(CircuitError::DecompositionBaseOverflow)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?
|
||||
.into_iter()
|
||||
})
|
||||
.into();
|
||||
|
||||
let mut bases_dims = input.dims().to_vec();
|
||||
bases_dims.push(*n);
|
||||
bases.reshape(&bases_dims)?;
|
||||
@@ -5271,7 +5392,7 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
|
||||
claimed_output.into()
|
||||
};
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
let input_slice = input.dims().iter().map(|x| 0..*x).collect::<Vec<_>>();
|
||||
@@ -5316,9 +5437,9 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
enforce_equality(config, region, &[input, signed_decomp])?;
|
||||
enforce_equality(config, region, &[input.clone(), signed_decomp])?;
|
||||
|
||||
Ok(claimed_output)
|
||||
Ok((claimed_output, input))
|
||||
}
|
||||
|
||||
pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
@@ -5326,7 +5447,7 @@ pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut decomp = decompose(config, region, values, ®ion.base(), ®ion.legs())?;
|
||||
let mut decomp = decompose(config, region, values, ®ion.base(), ®ion.legs())?.0;
|
||||
// get every n elements now, which correspond to the sign bit
|
||||
decomp.get_every_n(region.legs() + 1)?;
|
||||
decomp.reshape(values[0].dims())?;
|
||||
@@ -5608,7 +5729,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::range_check_percent;
|
||||
/// use ezkl::circuit::ops::layouts::output;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
@@ -5626,28 +5747,32 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[101, 201, 302, 403, 503, 603]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = range_check_percent::<Fp>(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0).unwrap();
|
||||
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0, false).unwrap();
|
||||
/// ```
|
||||
pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
scale: utils::F32,
|
||||
tol: f32,
|
||||
decomp: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if tol == 0.0 {
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, values);
|
||||
}
|
||||
|
||||
let mut values = [values[0].clone(), values[1].clone()];
|
||||
|
||||
values[0] = region.assign(&config.custom_gates.inputs[0], &values[0])?;
|
||||
values[1] = region.assign(&config.custom_gates.inputs[1], &values[1])?;
|
||||
let total_assigned_0 = values[0].len();
|
||||
let total_assigned_1 = values[1].len();
|
||||
let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1);
|
||||
region.increment(total_assigned);
|
||||
if !values[0].all_prev_assigned() {
|
||||
// range check the outputs
|
||||
values[0] = layouts::identity(config, region, &[values[0].clone()], decomp)?;
|
||||
}
|
||||
|
||||
if !values[1].all_prev_assigned() {
|
||||
// range check the outputs
|
||||
values[1] = layouts::identity(config, region, &[values[1].clone()], decomp)?;
|
||||
}
|
||||
|
||||
if tol == 0.0 {
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
|
||||
}
|
||||
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
|
||||
|
||||
@@ -159,6 +159,8 @@ pub struct Input {
|
||||
pub scale: crate::Scale,
|
||||
///
|
||||
pub datum_type: InputType,
|
||||
/// decomp check
|
||||
pub decomp: bool,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
@@ -196,6 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
self.decomp,
|
||||
)?)),
|
||||
}
|
||||
} else {
|
||||
@@ -251,20 +254,26 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
///
|
||||
#[serde(skip)]
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
///
|
||||
pub decomp: bool,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>, decomp: bool) -> Self {
|
||||
Self {
|
||||
quantized_values,
|
||||
raw_values,
|
||||
pre_assigned_val: None,
|
||||
decomp,
|
||||
}
|
||||
}
|
||||
/// Rebase the scale of the constant
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
|
||||
let visibility = self.quantized_values.visibility().unwrap();
|
||||
let visibility = match self.quantized_values.visibility() {
|
||||
Some(v) => v,
|
||||
None => return Err(CircuitError::UnsetVisibility),
|
||||
};
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -308,7 +317,12 @@ impl<
|
||||
self.quantized_values.clone().try_into()?
|
||||
};
|
||||
// we gotta constrain it once if its used multiple times
|
||||
Ok(Some(layouts::identity(config, region, &[value])?))
|
||||
Ok(Some(layouts::identity(
|
||||
config,
|
||||
region,
|
||||
&[value],
|
||||
self.decomp,
|
||||
)?))
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
|
||||
@@ -252,6 +252,12 @@ impl<
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 1 {
|
||||
return Err(TensorError::DimError(
|
||||
"GatherElements only accepts single inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
@@ -269,6 +275,12 @@ impl<
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 2 {
|
||||
return Err(TensorError::DimError(
|
||||
"ScatterElements requires two inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
@@ -311,7 +323,9 @@ impl<
|
||||
PolyOp::Mult => {
|
||||
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
|
||||
}
|
||||
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Identity { .. } => {
|
||||
layouts::identity(config, region, values[..].try_into()?, false)?
|
||||
}
|
||||
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
|
||||
PolyOp::Pad(p) => {
|
||||
if values.len() != 1 {
|
||||
|
||||
@@ -132,21 +132,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size given the number of rows and reserved blinding rows
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
(range_len / col_size as IntegerRep) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
@@ -168,7 +163,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
nonlinearity: &LookupOp,
|
||||
preexisting_inputs: Option<Vec<TableColumn>>,
|
||||
preexisting_inputs: &mut Vec<TableColumn>,
|
||||
) -> Table<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
@@ -177,28 +172,28 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
debug!("table range: {:?}", range);
|
||||
|
||||
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
|
||||
let mut cols = vec![];
|
||||
for _ in 0..num_cols {
|
||||
cols.push(cs.lookup_table_column());
|
||||
// validate enough columns are provided to store the range
|
||||
if preexisting_inputs.len() < num_cols {
|
||||
// add columns to match the required number of columns
|
||||
let diff = num_cols - preexisting_inputs.len();
|
||||
for _ in 0..diff {
|
||||
preexisting_inputs.push(cs.lookup_table_column());
|
||||
}
|
||||
cols
|
||||
});
|
||||
|
||||
let num_cols = table_inputs.len();
|
||||
}
|
||||
|
||||
let num_cols = preexisting_inputs.len();
|
||||
if num_cols > 1 {
|
||||
warn!("Using {} columns for non-linearity table.", num_cols);
|
||||
}
|
||||
|
||||
let table_outputs = table_inputs
|
||||
let table_outputs = preexisting_inputs
|
||||
.iter()
|
||||
.map(|_| cs.lookup_table_column())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Table {
|
||||
nonlinearity: nonlinearity.clone(),
|
||||
table_inputs,
|
||||
table_inputs: preexisting_inputs.clone(),
|
||||
table_outputs,
|
||||
is_assigned: false,
|
||||
selector_constructor: SelectorConstructor::new(num_cols),
|
||||
@@ -355,16 +350,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
|
||||
@@ -1813,6 +1813,7 @@ mod shuffle {
|
||||
&mut region,
|
||||
&self.inputs[i],
|
||||
&self.references[i],
|
||||
layouts::SortCollisionMode::Unsorted,
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)?;
|
||||
}
|
||||
@@ -1998,7 +1999,7 @@ mod add_with_overflow_and_poseidon {
|
||||
let base = BaseConfig::configure(cs, &[a, b], &output, CheckMode::SAFE);
|
||||
VarTensor::constant_cols(cs, K, 2, false);
|
||||
|
||||
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::configure(cs, ());
|
||||
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(cs, ());
|
||||
|
||||
MyCircuitConfig { base, poseidon }
|
||||
}
|
||||
@@ -2008,7 +2009,7 @@ mod add_with_overflow_and_poseidon {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
|
||||
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> =
|
||||
PoseidonChip::new(config.poseidon.clone());
|
||||
|
||||
let assigned_inputs_a =
|
||||
@@ -2043,11 +2044,9 @@ mod add_with_overflow_and_poseidon {
|
||||
let b = (0..LEN)
|
||||
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
|
||||
.collect::<Vec<_>>();
|
||||
let commitment_a =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone()).unwrap()[0][0];
|
||||
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0];
|
||||
|
||||
let commitment_b =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone()).unwrap()[0][0];
|
||||
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0];
|
||||
|
||||
// parameters
|
||||
let a = Tensor::from(a.into_iter().map(Value::known));
|
||||
@@ -2069,13 +2068,11 @@ mod add_with_overflow_and_poseidon {
|
||||
let b = (0..LEN)
|
||||
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
|
||||
.collect::<Vec<_>>();
|
||||
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone())
|
||||
.unwrap()[0][0]
|
||||
+ Fr::one();
|
||||
let commitment_a =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0] + Fr::one();
|
||||
|
||||
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone())
|
||||
.unwrap()[0][0]
|
||||
+ Fr::one();
|
||||
let commitment_b =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0] + Fr::one();
|
||||
|
||||
// parameters
|
||||
let a = Tensor::from(a.into_iter().map(Value::known));
|
||||
|
||||
@@ -517,6 +517,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn deploy_multi_da_contract(
|
||||
client: EthersClient,
|
||||
contract_instance_offset: usize,
|
||||
|
||||
@@ -118,7 +118,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
} => gen_srs_cmd(
|
||||
srs_path,
|
||||
logrows as u32,
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
commitment.unwrap_or_else(|| Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
),
|
||||
Commands::GetSrs {
|
||||
srs_path,
|
||||
@@ -1535,7 +1535,8 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
trace!("params computed");
|
||||
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
let data =
|
||||
GraphData::from_path(input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
// The number of input and output instances we attest to for the single call data attestation
|
||||
let mut input_len = None;
|
||||
@@ -2126,6 +2127,7 @@ pub(crate) fn mock_aggregate(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
|
||||
@@ -9,6 +9,8 @@ pub type IntegerRep = i128;
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else if x == IntegerRep::MIN {
|
||||
-F::from_u128(x.saturating_neg() as u128) - F::ONE
|
||||
} else {
|
||||
-F::from_u128(x.saturating_neg() as u128)
|
||||
}
|
||||
@@ -32,6 +34,9 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN;
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -51,7 +56,7 @@ mod test {
|
||||
use halo2curves::pasta::Fp as F;
|
||||
|
||||
#[test]
|
||||
fn test_conv() {
|
||||
fn integerreptofelt() {
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
@@ -73,4 +78,20 @@ mod test {
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmin() {
|
||||
let x = IntegerRep::MIN;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmax() {
|
||||
let x = IntegerRep::MAX;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,12 @@ pub enum GraphError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Non scalar power
|
||||
#[error("we only support scalar powers")]
|
||||
NonScalarPower,
|
||||
/// Non scalar base for exponentiation
|
||||
#[error("we only support scalar bases for exponentiation")]
|
||||
NonScalarBase,
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
@@ -113,13 +119,13 @@ pub enum GraphError {
|
||||
/// Missing input for a node
|
||||
#[error("missing input for node {0}")]
|
||||
MissingInput(usize),
|
||||
///
|
||||
/// Ranges can only be constant
|
||||
#[error("range only supports constant inputs in a zk circuit")]
|
||||
NonConstantRange,
|
||||
///
|
||||
/// Trilu diagonal must be constant
|
||||
#[error("trilu only supports constant diagonals in a zk circuit")]
|
||||
NonConstantTrilu,
|
||||
///
|
||||
/// The witness was too short
|
||||
#[error("insufficient witness values to generate a fixed output")]
|
||||
InsufficientWitnessValues,
|
||||
/// Missing scale
|
||||
@@ -143,4 +149,13 @@ pub enum GraphError {
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
/// Only nearest neighbor interpolation is supported
|
||||
#[error("only nearest neighbor interpolation is supported")]
|
||||
InvalidInterpolation,
|
||||
/// Node has a missing output
|
||||
#[error("node {0} has a missing output")]
|
||||
MissingOutput(usize),
|
||||
/// Inssuficient advice columns
|
||||
#[error("insuficcient advice columns (need {0} at least)")]
|
||||
InsufficientAdviceColumns(usize),
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
|
||||
|
||||
@@ -31,30 +32,95 @@ type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
|
||||
///
|
||||
/// Represents different types of values that can be stored in a file source
|
||||
/// Used for handling various input types in zero-knowledge proofs
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum FileSourceInner {
|
||||
/// Inner elements of float inputs coming from a file
|
||||
/// Floating point value (64-bit)
|
||||
Float(f64),
|
||||
/// Inner elements of bool inputs coming from a file
|
||||
/// Boolean value
|
||||
Bool(bool),
|
||||
/// Inner elements of inputs coming from a witness
|
||||
/// Field element value for direct use in circuits
|
||||
Field(Fp),
|
||||
}
|
||||
|
||||
impl FileSourceInner {
|
||||
///
|
||||
/// Returns true if the value is a floating point number
|
||||
pub fn is_float(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Float(_))
|
||||
}
|
||||
///
|
||||
|
||||
/// Returns true if the value is a boolean
|
||||
pub fn is_bool(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Bool(_))
|
||||
}
|
||||
///
|
||||
|
||||
/// Returns true if the value is a field element
|
||||
pub fn is_field(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Field(_))
|
||||
}
|
||||
|
||||
/// Creates a new floating point value
|
||||
pub fn new_float(f: f64) -> Self {
|
||||
FileSourceInner::Float(f)
|
||||
}
|
||||
|
||||
/// Creates a new field element value
|
||||
pub fn new_field(f: Fp) -> Self {
|
||||
FileSourceInner::Field(f)
|
||||
}
|
||||
|
||||
/// Creates a new boolean value
|
||||
pub fn new_bool(f: bool) -> Self {
|
||||
FileSourceInner::Bool(f)
|
||||
}
|
||||
|
||||
/// Adjusts the value according to the specified input type
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_type` - Type specification to convert the value to
|
||||
pub fn as_type(&mut self, input_type: &InputType) {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
||||
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
||||
FileSourceInner::Field(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the value to a field element using appropriate scaling
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `scale` - Scaling factor for floating point conversion
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
} else {
|
||||
Fp::zero()
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => *f,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the value to a floating point number
|
||||
pub fn to_float(&self) -> f64 {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => *f,
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for FileSourceInner {
|
||||
@@ -70,8 +136,8 @@ impl Serialize for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
// Deserialization implementation for FileSourceInner
|
||||
// Uses JSON deserialization to handle the different variants
|
||||
impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -98,70 +164,16 @@ impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
/// Elements of inputs coming from a file
|
||||
/// A collection of input values from a file source
|
||||
/// Organized as a vector of vectors where each inner vector represents a row/entry
|
||||
pub type FileSource = Vec<Vec<FileSourceInner>>;
|
||||
|
||||
impl FileSourceInner {
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_float(f: f64) -> Self {
|
||||
FileSourceInner::Float(f)
|
||||
}
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_field(f: Fp) -> Self {
|
||||
FileSourceInner::Field(f)
|
||||
}
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_bool(f: bool) -> Self {
|
||||
FileSourceInner::Bool(f)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn as_type(&mut self, input_type: &InputType) {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
||||
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
||||
FileSourceInner::Field(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
} else {
|
||||
Fp::zero()
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => *f,
|
||||
}
|
||||
}
|
||||
/// Convert to a float
|
||||
pub fn to_float(&self) -> f64 {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => *f,
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Call type for attested inputs on-chain
|
||||
/// Represents different types of calls for fetching on-chain data
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum Calls {
|
||||
/// Vector of calls to accounts, each returning an attested data point
|
||||
/// Multiple calls to different accounts, each returning individual values
|
||||
Multiple(Vec<CallsToAccount>),
|
||||
/// Single call to account, returning an array of attested data points
|
||||
/// Single call returning an array of values
|
||||
Single(CallToAccount),
|
||||
}
|
||||
|
||||
@@ -170,32 +182,6 @@ impl Default for Calls {
|
||||
Calls::Multiple(Vec::new())
|
||||
}
|
||||
}
|
||||
/// Inner elements of inputs/outputs coming from on-chain
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Calls to accounts
|
||||
pub calls: Calls,
|
||||
/// RPC url
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Create a new OnChainSource with multiple calls
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new OnChainSource with a single call
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Calls {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
@@ -217,7 +203,6 @@ impl<'de> Deserialize<'de> for Calls {
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = multiple_try {
|
||||
return Ok(Calls::Multiple(t));
|
||||
@@ -227,111 +212,52 @@ impl<'de> Deserialize<'de> for Calls {
|
||||
return Ok(Calls::Single(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom(
|
||||
"failed to deserialize FileSourceInner",
|
||||
))
|
||||
Err(serde::de::Error::custom("failed to deserialize Calls"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Inner elements of inputs/outputs coming from postgres DB
|
||||
/// Configuration for accessing on-chain data sources
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
/// postgres host
|
||||
pub host: RPCUrl,
|
||||
/// user to connect to postgres
|
||||
pub user: String,
|
||||
/// password to connect to postgres
|
||||
pub password: String,
|
||||
/// query to execute
|
||||
pub query: String,
|
||||
/// dbname
|
||||
pub dbname: String,
|
||||
/// port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Create a new PostgresSource
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch data from postgres
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// clone to move into thread
|
||||
let user = self.user.clone();
|
||||
let host = self.host.clone();
|
||||
let query = self.query.clone();
|
||||
let dbname = self.dbname.clone();
|
||||
let port = self.port.clone();
|
||||
let password = self.password.clone();
|
||||
|
||||
let config = if password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
host, user, dbname, port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
host, user, dbname, port, password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[]).await? {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetch data from postgres and format it as a FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
pub struct OnChainSource {
|
||||
/// Call specifications for fetching data
|
||||
pub calls: Calls,
|
||||
/// RPC endpoint URL for accessing the chain
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Creates a new OnChainSource with multiple calls
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `calls` - Vector of call specifications
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new OnChainSource with a single call
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `call` - Call specification
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Create dummy local on-chain data to test the OnChain data source
|
||||
/// Creates test data for the OnChain data source
|
||||
/// Used for testing and development purposes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Sample file data to use
|
||||
/// * `scales` - Scaling factors for each input
|
||||
/// * `shapes` - Shapes of the input tensors
|
||||
/// * `rpc` - Optional RPC endpoint override
|
||||
pub async fn test_from_file_data(
|
||||
data: &FileSource,
|
||||
scales: Vec<crate::Scale>,
|
||||
@@ -398,48 +324,40 @@ impl OnChainSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines the view only calls to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
/// Specification for view-only calls to fetch on-chain data
|
||||
/// Used for data attestation in smart contract verification
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallsToAccount {
|
||||
/// A vector of tuples, where index 0 of tuples
|
||||
/// are the byte strings representing the ABI encoded function calls to
|
||||
/// read the data from the address. This call must return a single
|
||||
/// elementary type (<https://docs.soliditylang.org/en/v0.8.20/abi-spec.html#types>).
|
||||
/// The second index of the tuple is the number of decimals for f32 conversion.
|
||||
/// We don't support dynamic types currently.
|
||||
/// Vector of (call data, decimals) pairs
|
||||
/// call_data: ABI-encoded function call
|
||||
/// decimals: Number of decimal places for float conversion
|
||||
pub call_data: Vec<(Call, Decimals)>,
|
||||
/// Address of the contract to read the data from.
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Defines a view only call to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
/// Specification for a single view-only call returning an array
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallToAccount {
|
||||
/// The call_data is a byte strings representing the ABI encoded function call to
|
||||
/// read the data from the address. This call must return a single array of integers that can be
|
||||
/// be safely cast to the int128 type in solidity.
|
||||
/// ABI-encoded function call data
|
||||
pub call_data: Call,
|
||||
/// The number of decimals for f32 conversion of all of the elements returned from the
|
||||
/// call.
|
||||
/// Number of decimal places for float conversion
|
||||
pub decimals: Decimals,
|
||||
/// Address of the contract to read the data from.
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
/// The number of elements returned from the call.
|
||||
/// Expected length of returned array
|
||||
pub len: usize,
|
||||
}
|
||||
/// Enum that defines source of the inputs/outputs to the EZKL model
|
||||
|
||||
/// Represents different sources of input/output data for the EZKL model
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum DataSource {
|
||||
/// .json File data source.
|
||||
/// Data from a JSON file containing arrays of values
|
||||
File(FileSource),
|
||||
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
|
||||
/// Data fetched from blockchain contracts
|
||||
OnChain(OnChainSource),
|
||||
/// Postgres DB
|
||||
/// Data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DB(PostgresSource),
|
||||
}
|
||||
@@ -482,8 +400,7 @@ impl From<OnChainSource> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
// Note: Always use JSON serialization for untagged enums
|
||||
impl<'de> Deserialize<'de> for DataSource {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -491,15 +408,19 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
// Try deserializing as FileSource first
|
||||
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
|
||||
|
||||
if let Ok(t) = first_try {
|
||||
return Ok(DataSource::File(t));
|
||||
}
|
||||
|
||||
// Try deserializing as OnChainSource
|
||||
let second_try: Result<OnChainSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = second_try {
|
||||
return Ok(DataSource::OnChain(t));
|
||||
}
|
||||
|
||||
// Try deserializing as PostgresSource if feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
|
||||
@@ -512,22 +433,29 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// Input to graph as a datasource
|
||||
/// Always use JSON serialization for GraphData. Seriously.
|
||||
/// Container for input and output data for graph computations
|
||||
///
|
||||
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
|
||||
pub struct GraphData {
|
||||
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
||||
/// Input data for the model/graph
|
||||
/// Can be empty if inputs come from on-chain sources
|
||||
pub input_data: DataSource,
|
||||
/// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain).
|
||||
|
||||
/// Optional output data for the model/graph
|
||||
/// Can be empty if outputs come from on-chain sources
|
||||
pub output_data: Option<DataSource>,
|
||||
}
|
||||
|
||||
impl UnwindSafe for GraphData {}
|
||||
|
||||
impl GraphData {
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Convert the input data to tract data
|
||||
/// Converts the input data to tract's tensor format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `shapes` - Expected shapes for each input tensor
|
||||
/// * `datum_types` - Expected data types for each input
|
||||
pub fn to_tract_data(
|
||||
&self,
|
||||
shapes: &[Vec<usize>],
|
||||
@@ -556,9 +484,14 @@ impl GraphData {
|
||||
Ok(inputs)
|
||||
}
|
||||
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Convert the tract data to tract data
|
||||
/// Converts tract tensor data into GraphData format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensors` - Array of tract tensors to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the converted tensor data
|
||||
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
|
||||
use tract_onnx::prelude::DatumType;
|
||||
|
||||
@@ -584,7 +517,10 @@ impl GraphData {
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a new GraphData instance with given input data
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_data` - The input data source
|
||||
pub fn new(input_data: DataSource) -> Self {
|
||||
GraphData {
|
||||
input_data,
|
||||
@@ -592,7 +528,13 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the model input from a file
|
||||
/// Loads graph input data from a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the input file
|
||||
///
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the loaded data
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let reader = std::fs::File::open(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
@@ -606,23 +548,35 @@ impl GraphData {
|
||||
Ok(graph_input)
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
/// Saves the graph data to a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path where to save the data
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let file = std::fs::File::create(path.clone()).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Splits the input data into multiple batches based on input shapes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_shapes` - Vector of shapes for each input tensor
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of GraphData instances, one for each batch
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if:
|
||||
/// - Data is from on-chain source
|
||||
/// - Input size is not evenly divisible by batch size
|
||||
pub async fn split_into_batches(
|
||||
&self,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
// split input data into batches
|
||||
let mut batched_inputs = vec![];
|
||||
|
||||
let iterable = match self {
|
||||
@@ -646,10 +600,12 @@ impl GraphData {
|
||||
} => data.fetch_and_format_as_file().await?,
|
||||
};
|
||||
|
||||
// Process each input tensor according to its shape
|
||||
for (i, shape) in input_shapes.iter().enumerate() {
|
||||
// ensure the input is evenly divisible by batch_size
|
||||
let input_size = shape.clone().iter().product::<usize>();
|
||||
let input = &iterable[i];
|
||||
|
||||
// Validate input size is divisible by batch size
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
@@ -657,6 +613,8 @@ impl GraphData {
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Split input into batches
|
||||
let mut batches = vec![];
|
||||
for batch in input.chunks(input_size) {
|
||||
batches.push(batch.to_vec());
|
||||
@@ -664,18 +622,18 @@ impl GraphData {
|
||||
batched_inputs.push(batches);
|
||||
}
|
||||
|
||||
// now merge all the batches for each input into a vector of batches
|
||||
// first assert each input has the same number of batches
|
||||
// Merge batches across inputs
|
||||
let num_batches = if batched_inputs.is_empty() {
|
||||
0
|
||||
} else {
|
||||
let num_batches = batched_inputs[0].len();
|
||||
// Verify all inputs have same number of batches
|
||||
for input in batched_inputs.iter() {
|
||||
assert_eq!(input.len(), num_batches);
|
||||
}
|
||||
num_batches
|
||||
};
|
||||
// now merge the batches
|
||||
|
||||
let mut input_batches = vec![];
|
||||
for i in 0..num_batches {
|
||||
let mut batch = vec![];
|
||||
@@ -685,11 +643,12 @@ impl GraphData {
|
||||
input_batches.push(DataSource::File(batch));
|
||||
}
|
||||
|
||||
// Ensure at least one batch exists
|
||||
if input_batches.is_empty() {
|
||||
input_batches.push(DataSource::File(vec![vec![]]));
|
||||
}
|
||||
|
||||
// create a new GraphWitness for each batch
|
||||
// Create GraphData instance for each batch
|
||||
let batches = input_batches
|
||||
.into_iter()
|
||||
.map(GraphData::new)
|
||||
@@ -701,6 +660,7 @@ impl GraphData {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallsToAccount {
|
||||
/// Converts CallsToAccount to Python object
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("account", &self.address).unwrap();
|
||||
@@ -709,6 +669,165 @@ impl ToPyObject for CallsToAccount {
|
||||
}
|
||||
}
|
||||
|
||||
// Additional Python bindings for various types...
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_postgres_source_new() {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let source = PostgresSource::new(
|
||||
"localhost".to_string(),
|
||||
"5432".to_string(),
|
||||
"user".to_string(),
|
||||
"SELECT * FROM table".to_string(),
|
||||
"database".to_string(),
|
||||
"password".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(source.host, "localhost");
|
||||
assert_eq!(source.port, "5432");
|
||||
assert_eq!(source.user, "user");
|
||||
assert_eq!(source.query, "SELECT * FROM table");
|
||||
assert_eq!(source.dbname, "database");
|
||||
assert_eq!(source.password, "password");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
// Test backwards compatibility with old format
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
// Test serialization/deserialization of graph input
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_python_compat() {
|
||||
// Test compatibility with mclbn256 library serialization
|
||||
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
||||
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
||||
assert_eq!(format!("{:?}", source), original_addr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Source data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
/// Database host address
|
||||
pub host: RPCUrl,
|
||||
/// Database user name
|
||||
pub user: String,
|
||||
/// Database password
|
||||
pub password: String,
|
||||
/// SQL query to execute
|
||||
pub query: String,
|
||||
/// Database name
|
||||
pub dbname: String,
|
||||
/// Database port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Creates a new PostgreSQL data source
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches data from the PostgreSQL database
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// Configuration string
|
||||
let config = if self.password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
self.host, self.user, self.dbname, self.port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
self.host, self.user, self.dbname, self.port, self.password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
|
||||
// Extract rows from query
|
||||
for row in client.query(&self.query, &[]).await? {
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetches and formats data as FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallToAccount {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -743,6 +862,7 @@ impl ToPyObject for DataSource {
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DataSource::DB(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("host", &source.host).unwrap();
|
||||
@@ -767,57 +887,3 @@ impl ToPyObject for FileSourceInner {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
|
||||
// test for the compatibility with the serialized elements from the mclbn256 library
|
||||
#[test]
|
||||
fn test_python_compat() {
|
||||
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
||||
|
||||
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
||||
|
||||
assert_eq!(format!("{:?}", source), original_addr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -619,11 +619,6 @@ impl GraphSettings {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn uses_modules(&self) -> bool {
|
||||
!self.module_sizes.max_constraints() > 0
|
||||
}
|
||||
|
||||
/// if any visibility is encrypted or hashed
|
||||
pub fn module_requires_fixed(&self) -> bool {
|
||||
self.run_args.input_visibility.is_hashed()
|
||||
@@ -766,7 +761,7 @@ pub struct TestOnChainData {
|
||||
pub data: std::path::PathBuf,
|
||||
/// rpc endpoint
|
||||
pub rpc: Option<String>,
|
||||
///
|
||||
/// data sources for the on chain data
|
||||
pub data_sources: TestSources,
|
||||
}
|
||||
|
||||
@@ -954,7 +949,7 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
_ => unreachable!("cannot load from on-chain data"),
|
||||
_ => Err(GraphError::OnChainDataSource),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -384,8 +384,7 @@ pub struct ParsedNodes {
|
||||
impl ParsedNodes {
|
||||
/// Returns the number of the computational graph's inputs
|
||||
pub fn num_inputs(&self) -> usize {
|
||||
let input_nodes = self.inputs.iter();
|
||||
input_nodes.len()
|
||||
self.inputs.len()
|
||||
}
|
||||
|
||||
/// Input types
|
||||
@@ -425,8 +424,7 @@ impl ParsedNodes {
|
||||
|
||||
/// Returns the number of the computational graph's outputs
|
||||
pub fn num_outputs(&self) -> usize {
|
||||
let output_nodes = self.outputs.iter();
|
||||
output_nodes.len()
|
||||
self.outputs.len()
|
||||
}
|
||||
|
||||
/// Returns shapes of the computational graph's outputs
|
||||
@@ -634,6 +632,10 @@ impl Model {
|
||||
|
||||
for (i, id) in model.clone().inputs.iter().enumerate() {
|
||||
let input = model.node_mut(id.node);
|
||||
|
||||
if input.outputs.len() == 0 {
|
||||
return Err(GraphError::MissingOutput(id.node));
|
||||
}
|
||||
let mut fact: InferenceFact = input.outputs[0].fact.clone();
|
||||
|
||||
for (i, x) in fact.clone().shape.dims().enumerate() {
|
||||
@@ -906,6 +908,7 @@ impl Model {
|
||||
n.opkind = SupportedOp::Input(Input {
|
||||
scale,
|
||||
datum_type: inp.datum_type,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
});
|
||||
input_idx += 1;
|
||||
n.out_scale = scale;
|
||||
@@ -1016,6 +1019,10 @@ impl Model {
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
if vars.advices.len() < 3 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(3));
|
||||
}
|
||||
|
||||
let mut base_gate = PolyConfig::configure(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
@@ -1035,6 +1042,10 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_dynamic_lookup() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
|
||||
base_gate.configure_dynamic_lookup(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
@@ -1043,6 +1054,9 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_shuffle() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
base_gate.configure_shuffles(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
@@ -1061,6 +1075,7 @@ impl Model {
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `witnessed_outputs` - The values to compare against.
|
||||
/// * `constants` - The constants for the circuit.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn layout(
|
||||
&self,
|
||||
mut config: ModelConfig,
|
||||
@@ -1131,8 +1146,8 @@ impl Model {
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let comparators = if run_args.output_visibility == Visibility::Public {
|
||||
let res = vars
|
||||
@@ -1155,7 +1170,10 @@ impl Model {
|
||||
.layout(
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
)
|
||||
.map_err(|e| e.into())
|
||||
})
|
||||
@@ -1432,13 +1450,16 @@ impl Model {
|
||||
.into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
let mut tol = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
@@ -1460,7 +1481,7 @@ impl Model {
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_felt_evals()
|
||||
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
.unwrap_or_else(|_| Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1530,6 +1551,7 @@ impl Model {
|
||||
let mut op = crate::circuit::Constant::new(
|
||||
c.quantized_values.clone(),
|
||||
c.raw_values.clone(),
|
||||
c.decomp,
|
||||
);
|
||||
op.pre_assign(consts[const_idx].clone());
|
||||
n.opkind = SupportedOp::Constant(op);
|
||||
|
||||
@@ -14,14 +14,11 @@ use serde::{Deserialize, Serialize};
|
||||
use super::errors::GraphError;
|
||||
use super::{VarVisibility, Visibility};
|
||||
|
||||
/// poseidon len to hash in tree
|
||||
pub const POSEIDON_LEN_GRAPH: usize = 32;
|
||||
/// Poseidon number of instances
|
||||
pub const POSEIDON_INSTANCES: usize = 1;
|
||||
|
||||
/// Poseidon module type
|
||||
pub type ModulePoseidon =
|
||||
PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>;
|
||||
pub type ModulePoseidon = PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>;
|
||||
/// Poseidon module config
|
||||
pub type ModulePoseidonConfig = PoseidonConfig<POSEIDON_WIDTH, POSEIDON_RATE>;
|
||||
|
||||
@@ -284,7 +281,6 @@ impl GraphModules {
|
||||
log::error!("Poseidon config not initialized");
|
||||
return Err(Error::Synthesis);
|
||||
}
|
||||
// If the module is encrypted, then we need to encrypt the inputs
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
// Import dependencies for scaling operations
|
||||
use super::scale_to_multiplier;
|
||||
|
||||
// Import ONNX-specific utilities when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::utilities::node_output_shapes;
|
||||
|
||||
// Import scale management types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
|
||||
// Import visibility settings for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::Visibility;
|
||||
|
||||
// Import operation types for different circuit components
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
@@ -13,28 +22,49 @@ use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
|
||||
// Import graph error types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::errors::GraphError;
|
||||
|
||||
// Import ONNX operation conversion utilities
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
|
||||
// Import tensor error handling
|
||||
use crate::tensor::TensorError;
|
||||
|
||||
// Import curve-specific field type
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
|
||||
// Import logging for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use log::trace;
|
||||
|
||||
// Import serialization traits
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
// Import data structures for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
// Import formatting traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::fmt;
|
||||
|
||||
// Import table display formatting for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tabled::Tabled;
|
||||
|
||||
// Import ONNX-specific types and traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::{
|
||||
self,
|
||||
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
|
||||
};
|
||||
|
||||
/// Helper function to format vectors for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
if !v.is_empty() {
|
||||
@@ -44,29 +74,35 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to format operation kinds for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_opkind(v: &SupportedOp) -> String {
|
||||
v.as_string()
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
/// A wrapper for an operation that has been rescaled to handle different precision requirements.
|
||||
/// This enables operations to work with inputs that have been scaled to different fixed-point representations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Rescaled {
|
||||
/// The operation that has to be rescaled.
|
||||
/// The underlying operation that needs to be rescaled
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// The scale of the operation's inputs.
|
||||
/// Vector of (index, scale) pairs defining how each input should be scaled
|
||||
pub scale: Vec<(usize, u128)>,
|
||||
}
|
||||
|
||||
/// Implementation of the Op trait for Rescaled operations
|
||||
impl Op<Fp> for Rescaled {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
@@ -77,6 +113,7 @@ impl Op<Fp> for Rescaled {
|
||||
Op::<Fp>::out_scale(&*self.inner, in_scales)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -93,28 +130,40 @@ impl Op<Fp> for Rescaled {
|
||||
self.inner.layout(config, region, res)
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
/// A wrapper for operations that require scale rebasing
|
||||
/// This handles cases where operation scales need to be adjusted to a target scale
|
||||
/// while preserving the numerical relationships
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RebaseScale {
|
||||
/// The operation that has to be rescaled.
|
||||
/// The operation that needs to be rescaled
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// rebase op
|
||||
/// Operation used for rebasing, typically division
|
||||
pub rebase_op: HybridOp,
|
||||
/// scale being rebased to
|
||||
/// Scale that we're rebasing to
|
||||
pub target_scale: i32,
|
||||
/// The original scale of the operation's inputs.
|
||||
/// Original scale of operation's inputs before rebasing
|
||||
pub original_scale: i32,
|
||||
/// multiplier
|
||||
/// Scaling multiplier used in rebasing
|
||||
pub multiplier: f64,
|
||||
}
|
||||
|
||||
impl RebaseScale {
|
||||
/// Creates a rebased version of an operation if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `global_scale` - Base scale for the system
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
/// * `scale_rebase_multiplier` - Factor determining when rebasing should occur
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation depending on scale relationships
|
||||
pub fn rebase(
|
||||
inner: SupportedOp,
|
||||
global_scale: crate::Scale,
|
||||
@@ -155,7 +204,15 @@ impl RebaseScale {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a rebased operation with increased scale
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `target_scale` - Scale to rebase to
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation with increased scale
|
||||
pub fn rebase_up(
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
@@ -192,10 +249,12 @@ impl RebaseScale {
|
||||
}
|
||||
|
||||
impl Op<Fp> for RebaseScale {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
"REBASED (div={:?}, rebasing_op={}) ({})",
|
||||
@@ -205,10 +264,12 @@ impl Op<Fp> for RebaseScale {
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -222,34 +283,40 @@ impl Op<Fp> for RebaseScale {
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
/// Represents all supported operation types in the circuit
|
||||
/// Each variant encapsulates a different type of operation with specific behavior
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum SupportedOp {
|
||||
/// A linear operation.
|
||||
/// Linear operations (polynomial-based)
|
||||
Linear(PolyOp),
|
||||
/// A nonlinear operation.
|
||||
/// Nonlinear operations requiring lookup tables
|
||||
Nonlinear(LookupOp),
|
||||
/// A hybrid operation.
|
||||
/// Mixed operations combining different approaches
|
||||
Hybrid(HybridOp),
|
||||
///
|
||||
/// Input values to the circuit
|
||||
Input(Input),
|
||||
///
|
||||
/// Constant values in the circuit
|
||||
Constant(Constant<Fp>),
|
||||
///
|
||||
/// Placeholder for unsupported operations
|
||||
Unknown(Unknown),
|
||||
///
|
||||
/// Operations requiring rescaling of inputs
|
||||
Rescaled(Rescaled),
|
||||
///
|
||||
/// Operations requiring scale rebasing
|
||||
RebaseScale(RebaseScale),
|
||||
}
|
||||
|
||||
impl SupportedOp {
|
||||
/// Checks if the operation is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if operation requires lookup table
|
||||
/// * `false` otherwise
|
||||
pub fn is_lookup(&self) -> bool {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(_) => true,
|
||||
@@ -257,7 +324,12 @@ impl SupportedOp {
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns input operation if this is an input
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(Input)` if this is an input operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_input(&self) -> Option<Input> {
|
||||
match self {
|
||||
SupportedOp::Input(op) => Some(op.clone()),
|
||||
@@ -265,7 +337,11 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to rebased operation if this is a rebased operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&RebaseScale)` if this is a rebased operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_rebased(&self) -> Option<&RebaseScale> {
|
||||
match self {
|
||||
SupportedOp::RebaseScale(op) => Some(op),
|
||||
@@ -273,7 +349,11 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to lookup operation if this is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&LookupOp)` if this is a lookup operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_lookup(&self) -> Option<&LookupOp> {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(op) => Some(op),
|
||||
@@ -281,7 +361,11 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_constant(&self) -> Option<&Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -289,7 +373,11 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns mutable reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&mut Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_mutable_constant(&mut self) -> Option<&mut Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -297,18 +385,19 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a homogeneously rescaled version of this operation if needed
|
||||
/// Only available with EZKL feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn homogenous_rescale(
|
||||
&self,
|
||||
in_scales: Vec<crate::Scale>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
let op = self.clone_dyn();
|
||||
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
|
||||
}
|
||||
|
||||
/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
|
||||
/// Returns reference to underlying Op implementation
|
||||
fn as_op(&self) -> &dyn Op<Fp> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => op,
|
||||
@@ -322,9 +411,10 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// check if is the identity operation
|
||||
/// Checks if this is an identity operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if the operation is the identity operation
|
||||
/// * `true` if this operation passes input through unchanged
|
||||
/// * `false` otherwise
|
||||
pub fn is_identity(&self) -> bool {
|
||||
match self {
|
||||
@@ -361,9 +451,11 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
if let Some(op) = value.as_any().downcast_ref::<Unknown>() {
|
||||
return SupportedOp::Unknown(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<Rescaled>() {
|
||||
return SupportedOp::Rescaled(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<RebaseScale>() {
|
||||
return SupportedOp::RebaseScale(op.clone());
|
||||
};
|
||||
@@ -375,6 +467,7 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
/// Layout this operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -384,54 +477,61 @@ impl Op<Fp> for SupportedOp {
|
||||
self.as_op().layout(config, region, values)
|
||||
}
|
||||
|
||||
/// Check if this is an input operation
|
||||
fn is_input(&self) -> bool {
|
||||
self.as_op().is_input()
|
||||
}
|
||||
|
||||
/// Check if this is a constant operation
|
||||
fn is_constant(&self) -> bool {
|
||||
self.as_op().is_constant()
|
||||
}
|
||||
|
||||
/// Get which inputs require homogeneous scales
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
self.as_op().requires_homogenous_input_scales()
|
||||
}
|
||||
|
||||
/// Create a clone of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
self.as_op().clone_dyn()
|
||||
}
|
||||
|
||||
/// Get string representation
|
||||
fn as_string(&self) -> String {
|
||||
self.as_op().as_string()
|
||||
}
|
||||
|
||||
/// Convert to Any type
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Calculate output scale from input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
}
|
||||
|
||||
/// A node's input is a tensor from another node's output.
|
||||
/// Represents a connection to another node's output
|
||||
/// First element is node index, second is output slot index
|
||||
pub type Outlet = (usize, usize);
|
||||
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
/// Represents a single computational node in the circuit graph
|
||||
/// Contains all information needed to execute and connect operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Node {
|
||||
/// [Op] i.e what operation this node represents.
|
||||
/// The operation this node performs
|
||||
pub opkind: SupportedOp,
|
||||
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
|
||||
/// Fixed point scale factor for this node's output
|
||||
pub out_scale: i32,
|
||||
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
|
||||
// but in_dim is [in], out_dim is [out]
|
||||
/// The indices of the node's inputs.
|
||||
/// Connections to other nodes' outputs that serve as inputs
|
||||
pub inputs: Vec<Outlet>,
|
||||
/// Dimensions of output.
|
||||
/// Shape of this node's output tensor
|
||||
pub out_dims: Vec<usize>,
|
||||
/// The node's unique identifier.
|
||||
/// Unique identifier for this node
|
||||
pub idx: usize,
|
||||
/// The node's num of uses
|
||||
/// Number of times this node's output is used
|
||||
pub num_uses: usize,
|
||||
}
|
||||
|
||||
@@ -469,12 +569,19 @@ impl PartialEq for Node {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Converts a tract [OnnxNode] into an ezkl [Node].
|
||||
/// # Arguments:
|
||||
/// * `node` - [OnnxNode]
|
||||
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
|
||||
/// * `public_params` - flag if parameters of model are public
|
||||
/// * `idx` - The node's unique identifier.
|
||||
/// Creates a new Node from an ONNX node
|
||||
/// Only available when EZKL feature is enabled
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node` - Source ONNX node
|
||||
/// * `other_nodes` - Map of existing nodes in the graph
|
||||
/// * `scales` - Scale factors for variables
|
||||
/// * `idx` - Unique identifier for this node
|
||||
/// * `symbol_values` - ONNX symbol values
|
||||
/// * `run_args` - Runtime configuration arguments
|
||||
///
|
||||
/// # Returns
|
||||
/// New Node instance or error if creation fails
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
@@ -612,16 +719,14 @@ impl Node {
|
||||
})
|
||||
}
|
||||
|
||||
/// check if it is a softmax node
|
||||
/// Check if this node performs softmax operation
|
||||
pub fn is_softmax(&self) -> bool {
|
||||
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
matches!(self.opkind, SupportedOp::Hybrid(HybridOp::Softmax { .. }))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to rescale constants that are only used once
|
||||
/// Only available when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn rescale_const_with_single_use(
|
||||
constant: &mut Constant<Fp>,
|
||||
|
||||
@@ -44,11 +44,11 @@ use tract_onnx::tract_hir::{
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `vec` - the vector to quantize.
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `elem` - the element to quantize.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(
|
||||
@@ -59,7 +59,7 @@ pub fn quantize_float(
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
if *elem > max_value || *elem < -max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
|
||||
f64::powf(2., scale as f64)
|
||||
}
|
||||
|
||||
/// Converts a scale (log base 2) to a fixed point multiplier.
|
||||
/// Converts a fixed point multiplier to a scale (log base 2).
|
||||
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
mult.log2().round() as crate::Scale
|
||||
}
|
||||
@@ -228,10 +228,7 @@ pub fn extract_tensor_value(
|
||||
.iter()
|
||||
.map(|x| match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
},
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -277,11 +274,9 @@ pub fn new_op_from_onnx(
|
||||
symbol_values: &SymbolValues,
|
||||
run_args: &crate::RunArgs,
|
||||
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
|
||||
use std::f64::consts::E;
|
||||
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
use crate::circuit::InputType;
|
||||
use std::f64::consts::E;
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
let input_scales = inputs
|
||||
.iter()
|
||||
@@ -312,6 +307,9 @@ pub fn new_op_from_onnx(
|
||||
let mut deleted_indices = vec![];
|
||||
let node = match node.op().name().as_ref() {
|
||||
"ShiftLeft" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -324,10 +322,13 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
|
||||
}
|
||||
}
|
||||
"ShiftRight" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -340,7 +341,7 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
@@ -363,7 +364,10 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
|
||||
if input_ops.len() != 3 {
|
||||
return Err(GraphError::InvalidDims(idx, "range".to_string()));
|
||||
}
|
||||
|
||||
let input_ops = input_ops
|
||||
.iter()
|
||||
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
|
||||
@@ -378,7 +382,11 @@ pub fn new_op_from_onnx(
|
||||
// Quantize the raw value (integers)
|
||||
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
|
||||
|
||||
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
raw_value,
|
||||
!run_args.ignore_range_check_inputs_outputs,
|
||||
);
|
||||
// Create a constant op
|
||||
SupportedOp::Constant(c)
|
||||
}
|
||||
@@ -419,6 +427,10 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
|
||||
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| {
|
||||
@@ -436,6 +448,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: false,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -447,8 +460,17 @@ pub fn new_op_from_onnx(
|
||||
"Topk" => {
|
||||
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
};
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
}
|
||||
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
c.raw_values.map(|x| x as usize)[0]
|
||||
@@ -488,6 +510,10 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -499,6 +525,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -522,6 +549,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -532,6 +562,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -555,6 +586,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -566,6 +600,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -589,6 +624,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -600,6 +638,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -674,7 +713,11 @@ pub fn new_op_from_onnx(
|
||||
constant_scale,
|
||||
&run_args.param_visibility,
|
||||
)?;
|
||||
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
raw_value,
|
||||
run_args.ignore_range_check_inputs_outputs,
|
||||
);
|
||||
// Create a constant op
|
||||
SupportedOp::Constant(c)
|
||||
}
|
||||
@@ -684,7 +727,9 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
assert_eq!(axes.len(), 1, "only support argmax over one axis");
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
|
||||
}
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
|
||||
}
|
||||
@@ -694,7 +739,9 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
assert_eq!(axes.len(), 1, "only support argmin over one axis");
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
|
||||
}
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
|
||||
}
|
||||
@@ -803,6 +850,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
@@ -846,6 +896,9 @@ pub fn new_op_from_onnx(
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
@@ -927,13 +980,19 @@ pub fn new_op_from_onnx(
|
||||
DatumType::F64 => (scales.input, InputType::F64),
|
||||
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
|
||||
};
|
||||
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
|
||||
SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale,
|
||||
datum_type,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
})
|
||||
}
|
||||
"Cast" => {
|
||||
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
|
||||
let dt = op.to;
|
||||
|
||||
assert_eq!(input_scales.len(), 1);
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
|
||||
};
|
||||
|
||||
match dt {
|
||||
DatumType::Bool
|
||||
@@ -983,6 +1042,11 @@ pub fn new_op_from_onnx(
|
||||
|
||||
if const_idx.len() == 1 {
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if inputs.len() <= const_idx {
|
||||
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
|
||||
// if not divisible by 2 then we need to add a range check
|
||||
@@ -1057,6 +1121,9 @@ pub fn new_op_from_onnx(
|
||||
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
|
||||
}
|
||||
};
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
|
||||
}
|
||||
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
@@ -1096,22 +1163,42 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Ceil" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Floor" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Round" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "round".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"RoundHalfToEven" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
// Extract the slope layer hyperparams from a const
|
||||
@@ -1121,7 +1208,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
return Err(GraphError::NonScalarPower);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
@@ -1138,7 +1227,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[0].decrement_use();
|
||||
deleted_indices.push(0);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar base")
|
||||
return Err(GraphError::NonScalarBase);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
|
||||
let base = c.raw_values[0];
|
||||
@@ -1148,10 +1239,14 @@ pub fn new_op_from_onnx(
|
||||
base: base.into(),
|
||||
})
|
||||
} else {
|
||||
unimplemented!("only support constant base or pow for now")
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
}
|
||||
"Div" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -1159,14 +1254,15 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 {
|
||||
if const_idx.len() > 1 || const_idx.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if const_idx != 1 {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
@@ -1176,14 +1272,28 @@ pub fn new_op_from_onnx(
|
||||
// get the non constant index
|
||||
let denom = c.raw_values[0];
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Div {
|
||||
let op = SupportedOp::Hybrid(HybridOp::Div {
|
||||
denom: denom.into(),
|
||||
})
|
||||
});
|
||||
|
||||
// if the input is scale 0 we re up to the max scale
|
||||
if input_scales[0] == 0 {
|
||||
SupportedOp::Rescaled(Rescaled {
|
||||
inner: Box::new(op),
|
||||
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
|
||||
})
|
||||
} else {
|
||||
op
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support non zero divisors of size 1")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support non zero divisors of size 1".to_string(),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
|
||||
@@ -1323,7 +1433,7 @@ pub fn new_op_from_onnx(
|
||||
if !resize_node.contains("interpolator: Nearest")
|
||||
&& !resize_node.contains("nearest: Floor")
|
||||
{
|
||||
unimplemented!("Only nearest neighbor interpolation is supported")
|
||||
return Err(GraphError::InvalidInterpolation);
|
||||
}
|
||||
// check if optional scale factor is present
|
||||
if inputs.len() != 2 && inputs.len() != 3 {
|
||||
@@ -1427,6 +1537,10 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Linear(PolyOp::Reshape(output_shape))
|
||||
}
|
||||
"Flatten" => {
|
||||
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
|
||||
};
|
||||
|
||||
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
|
||||
SupportedOp::Linear(PolyOp::Flatten(new_dims))
|
||||
}
|
||||
@@ -1500,12 +1614,10 @@ pub fn homogenize_input_scales(
|
||||
input_scales: Vec<crate::Scale>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let relevant_input_scales = input_scales
|
||||
.clone()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| inputs_to_scale.contains(idx))
|
||||
.map(|(_, scale)| scale)
|
||||
let relevant_input_scales = inputs_to_scale
|
||||
.iter()
|
||||
.filter(|idx| input_scales.len() > **idx)
|
||||
.map(|&idx| input_scales[idx])
|
||||
.collect_vec();
|
||||
|
||||
if inputs_to_scale.is_empty() {
|
||||
@@ -1546,10 +1658,30 @@ pub fn homogenize_input_scales(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// tests for the utility module
|
||||
pub mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
// quantization tests
|
||||
#[test]
|
||||
fn test_quantize_tensor() {
|
||||
let tensor: Tensor<f32> = (0..10).map(|x| x as f32).into();
|
||||
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
let scale = 0;
|
||||
let visibility = &Visibility::Public;
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
|
||||
assert_eq!(quantized.len(), 10);
|
||||
assert_eq!(quantized, reference);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_edge_cases() {
|
||||
assert_eq!(quantize_float(&f64::NAN, 0.0, 0).unwrap(), 0);
|
||||
assert!(quantize_float(&f64::INFINITY, 0.0, 0).is_err());
|
||||
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flatten_valtensors() {
|
||||
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
|
||||
@@ -11,35 +11,34 @@ use log::debug;
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
|
||||
/// Defines the visibility level of values within the zero-knowledge circuit
|
||||
/// Controls how values are handled during proof generation and verification
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
|
||||
pub enum Visibility {
|
||||
/// Mark an item as private to the prover (not in the proof submitted for verification)
|
||||
/// Value is private to the prover and not included in proof
|
||||
#[default]
|
||||
Private,
|
||||
/// Mark an item as public (sent in the proof submitted for verification)
|
||||
/// Value is public and included in proof for verification
|
||||
Public,
|
||||
/// Mark an item as publicly committed to (hash sent in the proof submitted for verification)
|
||||
/// Value is hashed and the hash is included in proof
|
||||
Hashed {
|
||||
/// Whether the hash is used as an instance (sent in the proof submitted for verification)
|
||||
/// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph
|
||||
/// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph
|
||||
/// Controls how the hash is handled in proof
|
||||
/// true - hash is included directly in proof (public)
|
||||
/// false - hash is used as advice and passed to computational graph
|
||||
hash_is_public: bool,
|
||||
///
|
||||
/// Specifies which outputs this hash affects
|
||||
outlets: Vec<usize>,
|
||||
},
|
||||
/// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification)
|
||||
/// Value is committed using KZG commitment scheme
|
||||
KZGCommit,
|
||||
/// assigned as a constant in the circuit
|
||||
/// Value is assigned as a constant in the circuit
|
||||
Fixed,
|
||||
}
|
||||
|
||||
@@ -66,15 +65,17 @@ impl Display for Visibility {
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Visibility {
|
||||
/// Converts visibility to command line flags
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
/// Converts string representation to Visibility
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
// split on last occurrence of '/'
|
||||
// Split on last occurrence of '/'
|
||||
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -106,8 +107,8 @@ impl<'a> From<&'a str> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Visibility {
|
||||
/// Converts Visibility to Python object
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
Visibility::Private => "private".to_object(py),
|
||||
@@ -134,14 +135,13 @@ impl IntoPy<PyObject> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Visibility {
|
||||
/// Extracts Visibility from Python object
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
let strval = strval.as_str();
|
||||
|
||||
if strval.contains("hashed/private") {
|
||||
// split on last occurence of '/'
|
||||
let (_, outlets) = strval.split_at(strval.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -174,29 +174,32 @@ impl<'source> FromPyObject<'source> for Visibility {
|
||||
}
|
||||
|
||||
impl Visibility {
|
||||
#[allow(missing_docs)]
|
||||
/// Returns true if visibility is Fixed
|
||||
pub fn is_fixed(&self) -> bool {
|
||||
matches!(&self, Visibility::Fixed)
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns true if visibility is Private or hashed private
|
||||
pub fn is_private(&self) -> bool {
|
||||
matches!(&self, Visibility::Private) || self.is_hashed_private()
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// Returns true if visibility is Public
|
||||
pub fn is_public(&self) -> bool {
|
||||
matches!(&self, Visibility::Public)
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns true if visibility involves hashing
|
||||
pub fn is_hashed(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. })
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns true if visibility uses KZG commitment
|
||||
pub fn is_polycommit(&self) -> bool {
|
||||
matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// Returns true if visibility is hashed with public hash
|
||||
pub fn is_hashed_public(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: true,
|
||||
@@ -207,7 +210,8 @@ impl Visibility {
|
||||
}
|
||||
false
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns true if visibility is hashed with private hash
|
||||
pub fn is_hashed_private(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: false,
|
||||
@@ -219,11 +223,12 @@ impl Visibility {
|
||||
false
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// Returns true if visibility requires additional processing
|
||||
pub fn requires_processing(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. }) | matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns vector of output indices that this visibility setting affects
|
||||
pub fn overwrites_inputs(&self) -> Vec<usize> {
|
||||
if let Visibility::Hashed { outlets, .. } = self {
|
||||
return outlets.clone();
|
||||
@@ -232,14 +237,14 @@ impl Visibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the scale of the model input, model parameters.
|
||||
/// Manages scaling factors for different parts of the model
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarScales {
|
||||
///
|
||||
/// Scale factor for input values
|
||||
pub input: crate::Scale,
|
||||
///
|
||||
/// Scale factor for parameter values
|
||||
pub params: crate::Scale,
|
||||
///
|
||||
/// Multiplier for scale rebasing
|
||||
pub rebase_multiplier: u32,
|
||||
}
|
||||
|
||||
@@ -250,17 +255,17 @@ impl std::fmt::Display for VarScales {
|
||||
}
|
||||
|
||||
impl VarScales {
|
||||
///
|
||||
/// Returns maximum scale value
|
||||
pub fn get_max(&self) -> crate::Scale {
|
||||
std::cmp::max(self.input, self.params)
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns minimum scale value
|
||||
pub fn get_min(&self) -> crate::Scale {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Place in [VarScales] struct.
|
||||
/// Creates VarScales from runtime arguments
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
@@ -270,16 +275,17 @@ impl VarScales {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Controls visibility settings for different parts of the model
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarVisibility {
|
||||
/// Input to the model or computational graph
|
||||
/// Visibility of model inputs
|
||||
pub input: Visibility,
|
||||
/// Parameters, such as weights and biases, in the model
|
||||
/// Visibility of model parameters (weights, biases)
|
||||
pub params: Visibility,
|
||||
/// Output of the model or computational graph
|
||||
/// Visibility of model outputs
|
||||
pub output: Visibility,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VarVisibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
@@ -301,8 +307,7 @@ impl Default for VarVisibility {
|
||||
}
|
||||
|
||||
impl VarVisibility {
|
||||
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Place in [VarVisibility] struct.
|
||||
/// Creates visibility settings from runtime arguments
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let input_vis = &args.input_visibility;
|
||||
let params_vis = &args.param_visibility;
|
||||
@@ -313,17 +318,17 @@ impl VarVisibility {
|
||||
}
|
||||
|
||||
if !output_vis.is_public()
|
||||
& !params_vis.is_public()
|
||||
& !input_vis.is_public()
|
||||
& !output_vis.is_fixed()
|
||||
& !params_vis.is_fixed()
|
||||
& !input_vis.is_fixed()
|
||||
& !output_vis.is_hashed()
|
||||
& !params_vis.is_hashed()
|
||||
& !input_vis.is_hashed()
|
||||
& !output_vis.is_polycommit()
|
||||
& !params_vis.is_polycommit()
|
||||
& !input_vis.is_polycommit()
|
||||
&& !params_vis.is_public()
|
||||
&& !input_vis.is_public()
|
||||
&& !output_vis.is_fixed()
|
||||
&& !params_vis.is_fixed()
|
||||
&& !input_vis.is_fixed()
|
||||
&& !output_vis.is_hashed()
|
||||
&& !params_vis.is_hashed()
|
||||
&& !input_vis.is_hashed()
|
||||
&& !output_vis.is_polycommit()
|
||||
&& !params_vis.is_polycommit()
|
||||
&& !input_vis.is_polycommit()
|
||||
{
|
||||
return Err(GraphError::Visibility);
|
||||
}
|
||||
@@ -335,17 +340,17 @@ impl VarVisibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for holding all columns that will be assigned to by a model.
|
||||
/// Container for circuit columns used by a model
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
|
||||
#[allow(missing_docs)]
|
||||
/// Advice columns for circuit assignments
|
||||
pub advices: Vec<VarTensor>,
|
||||
#[allow(missing_docs)]
|
||||
/// Optional instance column for public inputs
|
||||
pub instance: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
/// Get instance col
|
||||
/// Gets reference to instance column if it exists
|
||||
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
|
||||
if let Some(instance) = &self.instance {
|
||||
match instance {
|
||||
@@ -357,14 +362,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the initial instance offset
|
||||
/// Sets initial offset for instance values
|
||||
pub fn set_initial_instance_offset(&mut self, offset: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_initial_instance_offset(offset);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the total instance len
|
||||
/// Gets total length of instance data
|
||||
pub fn get_instance_len(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_total_instance_len()
|
||||
@@ -373,21 +378,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment the instance offset
|
||||
/// Increments instance index
|
||||
pub fn increment_instance_idx(&mut self) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.increment_idx();
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the instance offset
|
||||
/// Sets instance index to specific value
|
||||
pub fn set_instance_idx(&mut self, val: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_idx(val);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the instance offset
|
||||
/// Gets current instance index
|
||||
pub fn get_instance_idx(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_idx()
|
||||
@@ -396,7 +401,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Initializes instance column with specified dimensions and scale
|
||||
pub fn instantiate_instance(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -417,7 +422,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
};
|
||||
}
|
||||
|
||||
/// Allocate all columns that will be assigned to by a model.
|
||||
/// Creates new ModelVars with allocated columns based on settings
|
||||
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
|
||||
324
src/lib.rs
324
src/lib.rs
@@ -28,6 +28,9 @@
|
||||
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
use log::warn;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc as _;
|
||||
|
||||
/// Error type
|
||||
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
|
||||
@@ -99,7 +102,7 @@ use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::Visibility;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
};
|
||||
@@ -165,7 +168,6 @@ pub mod srs_sha;
|
||||
pub mod tensor;
|
||||
#[cfg(feature = "ios-bindings")]
|
||||
uniffi::setup_scaffolding!();
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -180,11 +182,9 @@ lazy_static! {
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
|
||||
}
|
||||
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
@@ -266,76 +266,102 @@ impl From<String> for Commitments {
|
||||
}
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
///
|
||||
/// RunArgs contains all configuration parameters needed to control the proving process,
|
||||
/// including scaling factors, visibility settings, and circuit parameters.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(Args, ToFlags)
|
||||
)]
|
||||
pub struct RunArgs {
|
||||
/// The tolerance for error on model outputs
|
||||
/// Error tolerance for model outputs
|
||||
/// Only applicable when outputs are public
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
|
||||
pub tolerance: Tolerance,
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
/// Fixed point scaling factor for quantizing inputs
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub input_scale: Scale,
|
||||
/// The denominator in the fixed point representation used when quantizing parameters
|
||||
/// Fixed point scaling factor for quantizing parameters
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub param_scale: Scale,
|
||||
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
/// Scale rebase threshold multiplier
|
||||
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
|
||||
/// Advanced parameter that should be used with caution
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
/// Range for lookup table input column values
|
||||
/// Specified as (min, max) pair
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
/// Log2 of the number of rows in the circuit
|
||||
/// Controls circuit size and proving time
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
|
||||
pub logrows: u32,
|
||||
/// The log_2 number of rows
|
||||
/// Number of inner columns per block
|
||||
/// Affects circuit layout and efficiency
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
pub num_inner_cols: usize,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
/// Graph variables for parameterizing the computation
|
||||
/// Format: "name->value", e.g. "batch_size->1"
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Flags whether inputs are public, private, fixed, hashed, polycommit
|
||||
/// Visibility setting for input values
|
||||
/// Controls whether inputs are public or private in the circuit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub input_visibility: Visibility,
|
||||
/// Flags whether outputs are public, private, fixed, hashed, polycommit
|
||||
/// Visibility setting for output values
|
||||
/// Controls whether outputs are public or private in the circuit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
|
||||
pub output_visibility: Visibility,
|
||||
/// Flags whether params are fixed, private, hashed, polycommit
|
||||
/// Visibility setting for parameters
|
||||
/// Controls how parameters are handled in the circuit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub param_visibility: Visibility,
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
/// Whether to rebase constants with zero fractional part to scale 0
|
||||
/// Can improve efficiency for integer constants
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
/// check mode (safe, unsafe, etc)
|
||||
/// Circuit checking mode
|
||||
/// Controls level of constraint verification
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
|
||||
pub check_mode: CheckMode,
|
||||
/// commitment scheme
|
||||
/// Commitment scheme for circuit proving
|
||||
/// Affects proof size and verification time
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
|
||||
pub commitment: Option<Commitments>,
|
||||
/// the base used for decompositions
|
||||
/// Base for number decomposition
|
||||
/// Must be a power of 2
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
|
||||
pub decomp_base: usize,
|
||||
/// Number of decomposition legs
|
||||
/// Controls decomposition granularity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
/// Whether to use bounded lookup for logarithm computation
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
/// Range check inputs and outputs (turn off if the inputs are felts)
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
/// Creates a new RunArgs instance with default values
|
||||
///
|
||||
/// Default configuration is optimized for common use cases
|
||||
/// while maintaining reasonable proving time and circuit size
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
@@ -355,54 +381,138 @@ impl Default for RunArgs {
|
||||
commitment: None,
|
||||
decomp_base: 16384,
|
||||
decomp_legs: 2,
|
||||
ignore_range_check_inputs_outputs: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
/// Validates the RunArgs configuration
|
||||
///
|
||||
/// Performs comprehensive validation of all parameters to ensure they are within
|
||||
/// acceptable ranges and follow required constraints. Returns accumulated errors
|
||||
/// if any validations fail.
|
||||
///
|
||||
/// # Returns
|
||||
/// - Ok(()) if all validations pass
|
||||
/// - Err(String) with detailed error message if any validation fails
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Visibility validations
|
||||
if self.param_visibility == Visibility::Public {
|
||||
return Err(
|
||||
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
|
||||
.into(),
|
||||
errors.push(
|
||||
"Parameters cannot be public instances. Use 'fixed' or 'kzgcommit' instead"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
return Err("scale_rebase_multiplier must be >= 1".into());
|
||||
}
|
||||
if self.lookup_range.0 > self.lookup_range.1 {
|
||||
return Err("lookup_range min is greater than max".into());
|
||||
}
|
||||
if self.logrows < 1 {
|
||||
return Err("logrows must be >= 1".into());
|
||||
}
|
||||
if self.num_inner_cols < 1 {
|
||||
return Err("num_inner_cols must be >= 1".into());
|
||||
}
|
||||
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
return Err("tolerance > 0.0 requires output_visibility to be public".into());
|
||||
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
|
||||
}
|
||||
|
||||
// Scale validations
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
errors.push("scale_rebase_multiplier must be >= 1".to_string());
|
||||
}
|
||||
|
||||
// if any of the scales are too small
|
||||
if self.input_scale < 8 || self.param_scale < 8 {
|
||||
warn!("low scale values (<8) may impact precision");
|
||||
}
|
||||
|
||||
// Lookup range validations
|
||||
if self.lookup_range.0 > self.lookup_range.1 {
|
||||
errors.push(format!(
|
||||
"Invalid lookup range: min ({}) is greater than max ({})",
|
||||
self.lookup_range.0, self.lookup_range.1
|
||||
));
|
||||
}
|
||||
|
||||
// Size validations
|
||||
if self.logrows < 1 {
|
||||
errors.push("logrows must be >= 1".to_string());
|
||||
}
|
||||
|
||||
if self.num_inner_cols < 1 {
|
||||
errors.push("num_inner_cols must be >= 1".to_string());
|
||||
}
|
||||
|
||||
let batch_size = self.variables.iter().find(|(name, _)| name == "batch_size");
|
||||
if let Some(batch_size) = batch_size {
|
||||
if batch_size.1 == 0 {
|
||||
errors.push("'batch_size' cannot be 0".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Decomposition validations
|
||||
if self.decomp_base == 0 {
|
||||
errors.push("decomp_base cannot be 0".to_string());
|
||||
}
|
||||
|
||||
if self.decomp_legs == 0 {
|
||||
errors.push("decomp_legs cannot be 0".to_string());
|
||||
}
|
||||
|
||||
// Performance validations
|
||||
if self.logrows > MAX_PUBLIC_SRS {
|
||||
warn!("logrows exceeds maximum public SRS size");
|
||||
}
|
||||
|
||||
// Validate tolerance is non-negative
|
||||
if self.tolerance.val < 0.0 {
|
||||
errors.push("tolerance cannot be negative".to_string());
|
||||
}
|
||||
|
||||
// Performance warnings
|
||||
if self.input_scale > 20 || self.param_scale > 20 {
|
||||
warn!("High scale values (>20) may impact performance");
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors.join("\n"))
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
/// Exports the configuration as JSON
|
||||
///
|
||||
/// Serializes the RunArgs instance to a JSON string
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(String)` containing JSON representation
|
||||
/// * `Err` if serialization fails
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
let res = serde_json::to_string(&self)?;
|
||||
Ok(res)
|
||||
}
|
||||
/// Parse an ezkl configuration from a json
|
||||
|
||||
/// Parses configuration from JSON
|
||||
///
|
||||
/// Deserializes a RunArgs instance from a JSON string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `arg_json` - JSON string containing configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(RunArgs)` if parsing succeeds
|
||||
/// * `Err` if parsing fails
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(arg_json)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single key-value pair
|
||||
// Additional helper functions for the module
|
||||
|
||||
/// Parses a key-value pair from a string in the format "key->value"
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `s` - Input string in the format "key->value"
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok((T, U))` - Parsed key and value
|
||||
/// * `Err` - If parsing fails
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
@@ -415,14 +525,15 @@ where
|
||||
{
|
||||
let pos = s
|
||||
.find("->")
|
||||
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
|
||||
let a = s[..pos].parse()?;
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
.ok_or_else(|| format!("invalid KEY->VALUE: no `->` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 2..].parse()?))
|
||||
}
|
||||
|
||||
/// Check if the version string matches the artifact version
|
||||
/// If the version string does not match the artifact version, log a warning
|
||||
/// Verifies that a version string matches the expected artifact version
|
||||
/// Logs warnings for version mismatches or unversioned artifacts
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `artifact_version` - Version string from the artifact
|
||||
pub fn check_version_string_matches(artifact_version: &str) {
|
||||
if artifact_version == "0.0.0"
|
||||
|| artifact_version == "source - no compatibility guaranteed"
|
||||
@@ -447,3 +558,98 @@ pub fn check_version_string_matches(artifact_version: &str) {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::field_reassign_with_default)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_valid_default_args() {
|
||||
let args = RunArgs::default();
|
||||
assert!(args.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_param_visibility() {
|
||||
let mut args = RunArgs::default();
|
||||
args.param_visibility = Visibility::Public;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Parameters cannot be public instances"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_scale_rebase() {
|
||||
let mut args = RunArgs::default();
|
||||
args.scale_rebase_multiplier = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("scale_rebase_multiplier must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_lookup_range() {
|
||||
let mut args = RunArgs::default();
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Invalid lookup range"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_logrows() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("logrows must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_inner_cols() {
|
||||
let mut args = RunArgs::default();
|
||||
args.num_inner_cols = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("num_inner_cols must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = 1.0;
|
||||
args.output_visibility = Visibility::Private;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negative_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = -1.0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("tolerance cannot be negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_batch_size() {
|
||||
let mut args = RunArgs::default();
|
||||
args.variables = vec![("batch_size".to_string(), 0)];
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("'batch_size' cannot be 0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_serialization() {
|
||||
let args = RunArgs::default();
|
||||
let json = args.as_json().unwrap();
|
||||
let deserialized = RunArgs::from_json(&json).unwrap();
|
||||
assert_eq!(args, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_validation_errors() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
// Should contain multiple error messages
|
||||
assert!(err.matches("\n").count() >= 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,7 +133,6 @@ pub fn aggregate<'a>(
|
||||
.collect_vec()
|
||||
}));
|
||||
|
||||
// loader.ctx().constrain_equal(cell_0, cell_1)
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
|
||||
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
@@ -309,11 +308,11 @@ impl AggregationCircuit {
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
/// Number of limbs used for decomposition
|
||||
pub fn num_limbs() -> usize {
|
||||
LIMBS
|
||||
}
|
||||
///
|
||||
/// Number of bits used for decomposition
|
||||
pub fn num_bits() -> usize {
|
||||
BITS
|
||||
}
|
||||
|
||||
@@ -353,6 +353,7 @@ where
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
{
|
||||
/// Create a new application snark from proof and instance variables ready for aggregation
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
protocol: Option<PlonkProtocol<C>>,
|
||||
instances: Vec<Vec<F>>,
|
||||
@@ -528,7 +529,6 @@ pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
@@ -794,7 +794,6 @@ pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
@@ -817,7 +816,6 @@ pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
|
||||
@@ -38,4 +38,10 @@ pub enum TensorError {
|
||||
/// Decomposition error
|
||||
#[error("decomposition error: {0}")]
|
||||
DecompositionError(#[from] DecompositionError),
|
||||
/// Invalid argument
|
||||
#[error("invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
/// Index out of bounds
|
||||
#[error("index {0} out of bounds for dimension {1}")]
|
||||
IndexOutOfBounds(usize, usize),
|
||||
}
|
||||
|
||||
@@ -24,9 +24,6 @@ use std::path::PathBuf;
|
||||
pub use val::*;
|
||||
pub use var::*;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use instant::Instant;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
@@ -40,8 +37,6 @@ use halo2_proofs::{
|
||||
poly::Rotation,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
#[cfg(feature = "metal")]
|
||||
use metal::{Device, MTLResourceOptions, MTLSize};
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::io::Read;
|
||||
@@ -49,31 +44,6 @@ use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
lazy_static::lazy_static! {
|
||||
static ref DEVICE: Device = Device::system_default().expect("no device found");
|
||||
|
||||
static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap();
|
||||
|
||||
static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue();
|
||||
|
||||
static ref PIPELINES: HashMap<String, metal::ComputePipelineState> = {
|
||||
let mut map = HashMap::new();
|
||||
for name in ["add", "sub", "mul"] {
|
||||
let function = LIB.get_function(name, None).unwrap();
|
||||
let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap();
|
||||
map.insert(name.to_string(), pipeline);
|
||||
}
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
pub trait TensorType: Clone + Debug + 'static {
|
||||
/// Returns the zero value.
|
||||
@@ -833,6 +803,12 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot duplicate every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
|
||||
let mut offset = initial_offset;
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
@@ -862,11 +838,17 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot remove every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Pre-calculate capacity to avoid reallocations
|
||||
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
|
||||
let mut inner = Vec::with_capacity(estimated_size);
|
||||
|
||||
// Use iterator directly instead of creating intermediate collections
|
||||
// Use iterator directly instead of creating intermediate collectionsif
|
||||
let mut i = 0;
|
||||
while i < self.inner.len() {
|
||||
// Add the current element
|
||||
@@ -885,7 +867,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
|
||||
/// Remove indices
|
||||
/// WARN: assumes indices are in ascending order for speed
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
@@ -912,7 +893,11 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
// remove indices
|
||||
for elem in indices.iter().rev() {
|
||||
inner.remove(*elem);
|
||||
if *elem < self.len() {
|
||||
inner.remove(*elem);
|
||||
} else {
|
||||
return Err(TensorError::IndexOutOfBounds(*elem, self.len()));
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
@@ -1404,10 +1389,6 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "add");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1505,10 +1486,6 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "sub");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1576,10 +1553,6 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "mul");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1685,7 +1658,9 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
}
|
||||
|
||||
// implement remainder
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + PartialEq> Rem
|
||||
for Tensor<T>
|
||||
{
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
|
||||
/// Elementwise remainder of a tensor with another tensor.
|
||||
@@ -1714,9 +1689,25 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() % r;
|
||||
});
|
||||
lhs.par_iter_mut()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| {
|
||||
if let Some(zero) = T::zero() {
|
||||
if r != zero {
|
||||
*o = o.clone() % r;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Cannot divide by zero in remainder".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(lhs)
|
||||
}
|
||||
@@ -1751,7 +1742,6 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
/// assert_eq!(c, vec![2, 3]);
|
||||
///
|
||||
/// ```
|
||||
|
||||
pub fn get_broadcasted_shape(
|
||||
shape_a: &[usize],
|
||||
shape_b: &[usize],
|
||||
@@ -1759,20 +1749,21 @@ pub fn get_broadcasted_shape(
|
||||
let num_dims_a = shape_a.len();
|
||||
let num_dims_b = shape_b.len();
|
||||
|
||||
match (num_dims_a, num_dims_b) {
|
||||
(a, b) if a == b => {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
if num_dims_a == num_dims_b {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
}
|
||||
(a, b) if a < b => Ok(shape_b.to_vec()),
|
||||
(a, b) if a > b => Ok(shape_a.to_vec()),
|
||||
_ => Err(TensorError::DimError(
|
||||
Ok(broadcasted_shape)
|
||||
} else if num_dims_a < num_dims_b {
|
||||
Ok(shape_b.to_vec())
|
||||
} else if num_dims_a > num_dims_b {
|
||||
Ok(shape_a.to_vec())
|
||||
} else {
|
||||
Err(TensorError::DimError(
|
||||
"Unknown condition for broadcasting".to_string(),
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
@@ -1811,66 +1802,4 @@ mod tests {
|
||||
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_int() {
|
||||
let a = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let b = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_felt() {
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
let a = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
let b = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,6 +385,12 @@ pub fn resize<T: TensorType + Send + Sync>(
|
||||
pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("add".to_string()));
|
||||
}
|
||||
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -433,6 +439,11 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("sub".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -479,6 +490,11 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("mult".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,33 +5,35 @@ use log::{debug, error, warn};
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
|
||||
use super::*;
|
||||
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
|
||||
/// Typically assign [ValTensor]s to [VarTensor]s when laying out a circuit.
|
||||
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
|
||||
/// VarTensors are used to store and manage circuit columns, typically for assigning ValTensor
|
||||
/// values during circuit layout. The tensor organizes storage into blocks of columns, where each
|
||||
/// block contains multiple columns and each column contains multiple rows.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq)]
|
||||
pub enum VarTensor {
|
||||
/// A VarTensor for holding Advice values, which are assigned at proving time.
|
||||
Advice {
|
||||
/// Vec of Advice columns, we have [[xx][xx][xx]...] where each inner vec is xx columns
|
||||
inner: Vec<Vec<Column<Advice>>>,
|
||||
///
|
||||
/// The number of columns in each inner block
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// Dummy var
|
||||
/// A placeholder tensor used for testing or temporary storage
|
||||
Dummy {
|
||||
///
|
||||
/// The number of columns in each inner block
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// Empty var
|
||||
/// An empty tensor with no storage
|
||||
#[default]
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// name of the tensor
|
||||
/// Returns the name of the tensor variant as a static string
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => "Advice",
|
||||
@@ -40,22 +42,35 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns true if the tensor is an Advice variant
|
||||
pub fn is_advice(&self) -> bool {
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
}
|
||||
|
||||
/// Calculates the maximum number of usable rows in the constraint system
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - Log base 2 of the total number of rows (including system and blinding rows)
|
||||
///
|
||||
/// # Returns
|
||||
/// The maximum number of usable rows after accounting for blinding factors
|
||||
pub fn max_rows<F: PrimeField>(cs: &ConstraintSystem<F>, logrows: usize) -> usize {
|
||||
let base = 2u32;
|
||||
base.pow(logrows as u32) as usize - cs.blinding_factors() - 1
|
||||
}
|
||||
|
||||
/// Create a new VarTensor::Advice that is unblinded
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
/// Creates a new VarTensor::Advice with unblinded columns. Unblinded columns are used when
|
||||
/// the values do not need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with unblinded columns enabled for equality constraints
|
||||
pub fn new_unblinded_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -93,11 +108,17 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
|
||||
/// the values need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with blinded columns enabled for equality constraints
|
||||
pub fn new_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -133,11 +154,17 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes fixed columns to support the VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
|
||||
/// Fixed columns are used for constant values that are known at circuit creation time.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_constants` - Number of constant values needed
|
||||
/// * `module_requires_fixed` - Whether the module requires at least one fixed column
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of fixed columns created
|
||||
pub fn constant_cols<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -169,7 +196,14 @@ impl VarTensor {
|
||||
modulo
|
||||
}
|
||||
|
||||
/// Create a new VarTensor::Dummy
|
||||
/// Creates a new dummy VarTensor for testing or temporary storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Dummy with the specified dimensions
|
||||
pub fn dummy(logrows: usize, num_inner_cols: usize) -> Self {
|
||||
let base = 2u32;
|
||||
let max_rows = base.pow(logrows as u32) as usize - 6;
|
||||
@@ -179,7 +213,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the dims of the object the VarTensor represents
|
||||
/// Returns the number of blocks in the tensor
|
||||
pub fn num_blocks(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner.len(),
|
||||
@@ -187,7 +221,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Num inner cols
|
||||
/// Returns the number of columns in each inner block
|
||||
pub fn num_inner_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { num_inner_cols, .. } | VarTensor::Dummy { num_inner_cols, .. } => {
|
||||
@@ -197,7 +231,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Total number of columns
|
||||
/// Returns the total number of columns across all blocks
|
||||
pub fn num_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner[0].len() * inner.len(),
|
||||
@@ -205,7 +239,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the size of each column
|
||||
/// Returns the maximum number of rows in each column
|
||||
pub fn col_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { col_size, .. } | VarTensor::Dummy { col_size, .. } => *col_size,
|
||||
@@ -213,7 +247,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the size of each column
|
||||
/// Returns the total size of each block (num_inner_cols * col_size)
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice {
|
||||
@@ -230,7 +264,13 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Take a linear coordinate and output the (column, row) position in the storage block.
|
||||
/// Converts a linear coordinate to (block, column, row) coordinates in the storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `linear_coord` - The linear index to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (block_index, column_index, row_index)
|
||||
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
|
||||
// x indexes over blocks of size num_inner_cols
|
||||
let x = linear_coord / self.block_size();
|
||||
@@ -243,7 +283,17 @@ impl VarTensor {
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// Retrieve the value of a specific cell in the tensor.
|
||||
/// Queries a range of cells in the tensor during circuit synthesis
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `y` - Column index within block
|
||||
/// * `z` - Starting row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried cells
|
||||
pub fn query_rng<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -268,7 +318,16 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve the value of a specific block at an offset in the tensor.
|
||||
/// Queries an entire block of cells at a given offset
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `z` - Row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried block
|
||||
pub fn query_whole_block<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -293,7 +352,16 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a constant value to a specific cell in the tensor.
|
||||
/// Assigns a constant value to a specific cell in the tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `coord` - Coordinate within the tensor
|
||||
/// * `constant` - The constant value to assign
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned cell or an error if assignment fails
|
||||
pub fn assign_constant<F: PrimeField + TensorType + PartialOrd>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -313,7 +381,17 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
/// Assigns values from a ValTensor to this tensor, excluding specified positions
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `omissions` - Set of positions to skip during assignment
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -344,7 +422,16 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
/// Assigns values from a ValTensor to this tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -396,14 +483,23 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Helper function to get the remaining size of the column
|
||||
/// Returns the remaining available space in a column for assignments
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `offset` - Current offset in the column
|
||||
/// * `values` - The ValTensor to check space for
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of rows that need to be flushed or an error if space is insufficient
|
||||
pub fn get_column_flush<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<usize, halo2_proofs::plonk::Error> {
|
||||
if values.len() > self.col_size() {
|
||||
error!("Values are too large for the column");
|
||||
error!(
|
||||
"There are too many values to flush for this column size, try setting the logrows to a higher value (eg. --logrows 22 on the cli)"
|
||||
);
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
|
||||
@@ -427,8 +523,16 @@ impl VarTensor {
|
||||
Ok(flush_len)
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor. Whereby the values are assigned to a single column, without overflowing.
|
||||
/// So for instance if we are assigning 10 values and we are at index 18 of the column, and the columns are of length 20, we skip the last 2 values of current column and start from the beginning of the next column.
|
||||
/// Assigns values to a single column, avoiding column overflow by flushing to the next column if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, number of rows flushed) or an error if assignment fails
|
||||
pub fn assign_exact_column<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -443,8 +547,17 @@ impl VarTensor {
|
||||
Ok((assigned_vals, flush_len))
|
||||
}
|
||||
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
/// Assigns values with duplication in dummy mode, used for testing and simulation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `single_inner_col` - Whether to treat as a single column
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
pub fn dummy_assign_with_duplication<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -494,7 +607,16 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Assigns values with duplication but without enforcing constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
pub fn assign_with_duplication_unconstrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -533,8 +655,18 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
/// Assigns values with duplication and enforces equality constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `check_mode` - Mode for checking equality constraints
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
pub fn assign_with_duplication_constrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -613,6 +745,17 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a single value to the tensor. This is a helper function used by other assignment methods.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `k` - The value to assign
|
||||
/// * `coord` - The coordinate where to assign the value
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned value or an error if assignment fails
|
||||
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -623,24 +766,28 @@ impl VarTensor {
|
||||
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
let res = match k {
|
||||
// Handle direct value assignment
|
||||
ValType::Value(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned value
|
||||
ValType::PrevAssigned(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned constant
|
||||
ValType::AssignedConstant(v, val) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle assigning evaluated value
|
||||
ValType::AssignedValue(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
|
||||
region
|
||||
@@ -649,6 +796,7 @@ impl VarTensor {
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle constant value assignment with caching
|
||||
ValType::Constant(v) => {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
|
||||
let value = ValType::AssignedConstant(
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -28,11 +28,12 @@
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false
|
||||
"bounded_log_lookup": false,
|
||||
"ignore_range_check_inputs_outputs": false
|
||||
},
|
||||
"num_rows": 46,
|
||||
"total_assignments": 92,
|
||||
"total_const_size": 3,
|
||||
"num_rows": 236,
|
||||
"total_assignments": 472,
|
||||
"total_const_size": 4,
|
||||
"total_dynamic_col_size": 0,
|
||||
"max_dynamic_input_len": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
|
||||
Binary file not shown.
@@ -1,7 +1,6 @@
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(test)]
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
@@ -23,6 +22,8 @@ mod native_tests {
|
||||
static COMPILE_WASM: Once = Once::new();
|
||||
static ENV_SETUP: Once = Once::new();
|
||||
|
||||
const TEST_BINARY: &str = "test-runs/ezkl";
|
||||
|
||||
//Sure to run this once
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
@@ -75,9 +76,8 @@ mod native_tests {
|
||||
});
|
||||
}
|
||||
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn init_wasm() {
|
||||
fn init_wasm() {
|
||||
COMPILE_WASM.call_once(|| {
|
||||
build_wasm_ezkl();
|
||||
});
|
||||
@@ -104,7 +104,7 @@ mod native_tests {
|
||||
|
||||
fn download_srs(logrows: u32, commitment: Commitments) {
|
||||
// if does not exist, download it
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"get-srs",
|
||||
"--logrows",
|
||||
@@ -206,62 +206,62 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 98] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
"1l_flatten",
|
||||
const TESTS: [&str; 99] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice", //1
|
||||
"1l_concat", //2
|
||||
"1l_flatten", //3
|
||||
// "1l_average",
|
||||
"1l_div",
|
||||
"1l_pad", // 5
|
||||
"1l_reshape",
|
||||
"1l_eltwise_div",
|
||||
"1l_sigmoid",
|
||||
"1l_sqrt",
|
||||
"1l_softmax", //10
|
||||
"1l_div", //4
|
||||
"1l_pad", // 5
|
||||
"1l_reshape", //6
|
||||
"1l_eltwise_div", //7
|
||||
"1l_sigmoid", //8
|
||||
"1l_sqrt", //9
|
||||
"1l_softmax", //10
|
||||
// "1l_instance_norm",
|
||||
"1l_batch_norm",
|
||||
"1l_prelu",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
"1l_batch_norm", //11
|
||||
"1l_prelu", //12
|
||||
"1l_leakyrelu", //13
|
||||
"1l_gelu_noappx", //14
|
||||
// "1l_gelu_tanh_appx",
|
||||
"1l_relu", //15
|
||||
"1l_downsample",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small", //20
|
||||
"2l_relu_sigmoid",
|
||||
"1l_conv",
|
||||
"2l_sigmoid_small",
|
||||
"2l_relu_sigmoid_conv",
|
||||
"3l_relu_conv_fc", //25
|
||||
"4l_relu_conv_fc",
|
||||
"1l_erf",
|
||||
"1l_var",
|
||||
"1l_elu",
|
||||
"min", //30
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"1l_conv_transpose",
|
||||
"1l_upsample",
|
||||
"1l_identity", //35
|
||||
"idolmodel", // too big evm
|
||||
"trig", // too big evm
|
||||
"prelu_gmm",
|
||||
"lstm",
|
||||
"rnn", //40
|
||||
"quantize_dequantize",
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"decision_tree", // 45
|
||||
"random_forest",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
"xgboost",
|
||||
"lightgbm", //50
|
||||
"hummingbird_decision_tree",
|
||||
"1l_relu", //15
|
||||
"1l_downsample", //16
|
||||
"1l_tanh", //17
|
||||
"2l_relu_sigmoid_small", //18
|
||||
"2l_relu_fc", //19
|
||||
"2l_relu_small", //20
|
||||
"2l_relu_sigmoid", //21
|
||||
"1l_conv", //22
|
||||
"2l_sigmoid_small", //23
|
||||
"2l_relu_sigmoid_conv", //24
|
||||
"3l_relu_conv_fc", //25
|
||||
"4l_relu_conv_fc", //26
|
||||
"1l_erf", //27
|
||||
"1l_var", //28
|
||||
"1l_elu", //29
|
||||
"min", //30
|
||||
"max", //31
|
||||
"1l_max_pool", //32
|
||||
"1l_conv_transpose", //33
|
||||
"1l_upsample", //34
|
||||
"1l_identity", //35
|
||||
"idolmodel", // too big evm
|
||||
"trig", // too big evm
|
||||
"prelu_gmm", //38
|
||||
"lstm", //39
|
||||
"rnn", //40
|
||||
"quantize_dequantize", //41
|
||||
"1l_where", //42
|
||||
"boolean", //43
|
||||
"boolean_identity", //44
|
||||
"decision_tree", // 45
|
||||
"random_forest", //46
|
||||
"gradient_boosted_trees", //47
|
||||
"1l_topk", //48
|
||||
"xgboost", //49
|
||||
"lightgbm", //50
|
||||
"hummingbird_decision_tree", //51
|
||||
"oh_decision_tree",
|
||||
"linear_svc",
|
||||
"gather_elements",
|
||||
@@ -309,6 +309,7 @@ mod native_tests {
|
||||
"log", // 95
|
||||
"exp", // 96
|
||||
"general_exp", // 97
|
||||
"integer_div", // 98
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -547,7 +548,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=97 {
|
||||
seq!(N in 0..=98 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -628,7 +629,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(8194), Some(5));
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(32776), Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -982,7 +983,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1556,7 +1557,7 @@ mod native_tests {
|
||||
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
@@ -1568,7 +1569,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
@@ -1580,7 +1581,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
@@ -1592,7 +1593,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
} else {
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
@@ -1641,6 +1642,11 @@ mod native_tests {
|
||||
format!("--commitment={}", commitment),
|
||||
];
|
||||
|
||||
// if output-visibility is fixed set --range-check-inputs-outputs to False
|
||||
if output_visibility == "fixed" {
|
||||
args.push("--ignore-range-check-inputs-outputs".to_string());
|
||||
}
|
||||
|
||||
if let Some(decomp_base) = decomp_base {
|
||||
args.push(format!("--decomp-base={}", decomp_base));
|
||||
}
|
||||
@@ -1653,7 +1659,7 @@ mod native_tests {
|
||||
args.push("--bounded-log-lookup".to_string());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1683,7 +1689,7 @@ mod native_tests {
|
||||
calibrate_args.push(scales);
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(calibrate_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1707,7 +1713,7 @@ mod native_tests {
|
||||
*tolerance = 0.0;
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"compile-circuit",
|
||||
"-M",
|
||||
@@ -1724,7 +1730,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"gen-witness",
|
||||
"-D",
|
||||
@@ -1792,7 +1798,7 @@ mod native_tests {
|
||||
|
||||
// Mock prove (fast, but does not cover some potential issues)
|
||||
fn render_circuit(test_dir: &str, example_name: String) {
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"render-circuit",
|
||||
"-M",
|
||||
@@ -1823,7 +1829,7 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
);
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock-aggregate",
|
||||
"--logrows=23",
|
||||
@@ -1861,7 +1867,7 @@ mod native_tests {
|
||||
|
||||
download_srs(23, commitment);
|
||||
// now setup-aggregate
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"setup-aggregate",
|
||||
"--sample-snarks",
|
||||
@@ -1877,7 +1883,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"aggregate",
|
||||
"--logrows=23",
|
||||
@@ -1892,7 +1898,7 @@ mod native_tests {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"verify-aggr",
|
||||
"--logrows=23",
|
||||
@@ -1942,7 +1948,7 @@ mod native_tests {
|
||||
let private_key = format!("--private-key={}", *ANVIL_DEFAULT_PRIVATE_KEY);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
@@ -1964,7 +1970,7 @@ mod native_tests {
|
||||
|
||||
let args = build_args(base_args, &sol_arg);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1980,7 +1986,7 @@ mod native_tests {
|
||||
private_key.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2002,14 +2008,14 @@ mod native_tests {
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&base_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// As sanity check, add example that should fail.
|
||||
base_args[2] = PF_FAILURE_AGGR;
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(base_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2060,7 +2066,7 @@ mod native_tests {
|
||||
|
||||
init_params(settings_path.clone().into());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"setup",
|
||||
"-M",
|
||||
@@ -2075,7 +2081,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"prove",
|
||||
"-W",
|
||||
@@ -2093,7 +2099,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"swap-proof-commitments",
|
||||
"--proof-path",
|
||||
@@ -2105,7 +2111,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"verify",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
@@ -2128,7 +2134,7 @@ mod native_tests {
|
||||
// get_srs for the graph_settings_num_instances
|
||||
download_srs(1, graph_settings.run_args.commitment.into());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"verify",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
@@ -2178,7 +2184,7 @@ mod native_tests {
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
@@ -2198,7 +2204,7 @@ mod native_tests {
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2210,7 +2216,7 @@ mod native_tests {
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2232,14 +2238,14 @@ mod native_tests {
|
||||
deployed_addr_arg.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// As sanity check, add example that should fail.
|
||||
args[2] = PF_FAILURE;
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2247,6 +2253,7 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier(
|
||||
num_inner_columns: usize,
|
||||
test_dir: &str,
|
||||
@@ -2297,7 +2304,7 @@ mod native_tests {
|
||||
"--reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2312,7 +2319,7 @@ mod native_tests {
|
||||
"-C=verifier/reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2341,7 +2348,7 @@ mod native_tests {
|
||||
&sol_arg_vk,
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2356,7 +2363,7 @@ mod native_tests {
|
||||
"-C=vka",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2369,7 +2376,7 @@ mod native_tests {
|
||||
let deployed_addr_arg_vk = format!("--addr-vk={}", addr_vk);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
@@ -2392,7 +2399,7 @@ mod native_tests {
|
||||
deployed_addr_arg_vk.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2425,7 +2432,7 @@ mod native_tests {
|
||||
// Verify the modified proof (should fail)
|
||||
let mut args_mod = args.clone();
|
||||
args_mod[2] = &modified_pf_arg;
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args_mod)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2503,7 +2510,7 @@ mod native_tests {
|
||||
let test_input_source = format!("--input-source={}", input_source);
|
||||
let test_output_source = format!("--output-source={}", output_source);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"setup",
|
||||
"-M",
|
||||
@@ -2518,7 +2525,7 @@ mod native_tests {
|
||||
assert!(status.success());
|
||||
|
||||
// generate the witness, passing the vk path to generate the necessary kzg commits
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"gen-witness",
|
||||
"-D",
|
||||
@@ -2575,7 +2582,7 @@ mod native_tests {
|
||||
}
|
||||
input.save(data_path.clone().into()).unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
@@ -2593,7 +2600,7 @@ mod native_tests {
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"prove",
|
||||
"-W",
|
||||
@@ -2614,7 +2621,7 @@ mod native_tests {
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
@@ -2633,7 +2640,7 @@ mod native_tests {
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2654,7 +2661,7 @@ mod native_tests {
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2677,7 +2684,7 @@ mod native_tests {
|
||||
create_da_args.push(test_on_chain_data_path.as_str());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&create_da_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2690,7 +2697,7 @@ mod native_tests {
|
||||
};
|
||||
|
||||
let addr_path_da_arg = format!("--addr-path={}/{}/addr_da.txt", test_dir, example_name);
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"deploy-evm-da",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
@@ -2728,14 +2735,14 @@ mod native_tests {
|
||||
deployed_addr_da_arg.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// Create a new set of test on chain data only for the on-chain input source
|
||||
if input_source != "file" || output_source != "file" {
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
@@ -2762,7 +2769,7 @@ mod native_tests {
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2778,7 +2785,7 @@ mod native_tests {
|
||||
deployed_addr_da_arg.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2789,18 +2796,28 @@ mod native_tests {
|
||||
#[cfg(feature = "icicle")]
|
||||
let args = [
|
||||
"build",
|
||||
"--release",
|
||||
"--profile=test-runs",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--features",
|
||||
"icicle",
|
||||
];
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
let args = ["build", "--release", "--bin", "ezkl"];
|
||||
#[cfg(feature = "macos-metal")]
|
||||
let args = [
|
||||
"build",
|
||||
"--profile=test-runs",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--features",
|
||||
"macos-metal",
|
||||
];
|
||||
// not macos-metal and not icicle
|
||||
#[cfg(all(not(feature = "icicle"), not(feature = "macos-metal")))]
|
||||
let args = ["build", "--profile=test-runs", "--bin", "ezkl"];
|
||||
#[cfg(not(feature = "mv-lookup"))]
|
||||
let args = [
|
||||
"build",
|
||||
"--release",
|
||||
"--profile=test-runs",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--no-default-features",
|
||||
@@ -2821,7 +2838,7 @@ mod native_tests {
|
||||
let status = Command::new("wasm-pack")
|
||||
.args([
|
||||
"build",
|
||||
"--release",
|
||||
"--profile=test-runs",
|
||||
"--target",
|
||||
"nodejs",
|
||||
"--out-dir",
|
||||
|
||||
@@ -72,11 +72,10 @@ mod py_tests {
|
||||
"torchtext==0.17.2",
|
||||
"torchvision==0.17.2",
|
||||
"pandas==2.2.1",
|
||||
"numpy==1.26.4",
|
||||
"seaborn==0.13.2",
|
||||
"notebook==7.1.2",
|
||||
"nbconvert==7.16.3",
|
||||
"onnx==1.16.0",
|
||||
"onnx==1.17.0",
|
||||
"kaggle==1.6.8",
|
||||
"py-solc-x==2.0.3",
|
||||
"web3==7.5.0",
|
||||
@@ -90,12 +89,13 @@ mod py_tests {
|
||||
"xgboost==2.0.3",
|
||||
"hummingbird-ml==0.4.11",
|
||||
"lightgbm==4.3.0",
|
||||
"numpy==1.26.4",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new("pip")
|
||||
.args(["install", "numpy==1.23"])
|
||||
.args(["install", "numpy==1.26.4"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
@@ -126,10 +126,10 @@ mod py_tests {
|
||||
}
|
||||
|
||||
const TESTS: [&str; 35] = [
|
||||
"ezkl_demo_batch.ipynb", // 0
|
||||
"proof_splitting.ipynb", // 1
|
||||
"variance.ipynb", // 2
|
||||
"mnist_gan.ipynb", // 3
|
||||
"mnist_gan.ipynb", // 0
|
||||
"ezkl_demo_batch.ipynb", // 1
|
||||
"proof_splitting.ipynb", // 2
|
||||
"variance.ipynb", // 3
|
||||
"keras_simple_demo.ipynb", // 4
|
||||
"mnist_gan_proof_splitting.ipynb", // 5
|
||||
"hashed_vis.ipynb", // 6
|
||||
|
||||
@@ -59,7 +59,7 @@ def test_poseidon_hash():
|
||||
message = [ezkl.float_to_felt(x, 7) for x in message]
|
||||
res = ezkl.poseidon_hash(message)
|
||||
assert ezkl.felt_to_big_endian(
|
||||
res[0]) == "0x0da7e5e5c8877242fa699f586baf770d731defd54f952d4adeb85047a0e32f45"
|
||||
res[0]) == "0x2369898875588bf49b6539376b09705ea69aee318a58e6fcc1e68fc3e7ad81ab"
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ mod wasm32 {
|
||||
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use ezkl::circuit::modules::poseidon::PoseidonChip;
|
||||
use ezkl::circuit::modules::Module;
|
||||
use ezkl::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use ezkl::graph::GraphCircuit;
|
||||
use ezkl::graph::{GraphSettings, GraphWitness};
|
||||
use ezkl::pfsys;
|
||||
@@ -227,11 +226,9 @@ mod wasm32 {
|
||||
let hash: Vec<Vec<Fr>> = serde_json::from_slice(&hash[..]).unwrap();
|
||||
|
||||
let reference_hash =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(hash, reference_hash)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user