Compare commits

...

18 Commits

Author SHA1 Message Date
github-actions[bot]
66233f1c3a ci: update version string in docs 2025-02-05 09:33:03 +00:00
dante
9a8c754e45 fix: use onnx convention when integer dividing (#925) 2025-02-05 09:32:44 +00:00
dante
d82766d413 fix: force prover det on argmax/min for collisions (#923) 2025-02-04 12:08:34 +00:00
dante
820a80122b fix: range-check graph input and outputs (#921) 2025-02-04 02:33:27 +00:00
dante
9c64e42bd3 docs: improve quality + code quality fixes (#920) 2025-01-31 10:48:25 +00:00
dante
27b5e5dde3 fix: make flushing err more informative (#919) 2025-01-28 14:53:05 -05:00
dante
83c4afce3b fix: version interpolation in npm publishing (#917) 2025-01-27 23:20:58 -05:00
dante
50740a22df fix: patch pypi whl version labels (#916) 2025-01-27 20:25:03 -05:00
dante
a2624f6303 fix: strict cvx opt bounds to stop prover non-det (#914) 2025-01-24 08:48:50 -05:00
dante
fc5be4f949 fix: syn-sel should be range-checked when overflow (#913) 2025-01-23 12:26:31 -05:00
dante
d0ba505baa fix: node parsing should not panic (#912) 2025-01-22 08:02:29 -05:00
dante
f35688917d fix: rm macos metal bindings from python (#911) 2025-01-21 00:36:57 -05:00
Artem
7ae541ed35 feat: metal acceleration for MSM solving (#909)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2025-01-20 22:17:24 -05:00
dante
675628cd08 fix!: shuffle argument should include an incrementing index (#904)
BREAKING CHANGE: pk and vk will not be backwards compatible
2025-01-17 09:19:10 -05:00
Artem
4fe7290240 fix: rust ci issue with updating swift pm testing files (#908) 2025-01-14 12:00:55 -05:00
dante
3e027db9b6 fix: apply zizmor suggestions to CI (#906)
---------

Co-authored-by: Jseam <hello.jseam@gmail.com>
2025-01-14 12:00:31 -05:00
Artem
e566acc22a fix: swift pm ci issue with updating testing files (#905) 2025-01-13 18:08:04 -05:00
dante
75ea99e81d fix: eager exec of ok_or error prints (#903) 2025-01-11 13:50:57 -05:00
66 changed files with 3416 additions and 1832 deletions

View File

@@ -6,22 +6,15 @@ on:
description: "Test scenario tags"
jobs:
bench_elgamal:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- name: Bench elgamal
run: cargo bench --verbose --bench elgamal
bench_poseidon:
permissions:
contents: read
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -31,10 +24,14 @@ jobs:
run: cargo bench --verbose --bench poseidon
bench_einsum_accum_matmul:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -44,10 +41,14 @@ jobs:
run: cargo bench --verbose --bench accum_einsum_matmul
bench_accum_matmul_relu:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -57,10 +58,14 @@ jobs:
run: cargo bench --verbose --bench accum_matmul_relu
bench_accum_matmul_relu_overflow:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -70,10 +75,14 @@ jobs:
run: cargo bench --verbose --bench accum_matmul_relu_overflow
bench_relu:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -83,10 +92,14 @@ jobs:
run: cargo bench --verbose --bench relu
bench_accum_dot:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -96,10 +109,14 @@ jobs:
run: cargo bench --verbose --bench accum_dot
bench_accum_conv:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -109,10 +126,14 @@ jobs:
run: cargo bench --verbose --bench accum_conv
bench_accum_sumpool:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -122,10 +143,14 @@ jobs:
run: cargo bench --verbose --bench accum_sumpool
bench_pairwise_add:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -135,10 +160,14 @@ jobs:
run: cargo bench --verbose --bench pairwise_add
bench_accum_sum:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
@@ -148,10 +177,14 @@ jobs:
run: cargo bench --verbose --bench accum_sum
bench_pairwise_pow:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27

View File

@@ -15,11 +15,18 @@ defaults:
working-directory: .
jobs:
publish-wasm-bindings:
permissions:
contents: read
packages: write
name: publish-wasm-bindings
env:
RELEASE_TAG: ${{ github.ref_name }}
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -40,43 +47,39 @@ jobs:
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
export PATH=$PATH:$PWD/binaryen-version_116/bin
wasm-opt --version
- name: Build wasm files for both web and nodejs compilation targets
run: |
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
- name: Create package.json in pkg folder
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
echo '{
"name": "@ezkljs/engine",
"version": "${{ github.ref_name }}",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}' > pkg/package.json
cat > pkg/package.json << EOF
{
"name": "@ezkljs/engine",
"version": "$RELEASE_TAG",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}
EOF
- name: Replace memory definition in nodejs
run: |
@@ -184,21 +187,26 @@ jobs:
in-browser-evm-ver-publish:
permissions:
contents: read
packages: write
name: publish-in-browser-evm-verifier-package
needs: [publish-wasm-bindings]
runs-on: ubuntu-latest
env:
RELEASE_TAG: ${{ github.ref_name }}
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
- name: Prepare tag and fetch package integrity
run: |
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)

View File

@@ -6,9 +6,13 @@ on:
description: "Test scenario tags"
jobs:
large-tests:
permissions:
contents: read
runs-on: kaiju
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18

View File

@@ -18,12 +18,19 @@ defaults:
jobs:
linux:
permissions:
contents: read
packages: write
runs-on: GPU
strategy:
matrix:
target: [x86_64]
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
@@ -44,8 +51,6 @@ jobs:
- name: Set Cargo.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml

View File

@@ -16,22 +16,32 @@ defaults:
jobs:
macos:
permissions:
contents: read
runs-on: macos-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
matrix:
target: [x86_64, universal2-apple-darwin]
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
@@ -45,6 +55,13 @@ jobs:
components: rustfmt, clippy
- name: Build wheels
if: matrix.target == 'universal2-apple-darwin'
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
- name: Build wheels
if: matrix.target == 'x86_64'
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
@@ -62,6 +79,8 @@ jobs:
path: dist
windows:
permissions:
contents: read
runs-on: windows-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -69,11 +88,21 @@ jobs:
target: [x64, x86]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -107,6 +136,8 @@ jobs:
path: dist
linux:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -114,11 +145,21 @@ jobs:
target: [x86_64]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -129,7 +170,6 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -168,58 +208,9 @@ jobs:
name: wheels
path: dist
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
# linux-cross:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# target: [aarch64, armv7]
# steps:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
# run: |
# sudo apt-get update
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
# - name: Install cross-compilation tools for armv7
# if: matrix.target == 'armv7'
# run: |
# sudo apt-get update
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
# - name: Build wheels
# uses: PyO3/maturin-action@v1
# with:
# target: ${{ matrix.target }}
# manylinux: auto
# args: --release --out dist --features python-bindings
# - uses: uraimo/run-on-arch-action@v2.5.0
# name: Install built wheel
# with:
# arch: ${{ matrix.target }}
# distro: ubuntu20.04
# githubToken: ${{ github.token }}
# install: |
# apt-get update
# apt-get install -y --no-install-recommends python3 python3-pip
# pip3 install -U pip
# run: |
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
# python3 -c "import ezkl"
# - name: Upload wheels
# uses: actions/upload-artifact@v3
# with:
# name: wheels
# path: dist
musllinux:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -228,6 +219,8 @@ jobs:
- x86_64-unknown-linux-musl
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
@@ -250,6 +243,7 @@ jobs:
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -283,6 +277,8 @@ jobs:
path: dist
musllinux-cross:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -292,10 +288,20 @@ jobs:
arch: aarch64
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: 3.12
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -339,8 +345,6 @@ jobs:
permissions:
id-token: write
if: "startsWith(github.ref, 'refs/tags/')"
# TODO: Uncomment if linux-cross is working
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
needs: [macos, windows, linux, musllinux, musllinux-cross]
steps:
- uses: actions/download-artifact@v3
@@ -348,35 +352,34 @@ jobs:
name: wheels
- name: List Files
run: ls -R
# Both publish steps will fail if there is no trusted publisher setup
# On failure the publish step will then simply continue to the next one
# # publishes to TestPyPI
# - name: Publish package distribution to TestPyPI
# uses: pypa/gh-action-pypi-publish@unstable/v1
# with:
# repository-url: https://test.pypi.org/legacy/
# packages-dir: ./
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
packages-dir: ./
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
permissions:
contents: read
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}
commit_ref: ${{ github.ref_name }}

View File

@@ -10,6 +10,9 @@ on:
- "*"
jobs:
create-release:
permissions:
contents: read
packages: write
name: create-release
runs-on: ubuntu-22.04
if: startsWith(github.ref, 'refs/tags/')
@@ -33,6 +36,9 @@ jobs:
tag_name: ${{ env.EZKL_VERSION }}
build-release-gpu:
permissions:
contents: read
packages: write
name: build-release-gpu
needs: ["create-release"]
runs-on: GPU
@@ -50,6 +56,9 @@ jobs:
components: rustfmt, clippy
- name: Checkout repo
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get release version from tag
shell: bash
@@ -91,6 +100,10 @@ jobs:
asset_content_type: application/octet-stream
build-release:
permissions:
contents: read
packages: write
issues: write
name: build-release
needs: ["create-release"]
runs-on: ${{ matrix.os }}
@@ -132,6 +145,8 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get release version from tag
shell: bash
@@ -181,14 +196,18 @@ jobs:
echo "target flag is: ${{ env.TARGET_FLAGS }}"
echo "target dir is: ${{ env.TARGET_DIR }}"
- name: Build release binary (no asm)
if: matrix.build != 'linux-gnu'
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
run: strip "target/${{ matrix.target }}/release/ezkl"

View File

@@ -19,11 +19,35 @@ env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-latest
fr-age-test:
needs: [build, library-tests, docs, python-tests, python-integration-tests]
permissions:
contents: read
runs-on: large-self-hosted
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: fr age Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
build:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -33,9 +57,13 @@ jobs:
run: cargo build --verbose
docs:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -45,9 +73,13 @@ jobs:
run: cargo doc --verbose
library-tests:
permissions:
contents: read
runs-on: ubuntu-latest-32-cores
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -71,6 +103,8 @@ jobs:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
@@ -101,9 +135,13 @@ jobs:
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
ultra-overflow-tests_og-lookup:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -125,18 +163,22 @@ jobs:
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: lookup overflow
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
- name: Matmul overflow
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
run: RUST_LOG=debug cargo nextest run --release matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
- name: Conv overflow
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
- name: Conv + relu overflow
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
ultra-overflow-tests:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -158,18 +200,22 @@ jobs:
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: lookup overflow
run: cargo nextest run --release lookup_ultra_overflow --no-capture -- --include-ignored
run: cargo nextest run lookup_ultra_overflow --no-capture -- --include-ignored
- name: Matmul overflow
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture -- --include-ignored
- name: Conv overflow
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture -- --include-ignored
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture -- --include-ignored
- name: Conv + relu overflow
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
model-serialization:
permissions:
contents: read
runs-on: ubuntu-latest-16-cores
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -183,9 +229,13 @@ jobs:
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
wasm32-tests:
runs-on: ubuntu-latest
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -194,7 +244,7 @@ jobs:
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
version: "v0.12.1"
- uses: nanasess/setup-chromedriver@v2
# with:
# chromedriver-version: "115.0.5790.102"
@@ -208,10 +258,13 @@ jobs:
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
mock-proving-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -222,59 +275,63 @@ jobs:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: public outputs and bounded lookup log
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
- name: kzg inputs
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
- name: kzg params
run: cargo nextest run --release --verbose tests::mock_kzg_params_::t --test-threads 32
- name: kzg outputs
run: cargo nextest run --release --verbose tests::mock_kzg_output_::t --test-threads 32
- name: kzg inputs + params + outputs
run: cargo nextest run --release --verbose tests::mock_kzg_all_::t --test-threads 32
- name: Mock fixed inputs
run: cargo nextest run --release --verbose tests::mock_fixed_inputs_ --test-threads 32
- name: Mock fixed outputs
run: cargo nextest run --release --verbose tests::mock_fixed_outputs --test-threads 32
- name: Mock accuracy calibration
run: cargo nextest run --release --verbose tests::mock_accuracy_cal_tests::a
- name: hashed inputs
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
- name: hashed params
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
- name: hashed params public inputs
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
- name: hashed outputs
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
- name: hashed inputs + params + outputs
run: cargo nextest run --release --verbose tests::mock_hashed_all_::t --test-threads 32
- name: hashed inputs + fixed params
run: cargo nextest run --release --verbose tests::mock_hashed_output_fixed_params_::t --test-threads 32
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: MNIST Gan Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
- name: NanoGPT Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_1_expects -- --include-ignored
run: cargo nextest run --verbose tests::large_mock_::large_tests_1_expects -- --include-ignored
- name: Self Attention Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_0_expects -- --include-ignored
run: cargo nextest run --verbose tests::large_mock_::large_tests_0_expects -- --include-ignored
- name: Multihead Attention Mock
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
run: cargo nextest run --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
- name: public outputs
run: cargo nextest run --release --verbose tests::mock_public_outputs_ --test-threads 32
run: cargo nextest run --verbose tests::mock_public_outputs_ --test-threads 32
- name: public inputs
run: cargo nextest run --release --verbose tests::mock_public_inputs_ --test-threads 32
run: cargo nextest run --verbose tests::mock_public_inputs_ --test-threads 32
- name: fixed params
run: cargo nextest run --release --verbose tests::mock_fixed_params_ --test-threads 32
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
- name: public outputs and bounded lookup log
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
- name: public outputs and tolerance > 0
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
- name: kzg inputs
run: cargo nextest run --verbose tests::mock_kzg_input_::t --test-threads 32
- name: kzg params
run: cargo nextest run --verbose tests::mock_kzg_params_::t --test-threads 32
- name: kzg outputs
run: cargo nextest run --verbose tests::mock_kzg_output_::t --test-threads 32
- name: kzg inputs + params + outputs
run: cargo nextest run --verbose tests::mock_kzg_all_::t --test-threads 32
- name: Mock fixed inputs
run: cargo nextest run --verbose tests::mock_fixed_inputs_ --test-threads 32
- name: Mock fixed outputs
run: cargo nextest run --verbose tests::mock_fixed_outputs --test-threads 32
- name: Mock accuracy calibration
run: cargo nextest run --verbose tests::mock_accuracy_cal_tests::a
- name: hashed inputs
run: cargo nextest run --verbose tests::mock_hashed_input_::t --test-threads 32
- name: hashed params
run: cargo nextest run --verbose tests::mock_hashed_params_::t --test-threads 32
- name: hashed params public inputs
run: cargo nextest run --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
- name: hashed outputs
run: cargo nextest run --verbose tests::mock_hashed_output_::t --test-threads 32
- name: hashed inputs + params + outputs
run: cargo nextest run --verbose tests::mock_hashed_all_::t --test-threads 32
- name: hashed inputs + fixed params
run: cargo nextest run --verbose tests::mock_hashed_output_fixed_params_::t --test-threads 32
prove-and-verify-evm-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -285,6 +342,8 @@ jobs:
crate: cargo-nextest
locked: true
- uses: actions/checkout@v3
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
@@ -305,7 +364,7 @@ jobs:
NODE_ENV: development
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
@@ -319,41 +378,79 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg inputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_input_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_input_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_params_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & kzg outputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain all kzg)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + hashed inputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_input_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_input_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + hashed params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_params_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + hashed outputs)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
# prove-and-verify-tests-metal:
# permissions:
# contents: read
# runs-on: macos-13
# # needs: [build, library-tests, docs]
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: jetli/wasm-pack-action@v0.4.0
# with:
# # Pin to version 0.12.1
# version: 'v0.12.1'
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2024-07-18
# - uses: actions/checkout@v3
# with:
# persist-credentials: false
# - name: Use pnpm 8
# uses: pnpm/action-setup@v2
# with:
# version: 8
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG prove and verify tests (public outputs)
# run: cargo nextest run --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
prove-and-verify-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -362,13 +459,15 @@ jobs:
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
version: "v0.12.1"
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@v2
with:
@@ -390,40 +489,40 @@ jobs:
locked: true
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --release --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
- name: KZG prove and verify tests (hashed inputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
run: cargo nextest run --verbose tests::kzg_prove_and_verify_tight_lookup_::t
- name: IPA prove and verify tests
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
run: cargo nextest run --verbose tests::ipa_prove_and_verify_::t --test-threads 1
- name: IPA prove and verify tests (ipa outputs)
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
run: cargo nextest run --verbose tests::ipa_prove_and_verify_ipa_output
- name: KZG prove and verify tests single inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_single_col
run: cargo nextest run --verbose tests::kzg_prove_and_verify_single_col
- name: KZG prove and verify tests triple inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_triple_col
run: cargo nextest run --verbose tests::kzg_prove_and_verify_triple_col
- name: KZG prove and verify tests quadruple inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_quadruple_col
run: cargo nextest run --verbose tests::kzg_prove_and_verify_quadruple_col
- name: KZG prove and verify tests octuple inner col
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
run: cargo nextest run --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t
- name: KZG prove and verify tests (public inputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input
run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input
- name: KZG prove and verify tests (fixed params)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params
run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed
# prove-and-verify-tests-gpu:
# runs-on: GPU
@@ -431,6 +530,8 @@ jobs:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
@@ -444,27 +545,31 @@ jobs:
# crate: cargo-nextest
# locked: true
# - name: KZG prove and verify tests (kzg outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# - name: KZG prove and verify tests (public outputs + column overflow)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
# - name: KZG prove and verify tests (public inputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
# - name: KZG prove and verify tests (fixed params)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
# - name: KZG prove and verify tests (hashed outputs)
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
prove-and-verify-mock-aggr-tests:
permissions:
contents: read
runs-on: self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -475,7 +580,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: Mock aggr tests (KZG)
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
run: cargo nextest run --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
# prove-and-verify-aggr-tests-gpu:
# runs-on: GPU
@@ -483,6 +588,8 @@ jobs:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
@@ -496,10 +603,14 @@ jobs:
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
prove-and-verify-aggr-tests:
permissions:
contents: read
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -510,13 +621,17 @@ jobs:
crate: cargo-nextest
locked: true
- name: KZG tests
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
prove-and-verify-aggr-evm-tests:
permissions:
contents: read
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -531,13 +646,17 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
run: cargo nextest run --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
examples:
permissions:
contents: read
runs-on: ubuntu-latest-32-cores
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
@@ -551,10 +670,14 @@ jobs:
run: cargo nextest run --release tests_examples
python-tests:
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.12"
@@ -572,15 +695,19 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
- name: Run pytest
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
permissions:
contents: read
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.12"
@@ -596,17 +723,19 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
- name: Public inputs
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_public_inputs_
- name: fixed params
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_fixed_params_
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_fixed_params_
- name: Public outputs
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_outputs_
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_public_outputs_
- name: Public outputs + resources
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
run: source .env/bin/activate; cargo nextest run --verbose tests::resources_accuracy_measurement_public_outputs_
python-integration-tests:
permissions:
contents: read
runs-on: large-self-hosted
services:
# Label used to access the service container
@@ -628,6 +757,8 @@ jobs:
- 5432:5432
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: "3.11"
@@ -649,7 +780,11 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: Neural bow
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
- name: Felt conversion
@@ -667,74 +802,87 @@ jobs:
# # now dump the contents of the file into a file called kaggle.json
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
# - name: Reusable verifier tutorial
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
ios-integration-tests:
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
permissions:
contents: read
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
swift-package-tests:
runs-on: macos-latest
needs: [ios-integration-tests]
permissions:
contents: read
runs-on: macos-latest
needs: [ios-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build EzklCoreBindings
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build EzklCoreBindings
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift- repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Clone ezkl-swift- repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder

33
.github/workflows/static-analysis.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Static Analysis
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
analyze:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
# Run Zizmor static analysis
- name: Install Zizmor
run: cargo install --locked zizmor
- name: Run Zizmor Analysis
run: zizmor .

View File

@@ -9,18 +9,24 @@ on:
jobs:
build-and-update:
permissions:
contents: read
packages: write
runs-on: macos-latest
env:
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
RELEASE_TAG: ${{ github.ref_name }}
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
with:
persist-credentials: false
- name: Extract TAG from github.ref_name
run: |
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
TAG="${{ github.ref_name }}"
TAG="${RELEASE_TAG}"
echo "Original TAG: $TAG"
# Remove leading 'v' if present to match the Swift Package Manager version format.
NEW_TAG=${TAG#v}
@@ -47,7 +53,8 @@ jobs:
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/*
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
@@ -105,7 +112,6 @@ jobs:
cd ezkl-swift-package
git add Sources/EzklCoreBindings Tests/EzklAssets
git commit -m "Automatically updated EzklCoreBindings for EZKL"
if ! git push origin; then
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
exit 1
@@ -115,7 +121,6 @@ jobs:
run: |
cd ezkl-swift-package
source $GITHUB_ENV
# Tag the latest commit on the current branch
if git rev-parse "$TAG" >/dev/null 2>&1; then
echo "Tag $TAG already exists locally. Skipping tag creation."

View File

@@ -12,6 +12,8 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.2

128
Cargo.lock generated
View File

@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
version = 4
[[package]]
name = "addr2line"
@@ -1760,7 +1760,7 @@ checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
[[package]]
name = "ecc"
version = "0.1.0"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
dependencies = [
"integer",
"num-bigint",
@@ -1835,6 +1835,16 @@ dependencies = [
"syn 2.0.90",
]
[[package]]
name = "env_filter"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0"
dependencies = [
"log",
"regex",
]
[[package]]
name = "env_logger"
version = "0.10.2"
@@ -1848,6 +1858,19 @@ dependencies = [
"termcolor",
]
[[package]]
name = "env_logger"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0"
dependencies = [
"anstream",
"anstyle",
"env_filter",
"humantime",
"log",
]
[[package]]
name = "equivalent"
version = "1.0.1"
@@ -1923,7 +1946,7 @@ dependencies = [
"console_error_panic_hook",
"criterion 0.5.1",
"ecc",
"env_logger",
"env_logger 0.10.2",
"ethabi",
"foundry-compilers",
"gag",
@@ -1931,7 +1954,7 @@ dependencies = [
"halo2_gadgets",
"halo2_proofs",
"halo2_solidity_verifier",
"halo2curves 0.7.0",
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
"hex",
"indicatif",
"instant",
@@ -1939,20 +1962,17 @@ dependencies = [
"lazy_static",
"log",
"maybe-rayon",
"metal",
"mimalloc",
"mnist",
"num",
"objc",
"openssl",
"pg_bigdecimal",
"portable-atomic",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"rand 0.8.5",
"regex",
"reqwest",
"semver 1.0.22",
"seq-macro",
@@ -2394,14 +2414,14 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6#ee4e1a09ebdb1f79f797685b78951c6034c430a6"
source = "git+https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996#bf9d0057a82443be48c4779bbe14961c18fb5996"
dependencies = [
"bincode",
"blake2b_simd",
"env_logger",
"env_logger 0.10.2",
"ff",
"group",
"halo2curves 0.7.0",
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
"icicle-bn254",
"icicle-core",
"icicle-cuda-runtime",
@@ -2409,6 +2429,7 @@ dependencies = [
"lazy_static",
"log",
"maybe-rayon",
"mopro-msm",
"rand_chacha",
"rand_core 0.6.4",
"rustc-hash 2.0.0",
@@ -2494,6 +2515,36 @@ dependencies = [
"subtle",
]
[[package]]
name = "halo2curves"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d380afeef3f1d4d3245b76895172018cfb087d9976a7cabcd5597775b2933e07"
dependencies = [
"blake2",
"digest 0.10.7",
"ff",
"group",
"halo2derive 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"hex",
"lazy_static",
"num-bigint",
"num-integer",
"num-traits",
"pairing",
"pasta_curves",
"paste",
"rand 0.8.5",
"rand_core 0.6.4",
"rayon",
"serde",
"serde_arrays",
"sha2",
"static_assertions",
"subtle",
"unroll",
]
[[package]]
name = "halo2curves"
version = "0.7.0"
@@ -2503,7 +2554,7 @@ dependencies = [
"digest 0.10.7",
"ff",
"group",
"halo2derive",
"halo2derive 0.1.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
"hex",
"lazy_static",
"num-bigint",
@@ -2523,6 +2574,20 @@ dependencies = [
"unroll",
]
[[package]]
name = "halo2derive"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdb99e7492b4f5ff469d238db464131b86c2eaac814a78715acba369f64d2c76"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "halo2derive"
version = "0.1.0"
@@ -2539,7 +2604,7 @@ dependencies = [
[[package]]
name = "halo2wrong"
version = "0.1.0"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
dependencies = [
"halo2_proofs",
"num-bigint",
@@ -2890,7 +2955,7 @@ dependencies = [
[[package]]
name = "integer"
version = "0.1.0"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
dependencies = [
"maingate",
"num-bigint",
@@ -3074,7 +3139,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.48.5",
"windows-targets 0.52.6",
]
[[package]]
@@ -3201,7 +3266,7 @@ dependencies = [
[[package]]
name = "maingate"
version = "0.1.0"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
dependencies = [
"halo2wrong",
"num-bigint",
@@ -3283,7 +3348,8 @@ dependencies = [
[[package]]
name = "metal"
version = "0.29.0"
source = "git+https://github.com/gfx-rs/metal-rs#0e1918b34689c4b8cd13a43372f9898680547ee9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
dependencies = [
"bitflags 2.5.0",
"block",
@@ -3354,6 +3420,28 @@ dependencies = [
"byteorder",
]
[[package]]
name = "mopro-msm"
version = "0.1.0"
source = "git+https://github.com/zkonduit/metal-msm-gpu-acceleration.git#be5f647b1a6c1a6ea9024390744a2b4d87f5d002"
dependencies = [
"bincode",
"env_logger 0.11.6",
"halo2curves 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"instant",
"itertools 0.13.0",
"lazy_static",
"log",
"metal",
"objc",
"once_cell",
"rand 0.8.5",
"rayon",
"serde",
"thiserror",
"walkdir",
]
[[package]]
name = "native-tls"
version = "0.2.11"
@@ -3587,9 +3675,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "openssl-src"
version = "300.2.3+3.2.1"
version = "300.4.1+3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cff92b6f71555b61bb9315f7c64da3ca43d87531622120fea0195fc761b4843"
checksum = "faa4eac4138c62414b5622d1b31c5c304f34b406b013c079c2bbc652fdd6678c"
dependencies = [
"cc",
]
@@ -5142,7 +5230,7 @@ checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
[[package]]
name = "snark-verifier"
version = "0.1.1"
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac%2Fchunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
dependencies = [
"ecc",
"halo2_proofs",
@@ -6146,7 +6234,7 @@ dependencies = [
[[package]]
name = "uniffi_testing"
version = "0.28.0"
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat/testing-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat%2Ftesting-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
dependencies = [
"anyhow",
"camino",

View File

@@ -40,7 +40,6 @@ maybe-rayon = { version = "0.1.1", default-features = false }
bincode = { version = "1.3.3", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = { version = "1.6.0", optional = true }
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
semver = { version = "1.0.22", optional = true }
@@ -74,7 +73,6 @@ tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
regex = { version = "1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
@@ -91,7 +89,6 @@ pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", ver
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
@@ -245,16 +242,14 @@ ezkl = [
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:openssl",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:regex",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:portable-atomic",
"dep:clap_complete",
"dep:halo2_solidity_verifier",
"dep:semver",
@@ -277,13 +272,14 @@ icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
@@ -292,9 +288,13 @@ uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "fea
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
# panic = "abort"
#panic = "abort"
[profile.test-runs]
inherits = "dev"
opt-level = 3
[package.metadata.wasm-pack.profile.release]
wasm-opt = [
"-O4",

View File

@@ -10,6 +10,7 @@ use rand::Rng;
// Assuming these are your types
#[derive(Clone)]
#[allow(dead_code)]
enum ValType {
Constant(F),
AssignedConstant(usize, F),

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '18.1.11'
version = release

View File

@@ -77,6 +77,7 @@
"outputs": [],
"source": [
"gip_run_args = ezkl.PyRunArgs()\n",
"gip_run_args.ignore_range_check_inputs_outputs = True\n",
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
@@ -335,9 +336,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -308,8 +308,11 @@
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.decomp_legs = 4\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
"\n",
"assert res == True\n",
"\n",

View File

@@ -453,18 +453,18 @@
"outputs": [],
"source": [
"# now mock aggregate the proofs\n",
"proofs = []\n",
"for i in range(3):\n",
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
" proofs.append(proof_path)\n",
"# proofs = []\n",
"# for i in range(3):\n",
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
"# proofs.append(proof_path)\n",
"\n",
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"display_name": ".env",
"language": "python",
"name": "python3"
},
@@ -478,7 +478,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.7"
},
"orig_nbformat": 4
},

View File

@@ -152,9 +152,11 @@
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"run_args = ezkl.PyRunArgs()\n",
"# logrows\n",
"run_args.logrows = 20\n",
"\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n"
]
},
@@ -302,7 +304,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.9.13"
}
},
"nbformat": 4,

View File

@@ -167,6 +167,8 @@
"run_args = ezkl.PyRunArgs()\n",
"# \"hashed/private\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"hashed/private/0\"\n",
"# as the inputs are felts we turn off input range checks\n",
"run_args.ignore_range_check_inputs_outputs = True\n",
"# we set it to fix the set we want to check membership for\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public -- set membership fails if it is not = 0\n",
@@ -519,4 +521,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -204,6 +204,7 @@
"run_args = ezkl.PyRunArgs()\n",
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"polycommit\"\n",
"run_args.ignore_range_check_inputs_outputs = True\n",
"# the parameters are public\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public (this is the inequality test)\n",
@@ -514,4 +515,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -94,7 +94,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -134,7 +134,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -183,7 +183,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -201,6 +201,7 @@
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.decomp_legs=6\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]"
]

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
return x // 3
circuit = MyModel()
x = torch.randint(0, 10, (1, 2, 2, 8))
out = circuit(x)
print(x)
print(out)
print(x/3)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}

Binary file not shown.

View File

@@ -1,7 +1,11 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
static GLOBAL: MiMalloc = MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};

View File

@@ -207,6 +207,9 @@ struct PyRunArgs {
/// bool: Should the circuit use unbounded lookups for log
#[pyo3(get, set)]
pub bounded_log_lookup: bool,
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
#[pyo3(get, set)]
pub ignore_range_check_inputs_outputs: bool,
}
/// default instantiation of PyRunArgs
@@ -239,6 +242,7 @@ impl From<PyRunArgs> for RunArgs {
commitment: Some(py_run_args.commitment.into()),
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
}
}
}
@@ -263,6 +267,7 @@ impl Into<PyRunArgs> for RunArgs {
commitment: self.commitment.into(),
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
}
}
}

View File

@@ -100,9 +100,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
for input in hash_inputs.iter().take(WIDTH) {
meta.enable_equality(*input);
}
meta.enable_constant(rc_b[0]);
Self::configure_with_cols(
@@ -152,9 +149,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
for input in hash_inputs.iter().take(WIDTH) {
meta.enable_equality(*input);
}
meta.enable_constant(rc_b[0]);
let instance = meta.instance_column();
@@ -176,7 +170,10 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, ModuleError> {
assert_eq!(message.len(), 1);
if message.len() != 1 {
return Err(ModuleError::InputWrongLength(message.len()));
}
let message = message[0].clone();
let start_time = instant::Instant::now();
@@ -231,7 +228,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
}
e => Err(ModuleError::WrongInputType(
format!("{:?}", e),
"PrevAssigned".to_string(),
"AssignedValue".to_string(),
)),
}
})
@@ -296,6 +293,12 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, ModuleError> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// empty hash case
if input_cells.is_empty() {
return Ok(input[0].clone());
}
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
input_cells.iter().map(|e| ValType::from(e.clone())).into();
@@ -517,6 +520,21 @@ mod tests {
}
}
#[test]
fn poseidon_hash_empty() {
let message = [];
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 9;
let circuit = HashCircuit::<PoseidonSpec, 2> {
message: message.into(),
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify(), Ok(()))
}
#[test]
fn poseidon_hash() {
let rng = rand::rngs::OsRng;

View File

@@ -17,7 +17,6 @@ pub enum BaseOp {
Sub,
SumInit,
Sum,
IsBoolean,
}
/// Matches a [BaseOp] to an operation over inputs
@@ -34,7 +33,6 @@ impl BaseOp {
BaseOp::Add => a + b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
}
@@ -74,7 +72,6 @@ impl BaseOp {
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -90,7 +87,6 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -106,7 +102,6 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::IsBoolean => 0,
}
}
@@ -122,7 +117,6 @@ impl BaseOp {
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsBoolean => 0,
}
}
}

View File

@@ -2,7 +2,7 @@ use std::str::FromStr;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Constraints, Expression, Selector},
plonk::{ConstraintSystem, Constraints, Expression, Selector, TableColumn},
poly::Rotation,
};
use log::debug;
@@ -215,15 +215,16 @@ impl DynamicLookups {
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct Shuffles {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
/// Selectors for the dynamic lookup tables
pub reference_selectors: Vec<Selector>,
pub output_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub references: Vec<VarTensor>,
pub outputs: Vec<VarTensor>,
}
impl Shuffles {
@@ -234,9 +235,13 @@ impl Shuffles {
Self {
input_selectors: BTreeMap::new(),
reference_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone()],
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
output_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
outputs: vec![
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
],
}
}
}
@@ -336,6 +341,8 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
/// shared table inputs
pub shared_table_inputs: Vec<TableColumn>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
@@ -348,6 +355,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
shared_table_inputs: vec![],
_marker: PhantomData,
}
}
@@ -374,13 +382,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
if inputs[0].num_cols() != output.num_cols() {
log::warn!("input and output shapes do not match");
}
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
log::warn!("input number of inner columns do not match");
}
if inputs[0].num_inner_cols() != output.num_inner_cols() {
log::warn!("input and output number of inner columns do not match");
}
for i in 0..output.num_blocks() {
for j in 0..output.num_inner_cols() {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -414,24 +427,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.query_offset_rng();
let constraints = match base_op {
BaseOp::IsBoolean => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("non accum: output query failed");
let output = expected_output[base_op.constraint_idx()].clone();
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("non accum: output query failed");
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.constraint_idx()].clone() - res]
}
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.constraint_idx()].clone() - res]
};
Constraints::with_selector(selector, constraints)
@@ -486,6 +488,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
dynamic_lookups: DynamicLookups::default(),
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
shared_table_inputs: vec![],
check_mode,
_marker: PhantomData,
}
@@ -516,21 +519,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
return Err(CircuitError::WrongColumnType(output.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
let table = if !self.static_lookups.tables.contains_key(nl) {
// as all tables have the same input we see if there's another table who's input we can reuse
let table = if let Some(table) = self.static_lookups.tables.values().next() {
Table::<F>::configure(
cs,
lookup_range,
logrows,
nl,
Some(table.table_inputs.clone()),
)
} else {
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
};
let table =
Table::<F>::configure(cs, lookup_range, logrows, nl, &mut self.shared_table_inputs);
self.static_lookups.tables.insert(nl.clone(), table.clone());
table
} else {
@@ -581,9 +572,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
// this is 0 if the index is the same as the column index (starting from 1)
let col_expr = sel.clone()
* table
* (table
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
.get_expr_at_idx(col_idx, synthetic_sel));
let multiplier =
table.selector_constructor.get_selector_val_at_idx(col_idx);
@@ -615,6 +606,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
res
});
}
// add a degree-k custom constraint of the following form to the range check and
// static lookup configuration.
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
cs.create_gate("range_check_on_sel", |cs| {
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let range_check_on_synthetic_sel = match len {
1 => Expression::Constant(F::from(0)),
_ => {
let mut initial_expr = Expression::Constant(F::from(1));
for i in 0..len {
initial_expr = initial_expr
* (synthetic_sel.clone()
- Expression::Constant(F::from(i as u64)))
}
initial_expr
}
};
let sel = cs.query_selector(multi_col_selector);
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
});
self.static_lookups
.selectors
.insert((nl.clone(), x, y), multi_col_selector);
@@ -740,8 +765,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn configure_shuffles(
&mut self,
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
inputs: &[VarTensor; 3],
outputs: &[VarTensor; 3],
) -> Result<(), CircuitError>
where
F: Field,
@@ -752,14 +777,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
}
}
for t in references.iter() {
for t in outputs.iter() {
if !t.is_advice() || t.num_inner_cols() > 1 {
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
// assert all tables have the same number of blocks
if references
if outputs
.iter()
.map(|t| t.num_blocks())
.collect::<Vec<_>>()
@@ -767,23 +792,23 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
.any(|w| w[0] != w[1])
{
return Err(CircuitError::WrongDynamicColumnType(
"references inner cols".to_string(),
"outputs inner cols".to_string(),
));
}
let one = Expression::Constant(F::ONE);
for q in 0..references[0].num_blocks() {
let s_reference = cs.complex_selector();
for q in 0..outputs[0].num_blocks() {
let s_output = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
cs.lookup_any("shuffle", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_referenceq = cs.query_selector(s_reference);
let s_outputq = cs.query_selector(s_output);
let mut input_queries = vec![one.clone()];
for input in inputs {
@@ -795,9 +820,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
});
}
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
let mut output_queries = vec![one.clone()];
for output in outputs {
output_queries.push(match output {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[q][0], Rotation(0))
}
@@ -806,7 +831,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
}
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
expression.extend(lhs.zip(rhs));
expression
@@ -817,13 +842,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
.or_insert(s_input);
}
}
self.shuffles.reference_selectors.push(s_reference);
self.shuffles.output_selectors.push(s_output);
}
// if we haven't previously initialized the input/output, do so now
if self.shuffles.references.is_empty() {
debug!("assigning shuffles reference");
self.shuffles.references = references.to_vec();
if self.shuffles.outputs.is_empty() {
debug!("assigning shuffles output");
self.shuffles.outputs = outputs.to_vec();
}
if self.shuffles.inputs.is_empty() {
debug!("assigning shuffles input");
@@ -855,7 +880,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
self.range_checks.ranges.entry(range)
{
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
@@ -893,9 +917,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
let default_x = range_check.get_first_element(col_idx);
let col_expr = sel.clone()
* range_check
* (range_check
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
.get_expr_at_idx(col_idx, synthetic_sel));
let multiplier = range_check
.selector_constructor
@@ -918,6 +942,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
res
});
}
// add a degree-k custom constraint of the following form to the range check and
// static lookup configuration.
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
cs.create_gate("range_check_on_sel", |cs| {
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let range_check_on_synthetic_sel = match len {
1 => Expression::Constant(F::from(0)),
_ => {
let mut initial_expr = Expression::Constant(F::from(1));
for i in 0..len {
initial_expr = initial_expr
* (synthetic_sel.clone()
- Expression::Constant(F::from(i as u64)))
}
initial_expr
}
};
let sel = cs.query_selector(multi_col_selector);
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
});
self.range_checks
.selectors
.insert((range, x, y), multi_col_selector);

View File

@@ -25,7 +25,7 @@ pub enum CircuitError {
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
///
/// Invalid einsum expression
#[error("invalid einsum expression")]
InvalidEinsum,
/// Flush error
@@ -100,4 +100,13 @@ pub enum CircuitError {
#[error("invalid input type {0}")]
/// Invalid input type
InvalidInputType(String),
#[error("an element is missing from the shuffled version of the tensor")]
/// An element is missing from the shuffled version of the tensor
MissingShuffleElement,
/// Visibility has not been set
#[error("visibility has not been set")]
UnsetVisibility,
/// A decomposition base overflowed
#[error("decomposition base overflowed")]
DecompositionBaseOverflow,
}

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::integer_rep_to_felt,
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
};
@@ -76,7 +76,10 @@ pub enum HybridOp {
output_scale: utils::F32,
axes: Vec<usize>,
},
RangeCheck(Tolerance),
Output {
tol: Tolerance,
decomp: bool,
},
Greater,
GreaterEqual,
Less,
@@ -178,7 +181,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale, output_scale, axes
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Output { tol, decomp } => {
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
}
HybridOp::Greater => "GREATER".to_string(),
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
HybridOp::Less => "LESS".to_string(),
@@ -250,8 +255,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as i128),
integer_rep_to_felt(output_scale.0 as i128),
integer_rep_to_felt(input_scale.0 as IntegerRep),
integer_rep_to_felt(output_scale.0 as IntegerRep),
)?,
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
@@ -259,7 +264,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
integer_rep_to_felt(denom.0 as i128),
integer_rep_to_felt(denom.0 as IntegerRep),
)?
} else {
layouts::nonlinearity(
@@ -314,12 +319,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
*output_scale,
axes,
)?,
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
HybridOp::Output { tol, decomp } => layouts::output(
config,
region,
values[..].try_into()?,
tol.scale,
tol.val,
*decomp,
)?,
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
HybridOp::GreaterEqual => {

File diff suppressed because it is too large Load Diff

View File

@@ -159,6 +159,8 @@ pub struct Input {
pub scale: crate::Scale,
///
pub datum_type: InputType,
/// decomp check
pub decomp: bool,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
@@ -196,6 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
config,
region,
values[..].try_into()?,
self.decomp,
)?)),
}
} else {
@@ -251,20 +254,26 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
#[serde(skip)]
pub pre_assigned_val: Option<ValTensor<F>>,
///
pub decomp: bool,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>, decomp: bool) -> Self {
Self {
quantized_values,
raw_values,
pre_assigned_val: None,
decomp,
}
}
/// Rebase the scale of the constant
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
let visibility = self.quantized_values.visibility().unwrap();
let visibility = match self.quantized_values.visibility() {
Some(v) => v,
None => return Err(CircuitError::UnsetVisibility),
};
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
Ok(())
}
@@ -308,7 +317,12 @@ impl<
self.quantized_values.clone().try_into()?
};
// we gotta constrain it once if its used multiple times
Ok(Some(layouts::identity(config, region, &[value])?))
Ok(Some(layouts::identity(
config,
region,
&[value],
self.decomp,
)?))
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {

View File

@@ -252,6 +252,12 @@ impl<
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
if values.len() != 1 {
return Err(TensorError::DimError(
"GatherElements only accepts single inputs".to_string(),
)
.into());
}
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
@@ -269,6 +275,12 @@ impl<
}
PolyOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
if values.len() != 2 {
return Err(TensorError::DimError(
"ScatterElements requires two inputs".to_string(),
)
.into());
}
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
@@ -311,7 +323,9 @@ impl<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Identity { .. } => {
layouts::identity(config, region, values[..].try_into()?, false)?
}
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {

View File

@@ -132,21 +132,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
(first_element, op_f.output[0])
}
///
/// calculates the column size given the number of rows and reserved blinding rows
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(logrows as u32) - reserved_blinding_rows
}
///
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
}
///
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as IntegerRep)) as usize + 1
(range_len / col_size as IntegerRep) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
@@ -168,7 +163,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
range: Range,
logrows: usize,
nonlinearity: &LookupOp,
preexisting_inputs: Option<Vec<TableColumn>>,
preexisting_inputs: &mut Vec<TableColumn>,
) -> Table<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
@@ -177,28 +172,28 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
debug!("table range: {:?}", range);
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
let mut cols = vec![];
for _ in 0..num_cols {
cols.push(cs.lookup_table_column());
// validate enough columns are provided to store the range
if preexisting_inputs.len() < num_cols {
// add columns to match the required number of columns
let diff = num_cols - preexisting_inputs.len();
for _ in 0..diff {
preexisting_inputs.push(cs.lookup_table_column());
}
cols
});
let num_cols = table_inputs.len();
}
let num_cols = preexisting_inputs.len();
if num_cols > 1 {
warn!("Using {} columns for non-linearity table.", num_cols);
}
let table_outputs = table_inputs
let table_outputs = preexisting_inputs
.iter()
.map(|_| cs.lookup_table_column())
.collect::<Vec<_>>();
Table {
nonlinearity: nonlinearity.clone(),
table_inputs,
table_inputs: preexisting_inputs.clone(),
table_outputs,
is_assigned: false,
selector_constructor: SelectorConstructor::new(num_cols),
@@ -355,16 +350,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
}
///
/// calculates the column size
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(logrows as u32) - reserved_blinding_rows
}
///
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in

View File

@@ -1040,6 +1040,10 @@ mod conv {
let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
// column for constants
let _constant = VarTensor::constant_cols(cs, K, 8, false);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
@@ -1171,7 +1175,7 @@ mod conv_col_ultra_overflow {
use super::*;
const K: usize = 4;
const K: usize = 6;
const LEN: usize = 10;
#[derive(Clone)]
@@ -1191,9 +1195,10 @@ mod conv_col_ultra_overflow {
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * LEN * LEN * LEN, false);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
@@ -1776,13 +1781,18 @@ mod shuffle {
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
config
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.configure_shuffles(
cs,
&[a.clone(), b.clone(), c.clone()],
&[d.clone(), e.clone(), f.clone()],
)
.unwrap();
config
}
@@ -1803,6 +1813,7 @@ mod shuffle {
&mut region,
&self.inputs[i],
&self.references[i],
layouts::SortCollisionMode::Unsorted,
)
.map_err(|_| Error::Synthesis)?;
}

View File

@@ -83,7 +83,7 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
/// Default VK abi path
pub const DEFAULT_VK_ABI: &str = "vk.abi";
/// Default scale rebase multipliers for calibration
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,10";
/// Default use reduced srs for verification
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
/// Default only check for range check rebase

View File

@@ -517,6 +517,7 @@ pub async fn deploy_da_verifier_via_solidity(
}
}
#[allow(clippy::too_many_arguments)]
async fn deploy_multi_da_contract(
client: EthersClient,
contract_instance_offset: usize,

View File

@@ -118,7 +118,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
} => gen_srs_cmd(
srs_path,
logrows as u32,
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
commitment.unwrap_or_else(|| Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
),
Commands::GetSrs {
srs_path,
@@ -1535,7 +1535,8 @@ pub(crate) async fn create_evm_data_attestation(
trace!("params computed");
// if input is not provided, we just instantiate dummy input data
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
let data =
GraphData::from_path(input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
// The number of input and output instances we attest to for the single call data attestation
let mut input_len = None;
@@ -2126,6 +2127,7 @@ pub(crate) fn mock_aggregate(
Ok(String::new())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
vk_path: PathBuf,

View File

@@ -5,10 +5,12 @@ use halo2curves::ff::PrimeField;
/// Integer representation of a PrimeField element.
pub type IntegerRep = i128;
/// Converts an i64 to a PrimeField element.
/// Converts an integer rep to a PrimeField element.
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else if x == IntegerRep::MIN {
-F::from_u128(x.saturating_neg() as u128) - F::ONE
} else {
-F::from_u128(x.saturating_neg() as u128)
}
@@ -32,6 +34,9 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
/// Converts a PrimeField element to an i64.
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
if x > F::from_u128(IntegerRep::MAX as u128) {
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
return IntegerRep::MIN;
}
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
@@ -51,7 +56,7 @@ mod test {
use halo2curves::pasta::Fp as F;
#[test]
fn test_conv() {
fn integerreptofelt() {
let res: F = integer_rep_to_felt(-15);
assert_eq!(res, -F::from(15));
@@ -69,8 +74,24 @@ mod test {
fn felttointegerrep() {
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
}
#[test]
fn felttointegerrepmin() {
let x = IntegerRep::MIN;
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
#[test]
fn felttointegerrepmax() {
let x = IntegerRep::MAX;
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
}

View File

@@ -11,6 +11,12 @@ pub enum GraphError {
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Non scalar power
#[error("we only support scalar powers")]
NonScalarPower,
/// Non scalar base for exponentiation
#[error("we only support scalar bases for exponentiation")]
NonScalarBase,
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
@@ -113,13 +119,13 @@ pub enum GraphError {
/// Missing input for a node
#[error("missing input for node {0}")]
MissingInput(usize),
///
/// Ranges can only be constant
#[error("range only supports constant inputs in a zk circuit")]
NonConstantRange,
///
/// Trilu diagonal must be constant
#[error("trilu only supports constant diagonals in a zk circuit")]
NonConstantTrilu,
///
/// The witness was too short
#[error("insufficient witness values to generate a fixed output")]
InsufficientWitnessValues,
/// Missing scale
@@ -143,4 +149,13 @@ pub enum GraphError {
/// Invalid RunArg
#[error("invalid RunArgs: {0}")]
InvalidRunArgs(String),
/// Only nearest neighbor interpolation is supported
#[error("only nearest neighbor interpolation is supported")]
InvalidInterpolation,
/// Node has a missing output
#[error("node {0} has a missing output")]
MissingOutput(usize),
/// Inssuficient advice columns
#[error("insuficcient advice columns (need {0} at least)")]
InsufficientAdviceColumns(usize),
}

View File

@@ -14,7 +14,6 @@ use pyo3::prelude::*;
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
@@ -25,6 +24,7 @@ use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
@@ -32,30 +32,95 @@ type Decimals = u8;
type Call = String;
type RPCUrl = String;
///
/// Represents different types of values that can be stored in a file source
/// Used for handling various input types in zero-knowledge proofs
#[derive(Clone, Debug, PartialOrd, PartialEq)]
pub enum FileSourceInner {
/// Inner elements of float inputs coming from a file
/// Floating point value (64-bit)
Float(f64),
/// Inner elements of bool inputs coming from a file
/// Boolean value
Bool(bool),
/// Inner elements of inputs coming from a witness
/// Field element value for direct use in circuits
Field(Fp),
}
impl FileSourceInner {
///
/// Returns true if the value is a floating point number
pub fn is_float(&self) -> bool {
matches!(self, FileSourceInner::Float(_))
}
///
/// Returns true if the value is a boolean
pub fn is_bool(&self) -> bool {
matches!(self, FileSourceInner::Bool(_))
}
///
/// Returns true if the value is a field element
pub fn is_field(&self) -> bool {
matches!(self, FileSourceInner::Field(_))
}
/// Creates a new floating point value
pub fn new_float(f: f64) -> Self {
FileSourceInner::Float(f)
}
/// Creates a new field element value
pub fn new_field(f: Fp) -> Self {
FileSourceInner::Field(f)
}
/// Creates a new boolean value
pub fn new_bool(f: bool) -> Self {
FileSourceInner::Bool(f)
}
/// Adjusts the value according to the specified input type
///
/// # Arguments
/// * `input_type` - Type specification to convert the value to
pub fn as_type(&mut self, input_type: &InputType) {
match self {
FileSourceInner::Float(f) => input_type.roundtrip(f),
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
FileSourceInner::Field(_) => {}
}
}
/// Converts the value to a field element using appropriate scaling
///
/// # Arguments
/// * `scale` - Scaling factor for floating point conversion
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => {
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
}
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
} else {
Fp::zero()
}
}
FileSourceInner::Field(f) => *f,
}
}
/// Converts the value to a floating point number
pub fn to_float(&self) -> f64 {
match self {
FileSourceInner::Float(f) => *f,
FileSourceInner::Bool(f) => {
if *f {
1.0
} else {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
}
}
}
impl Serialize for FileSourceInner {
@@ -71,8 +136,8 @@ impl Serialize for FileSourceInner {
}
}
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
// Deserialization implementation for FileSourceInner
// Uses JSON deserialization to handle the different variants
impl<'de> Deserialize<'de> for FileSourceInner {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
@@ -99,70 +164,16 @@ impl<'de> Deserialize<'de> for FileSourceInner {
}
}
/// Elements of inputs coming from a file
/// A collection of input values from a file source
/// Organized as a vector of vectors where each inner vector represents a row/entry
pub type FileSource = Vec<Vec<FileSourceInner>>;
impl FileSourceInner {
/// Create a new FileSourceInner
pub fn new_float(f: f64) -> Self {
FileSourceInner::Float(f)
}
/// Create a new FileSourceInner
pub fn new_field(f: Fp) -> Self {
FileSourceInner::Field(f)
}
/// Create a new FileSourceInner
pub fn new_bool(f: bool) -> Self {
FileSourceInner::Bool(f)
}
///
pub fn as_type(&mut self, input_type: &InputType) {
match self {
FileSourceInner::Float(f) => input_type.roundtrip(f),
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
FileSourceInner::Field(_) => {}
}
}
/// Convert to a field element
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => {
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
}
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
} else {
Fp::zero()
}
}
FileSourceInner::Field(f) => *f,
}
}
/// Convert to a float
pub fn to_float(&self) -> f64 {
match self {
FileSourceInner::Float(f) => *f,
FileSourceInner::Bool(f) => {
if *f {
1.0
} else {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
}
}
}
/// Call type for attested inputs on-chain
/// Represents different types of calls for fetching on-chain data
#[derive(Clone, Debug, PartialOrd, PartialEq)]
pub enum Calls {
/// Vector of calls to accounts, each returning an attested data point
/// Multiple calls to different accounts, each returning individual values
Multiple(Vec<CallsToAccount>),
/// Single call to account, returning an array of attested data points
/// Single call returning an array of values
Single(CallToAccount),
}
@@ -171,32 +182,6 @@ impl Default for Calls {
Calls::Multiple(Vec::new())
}
}
/// Inner elements of inputs/outputs coming from on-chain
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct OnChainSource {
/// Calls to accounts
pub calls: Calls,
/// RPC url
pub rpc: RPCUrl,
}
impl OnChainSource {
/// Create a new OnChainSource with multiple calls
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
OnChainSource {
calls: Calls::Multiple(calls),
rpc,
}
}
/// Create a new OnChainSource with a single call
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
OnChainSource {
calls: Calls::Single(call),
rpc,
}
}
}
impl Serialize for Calls {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
@@ -218,7 +203,6 @@ impl<'de> Deserialize<'de> for Calls {
D: Deserializer<'de>,
{
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
if let Ok(t) = multiple_try {
return Ok(Calls::Multiple(t));
@@ -228,111 +212,52 @@ impl<'de> Deserialize<'de> for Calls {
return Ok(Calls::Single(t));
}
Err(serde::de::Error::custom(
"failed to deserialize FileSourceInner",
))
Err(serde::de::Error::custom("failed to deserialize Calls"))
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Inner elements of inputs/outputs coming from postgres DB
/// Configuration for accessing on-chain data sources
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
/// postgres host
pub host: RPCUrl,
/// user to connect to postgres
pub user: String,
/// password to connect to postgres
pub password: String,
/// query to execute
pub query: String,
/// dbname
pub dbname: String,
/// port
pub port: String,
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Create a new PostgresSource
pub fn new(
host: RPCUrl,
port: String,
user: String,
query: String,
dbname: String,
password: String,
) -> Self {
PostgresSource {
host,
user,
password,
query,
dbname,
port,
}
}
/// Fetch data from postgres
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// clone to move into thread
let user = self.user.clone();
let host = self.host.clone();
let query = self.query.clone();
let dbname = self.dbname.clone();
let port = self.port.clone();
let password = self.password.clone();
let config = if password.is_empty() {
format!(
"host={} user={} dbname={} port={}",
host, user, dbname, port
)
} else {
format!(
"host={} user={} dbname={} port={} password={}",
host, user, dbname, port, password
)
};
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// extract rows from query
for row in client.query(&query, &[]).await? {
// extract features from row
for i in 0..row.len() {
res.push(row.get(i));
}
}
Ok(vec![res])
}
/// Fetch data from postgres and format it as a FileSource
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
.map(|d| {
FileSourceInner::Float(
d.n.as_ref()
.unwrap()
.to_f64()
.ok_or("could not convert decimal to f64")
.unwrap(),
)
})
.collect()
})
.collect())
}
pub struct OnChainSource {
/// Call specifications for fetching data
pub calls: Calls,
/// RPC endpoint URL for accessing the chain
pub rpc: RPCUrl,
}
impl OnChainSource {
/// Creates a new OnChainSource with multiple calls
///
/// # Arguments
/// * `calls` - Vector of call specifications
/// * `rpc` - RPC endpoint URL
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
OnChainSource {
calls: Calls::Multiple(calls),
rpc,
}
}
/// Creates a new OnChainSource with a single call
///
/// # Arguments
/// * `call` - Call specification
/// * `rpc` - RPC endpoint URL
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
OnChainSource {
calls: Calls::Single(call),
rpc,
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Create dummy local on-chain data to test the OnChain data source
/// Creates test data for the OnChain data source
/// Used for testing and development purposes
///
/// # Arguments
/// * `data` - Sample file data to use
/// * `scales` - Scaling factors for each input
/// * `shapes` - Shapes of the input tensors
/// * `rpc` - Optional RPC endpoint override
pub async fn test_from_file_data(
data: &FileSource,
scales: Vec<crate::Scale>,
@@ -399,48 +324,40 @@ impl OnChainSource {
}
}
/// Defines the view only calls to accounts to fetch the on-chain input data.
/// This data will be included as part of the first elements in the publicInputs
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
/// Specification for view-only calls to fetch on-chain data
/// Used for data attestation in smart contract verification
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct CallsToAccount {
/// A vector of tuples, where index 0 of tuples
/// are the byte strings representing the ABI encoded function calls to
/// read the data from the address. This call must return a single
/// elementary type (<https://docs.soliditylang.org/en/v0.8.20/abi-spec.html#types>).
/// The second index of the tuple is the number of decimals for f32 conversion.
/// We don't support dynamic types currently.
/// Vector of (call data, decimals) pairs
/// call_data: ABI-encoded function call
/// decimals: Number of decimal places for float conversion
pub call_data: Vec<(Call, Decimals)>,
/// Address of the contract to read the data from.
/// Contract address to call
pub address: String,
}
/// Defines a view only call to accounts to fetch the on-chain input data.
/// This data will be included as part of the first elements in the publicInputs
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
/// Specification for a single view-only call returning an array
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct CallToAccount {
/// The call_data is a byte strings representing the ABI encoded function call to
/// read the data from the address. This call must return a single array of integers that can be
/// be safely cast to the int128 type in solidity.
/// ABI-encoded function call data
pub call_data: Call,
/// The number of decimals for f32 conversion of all of the elements returned from the
/// call.
/// Number of decimal places for float conversion
pub decimals: Decimals,
/// Address of the contract to read the data from.
/// Contract address to call
pub address: String,
/// The number of elements returned from the call.
/// Expected length of returned array
pub len: usize,
}
/// Enum that defines source of the inputs/outputs to the EZKL model
/// Represents different sources of input/output data for the EZKL model
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
#[serde(untagged)]
pub enum DataSource {
/// .json File data source.
/// Data from a JSON file containing arrays of values
File(FileSource),
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
/// Data fetched from blockchain contracts
OnChain(OnChainSource),
/// Postgres DB
/// Data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DB(PostgresSource),
}
@@ -483,8 +400,7 @@ impl From<OnChainSource> for DataSource {
}
}
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
// Note: Always use JSON serialization for untagged enums
impl<'de> Deserialize<'de> for DataSource {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
@@ -492,15 +408,19 @@ impl<'de> Deserialize<'de> for DataSource {
{
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
// Try deserializing as FileSource first
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = first_try {
return Ok(DataSource::File(t));
}
// Try deserializing as OnChainSource
let second_try: Result<OnChainSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = second_try {
return Ok(DataSource::OnChain(t));
}
// Try deserializing as PostgresSource if feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
@@ -513,22 +433,29 @@ impl<'de> Deserialize<'de> for DataSource {
}
}
/// Input to graph as a datasource
/// Always use JSON serialization for GraphData. Seriously.
#[derive(Clone, Debug, Deserialize, Default, PartialEq)]
/// Container for input and output data for graph computations
///
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
pub struct GraphData {
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
/// Input data for the model/graph
/// Can be empty if inputs come from on-chain sources
pub input_data: DataSource,
/// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain).
/// Optional output data for the model/graph
/// Can be empty if outputs come from on-chain sources
pub output_data: Option<DataSource>,
}
impl UnwindSafe for GraphData {}
impl GraphData {
// not wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Convert the input data to tract data
/// Converts the input data to tract's tensor format
///
/// # Arguments
/// * `shapes` - Expected shapes for each input tensor
/// * `datum_types` - Expected data types for each input
pub fn to_tract_data(
&self,
shapes: &[Vec<usize>],
@@ -557,9 +484,14 @@ impl GraphData {
Ok(inputs)
}
// not wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Convert the tract data to tract data
/// Converts tract tensor data into GraphData format
///
/// # Arguments
/// * `tensors` - Array of tract tensors to convert
///
/// # Returns
/// A new GraphData instance containing the converted tensor data
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
use tract_onnx::prelude::DatumType;
@@ -585,7 +517,10 @@ impl GraphData {
})
}
/// Creates a new GraphData instance with given input data
///
/// # Arguments
/// * `input_data` - The input data source
pub fn new(input_data: DataSource) -> Self {
GraphData {
input_data,
@@ -593,7 +528,13 @@ impl GraphData {
}
}
/// Load the model input from a file
/// Loads graph input data from a file
///
/// # Arguments
/// * `path` - Path to the input file
///
/// # Returns
/// A new GraphData instance containing the loaded data
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let reader = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
@@ -607,23 +548,35 @@ impl GraphData {
Ok(graph_input)
}
/// Save the model input to a file
/// Saves the graph data to a file
///
/// # Arguments
/// * `path` - Path where to save the data
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::to_writer(writer, self)?;
Ok(())
}
/// Splits the input data into multiple batches based on input shapes
///
/// # Arguments
/// * `input_shapes` - Vector of shapes for each input tensor
///
/// # Returns
/// Vector of GraphData instances, one for each batch
///
/// # Errors
/// Returns error if:
/// - Data is from on-chain source
/// - Input size is not evenly divisible by batch size
pub async fn split_into_batches(
&self,
input_shapes: Vec<Vec<usize>>,
) -> Result<Vec<Self>, GraphError> {
// split input data into batches
let mut batched_inputs = vec![];
let iterable = match self {
@@ -647,10 +600,12 @@ impl GraphData {
} => data.fetch_and_format_as_file().await?,
};
// Process each input tensor according to its shape
for (i, shape) in input_shapes.iter().enumerate() {
// ensure the input is evenly divisible by batch_size
let input_size = shape.clone().iter().product::<usize>();
let input = &iterable[i];
// Validate input size is divisible by batch size
if input.len() % input_size != 0 {
return Err(GraphError::InvalidDims(
0,
@@ -658,6 +613,8 @@ impl GraphData {
.to_string(),
));
}
// Split input into batches
let mut batches = vec![];
for batch in input.chunks(input_size) {
batches.push(batch.to_vec());
@@ -665,18 +622,18 @@ impl GraphData {
batched_inputs.push(batches);
}
// now merge all the batches for each input into a vector of batches
// first assert each input has the same number of batches
// Merge batches across inputs
let num_batches = if batched_inputs.is_empty() {
0
} else {
let num_batches = batched_inputs[0].len();
// Verify all inputs have same number of batches
for input in batched_inputs.iter() {
assert_eq!(input.len(), num_batches);
}
num_batches
};
// now merge the batches
let mut input_batches = vec![];
for i in 0..num_batches {
let mut batch = vec![];
@@ -686,11 +643,12 @@ impl GraphData {
input_batches.push(DataSource::File(batch));
}
// Ensure at least one batch exists
if input_batches.is_empty() {
input_batches.push(DataSource::File(vec![vec![]]));
}
// create a new GraphWitness for each batch
// Create GraphData instance for each batch
let batches = input_batches
.into_iter()
.map(GraphData::new)
@@ -702,6 +660,7 @@ impl GraphData {
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallsToAccount {
/// Converts CallsToAccount to Python object
fn to_object(&self, py: Python) -> PyObject {
let dict = PyDict::new(py);
dict.set_item("account", &self.address).unwrap();
@@ -710,6 +669,165 @@ impl ToPyObject for CallsToAccount {
}
}
// Additional Python bindings for various types...
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postgres_source_new() {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let source = PostgresSource::new(
"localhost".to_string(),
"5432".to_string(),
"user".to_string(),
"SELECT * FROM table".to_string(),
"database".to_string(),
"password".to_string(),
);
assert_eq!(source.host, "localhost");
assert_eq!(source.port, "5432");
assert_eq!(source.user, "user");
assert_eq!(source.query, "SELECT * FROM table");
assert_eq!(source.dbname, "database");
assert_eq!(source.password, "password");
}
}
#[test]
fn test_data_source_serialization_round_trip() {
// Test backwards compatibility with old format
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
let serialized = serde_json::to_string(&source).unwrap();
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
assert_eq!(serialized, JSON);
let expect = serde_json::from_str::<DataSource>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(expect, source);
}
#[test]
fn test_graph_input_serialization_round_trip() {
// Test serialization/deserialization of graph input
let file = GraphData::new(DataSource::from(vec![vec![
0.05326242372393608,
0.07497056573629379,
0.05235547572374344,
]]));
let serialized = serde_json::to_string(&file).unwrap();
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
assert_eq!(serialized, JSON);
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(graph_input3, file);
}
#[test]
fn test_python_compat() {
// Test compatibility with mclbn256 library serialization
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
assert_eq!(format!("{:?}", source), original_addr);
}
}
/// Source data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
/// Database host address
pub host: RPCUrl,
/// Database user name
pub user: String,
/// Database password
pub password: String,
/// SQL query to execute
pub query: String,
/// Database name
pub dbname: String,
/// Database port
pub port: String,
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Creates a new PostgreSQL data source
pub fn new(
host: RPCUrl,
port: String,
user: String,
query: String,
dbname: String,
password: String,
) -> Self {
PostgresSource {
host,
user,
password,
query,
dbname,
port,
}
}
/// Fetches data from the PostgreSQL database
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// Configuration string
let config = if self.password.is_empty() {
format!(
"host={} user={} dbname={} port={}",
self.host, self.user, self.dbname, self.port
)
} else {
format!(
"host={} user={} dbname={} port={} password={}",
self.host, self.user, self.dbname, self.port, self.password
)
};
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// Extract rows from query
for row in client.query(&self.query, &[]).await? {
for i in 0..row.len() {
res.push(row.get(i));
}
}
Ok(vec![res])
}
/// Fetches and formats data as FileSource
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
.map(|d| {
FileSourceInner::Float(
d.n.as_ref()
.unwrap()
.to_f64()
.ok_or("could not convert decimal to f64")
.unwrap(),
)
})
.collect()
})
.collect())
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallToAccount {
fn to_object(&self, py: Python) -> PyObject {
@@ -744,6 +862,7 @@ impl ToPyObject for DataSource {
.unwrap();
dict.to_object(py)
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DataSource::DB(source) => {
let dict = PyDict::new(py);
dict.set_item("host", &source.host).unwrap();
@@ -768,69 +887,3 @@ impl ToPyObject for FileSourceInner {
}
}
}
impl Serialize for GraphData {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("GraphData", 4)?;
state.serialize_field("input_data", &self.input_data)?;
state.serialize_field("output_data", &self.output_data)?;
state.end()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
// this is for backwards compatibility with the old format
fn test_data_source_serialization_round_trip() {
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
let serialized = serde_json::to_string(&source).unwrap();
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
assert_eq!(serialized, JSON);
let expect = serde_json::from_str::<DataSource>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(expect, source);
}
#[test]
// this is for backwards compatibility with the old format
fn test_graph_input_serialization_round_trip() {
let file = GraphData::new(DataSource::from(vec![vec![
0.05326242372393608,
0.07497056573629379,
0.05235547572374344,
]]));
let serialized = serde_json::to_string(&file).unwrap();
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
assert_eq!(serialized, JSON);
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
.map_err(|e| e.to_string())
.unwrap();
assert_eq!(graph_input3, file);
}
// test for the compatibility with the serialized elements from the mclbn256 library
#[test]
fn test_python_compat() {
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
assert_eq!(format!("{:?}", source), original_addr);
}
}

View File

@@ -281,7 +281,7 @@ impl GraphWitness {
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let witness: GraphWitness =
serde_json::from_reader(reader).map_err(|e| Into::<GraphError>::into(e))?;
serde_json::from_reader(reader).map_err(Into::<GraphError>::into)?;
// check versions match
crate::check_version_string_matches(witness.version.as_deref().unwrap_or(""));
@@ -619,11 +619,6 @@ impl GraphSettings {
}
}
///
pub fn uses_modules(&self) -> bool {
!self.module_sizes.max_constraints() > 0
}
/// if any visibility is encrypted or hashed
pub fn module_requires_fixed(&self) -> bool {
self.run_args.input_visibility.is_hashed()
@@ -766,7 +761,7 @@ pub struct TestOnChainData {
pub data: std::path::PathBuf,
/// rpc endpoint
pub rpc: Option<String>,
///
/// data sources for the on chain data
pub data_sources: TestSources,
}
@@ -954,7 +949,7 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
_ => unreachable!("cannot load from on-chain data"),
_ => Err(GraphError::OnChainDataSource),
}
}

View File

@@ -384,8 +384,7 @@ pub struct ParsedNodes {
impl ParsedNodes {
/// Returns the number of the computational graph's inputs
pub fn num_inputs(&self) -> usize {
let input_nodes = self.inputs.iter();
input_nodes.len()
self.inputs.len()
}
/// Input types
@@ -425,8 +424,7 @@ impl ParsedNodes {
/// Returns the number of the computational graph's outputs
pub fn num_outputs(&self) -> usize {
let output_nodes = self.outputs.iter();
output_nodes.len()
self.outputs.len()
}
/// Returns shapes of the computational graph's outputs
@@ -630,12 +628,14 @@ impl Model {
let mut model = tract_onnx::onnx().model_for_read(reader)?;
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(
variables.into_iter().map(|(k, v)| (k.clone(), *v)),
);
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
for (i, id) in model.clone().inputs.iter().enumerate() {
let input = model.node_mut(id.node);
if input.outputs.len() == 0 {
return Err(GraphError::MissingOutput(id.node));
}
let mut fact: InferenceFact = input.outputs[0].fact.clone();
for (i, x) in fact.clone().shape.dims().enumerate() {
@@ -908,6 +908,7 @@ impl Model {
n.opkind = SupportedOp::Input(Input {
scale,
datum_type: inp.datum_type,
decomp: !run_args.ignore_range_check_inputs_outputs,
});
input_idx += 1;
n.out_scale = scale;
@@ -1018,6 +1019,10 @@ impl Model {
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
if vars.advices.len() < 3 {
return Err(GraphError::InsufficientAdviceColumns(3));
}
let mut base_gate = PolyConfig::configure(
meta,
vars.advices[0..2].try_into()?,
@@ -1037,6 +1042,10 @@ impl Model {
}
if settings.requires_dynamic_lookup() {
if vars.advices.len() < 6 {
return Err(GraphError::InsufficientAdviceColumns(6));
}
base_gate.configure_dynamic_lookup(
meta,
vars.advices[0..3].try_into()?,
@@ -1045,10 +1054,13 @@ impl Model {
}
if settings.requires_shuffle() {
if vars.advices.len() < 6 {
return Err(GraphError::InsufficientAdviceColumns(6));
}
base_gate.configure_shuffles(
meta,
vars.advices[0..2].try_into()?,
vars.advices[3..5].try_into()?,
vars.advices[0..3].try_into()?,
vars.advices[3..6].try_into()?,
)?;
}
@@ -1063,6 +1075,7 @@ impl Model {
/// * `vars` - The variables for the circuit.
/// * `witnessed_outputs` - The values to compare against.
/// * `constants` - The constants for the circuit.
#[allow(clippy::too_many_arguments)]
pub fn layout(
&self,
mut config: ModelConfig,
@@ -1133,8 +1146,8 @@ impl Model {
.iter()
.enumerate()
.map(|(i, output)| {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
tol.scale = scale_to_multiplier(output_scales[i]).into();
let comparators = if run_args.output_visibility == Visibility::Public {
let res = vars
@@ -1157,7 +1170,10 @@ impl Model {
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::RangeCheck(tolerance)),
Box::new(HybridOp::Output {
tol,
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
)
.map_err(|e| e.into())
})
@@ -1434,13 +1450,16 @@ impl Model {
.into();
comparator.reshape(output.dims())?;
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut tol = run_args.tolerance;
tol.scale = scale_to_multiplier(output_scales[i]).into();
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
Box::new(HybridOp::RangeCheck(tolerance)),
Box::new(HybridOp::Output {
tol,
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
)
})
.collect::<Result<Vec<_>, _>>();
@@ -1462,7 +1481,7 @@ impl Model {
.iter()
.map(|x| {
x.get_felt_evals()
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
.unwrap_or_else(|_| Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
})
.collect();
@@ -1532,6 +1551,7 @@ impl Model {
let mut op = crate::circuit::Constant::new(
c.quantized_values.clone(),
c.raw_values.clone(),
c.decomp,
);
op.pre_assign(consts[const_idx].clone());
n.opkind = SupportedOp::Constant(op);

View File

@@ -284,7 +284,6 @@ impl GraphModules {
log::error!("Poseidon config not initialized");
return Err(Error::Synthesis);
}
// If the module is encrypted, then we need to encrypt the inputs
}
Ok(())

View File

@@ -1,10 +1,19 @@
// Import dependencies for scaling operations
use super::scale_to_multiplier;
// Import ONNX-specific utilities when EZKL feature is enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::utilities::node_output_shapes;
// Import scale management types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
// Import visibility settings for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::Visibility;
// Import operation types for different circuit components
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
@@ -13,28 +22,49 @@ use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
// Import graph error types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::errors::GraphError;
// Import ONNX operation conversion utilities
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::new_op_from_onnx;
// Import tensor error handling
use crate::tensor::TensorError;
// Import curve-specific field type
use halo2curves::bn256::Fr as Fp;
// Import logging for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
// Import serialization traits
use serde::Deserialize;
use serde::Serialize;
// Import data structures for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::collections::BTreeMap;
// Import formatting traits for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::fmt;
// Import table display formatting for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tabled::Tabled;
// Import ONNX-specific types and traits for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::{
self,
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
};
/// Helper function to format vectors for display
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
if !v.is_empty() {
@@ -44,29 +74,35 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
}
}
/// Helper function to format operation kinds for display
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn display_opkind(v: &SupportedOp) -> String {
v.as_string()
}
/// A wrapper for an operation that has been rescaled.
/// A wrapper for an operation that has been rescaled to handle different precision requirements.
/// This enables operations to work with inputs that have been scaled to different fixed-point representations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Rescaled {
/// The operation that has to be rescaled.
/// The underlying operation that needs to be rescaled
pub inner: Box<SupportedOp>,
/// The scale of the operation's inputs.
/// Vector of (index, scale) pairs defining how each input should be scaled
pub scale: Vec<(usize, u128)>,
}
/// Implementation of the Op trait for Rescaled operations
impl Op<Fp> for Rescaled {
/// Convert to Any type for runtime type checking
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Get string representation of the operation
fn as_string(&self) -> String {
format!("RESCALED INPUT ({})", self.inner.as_string())
}
/// Calculate output scale based on input scales
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let in_scales = in_scales
.into_iter()
@@ -77,6 +113,7 @@ impl Op<Fp> for Rescaled {
Op::<Fp>::out_scale(&*self.inner, in_scales)
}
/// Layout the operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -93,28 +130,40 @@ impl Op<Fp> for Rescaled {
self.inner.layout(config, region, res)
}
/// Create a cloned boxed copy of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
Box::new(self.clone()) // Forward to the derive(Clone) impl
Box::new(self.clone())
}
}
/// A wrapper for an operation that has been rescaled.
/// A wrapper for operations that require scale rebasing
/// This handles cases where operation scales need to be adjusted to a target scale
/// while preserving the numerical relationships
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RebaseScale {
/// The operation that has to be rescaled.
/// The operation that needs to be rescaled
pub inner: Box<SupportedOp>,
/// rebase op
/// Operation used for rebasing, typically division
pub rebase_op: HybridOp,
/// scale being rebased to
/// Scale that we're rebasing to
pub target_scale: i32,
/// The original scale of the operation's inputs.
/// Original scale of operation's inputs before rebasing
pub original_scale: i32,
/// multiplier
/// Scaling multiplier used in rebasing
pub multiplier: f64,
}
impl RebaseScale {
/// Creates a rebased version of an operation if needed
///
/// # Arguments
/// * `inner` - Operation to potentially rebase
/// * `global_scale` - Base scale for the system
/// * `op_out_scale` - Current output scale of the operation
/// * `scale_rebase_multiplier` - Factor determining when rebasing should occur
///
/// # Returns
/// Original or rebased operation depending on scale relationships
pub fn rebase(
inner: SupportedOp,
global_scale: crate::Scale,
@@ -155,7 +204,15 @@ impl RebaseScale {
}
}
/// Creates a rebased operation with increased scale
///
/// # Arguments
/// * `inner` - Operation to potentially rebase
/// * `target_scale` - Scale to rebase to
/// * `op_out_scale` - Current output scale of the operation
///
/// # Returns
/// Original or rebased operation with increased scale
pub fn rebase_up(
inner: SupportedOp,
target_scale: crate::Scale,
@@ -192,10 +249,12 @@ impl RebaseScale {
}
impl Op<Fp> for RebaseScale {
/// Convert to Any type for runtime type checking
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Get string representation of the operation
fn as_string(&self) -> String {
format!(
"REBASED (div={:?}, rebasing_op={}) ({})",
@@ -205,10 +264,12 @@ impl Op<Fp> for RebaseScale {
)
}
/// Calculate output scale based on input scales
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.target_scale)
}
/// Layout the operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -222,34 +283,40 @@ impl Op<Fp> for RebaseScale {
self.rebase_op.layout(config, region, &[original_res])
}
/// Create a cloned boxed copy of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
Box::new(self.clone()) // Forward to the derive(Clone) impl
Box::new(self.clone())
}
}
/// A single operation in a [crate::graph::Model].
/// Represents all supported operation types in the circuit
/// Each variant encapsulates a different type of operation with specific behavior
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum SupportedOp {
/// A linear operation.
/// Linear operations (polynomial-based)
Linear(PolyOp),
/// A nonlinear operation.
/// Nonlinear operations requiring lookup tables
Nonlinear(LookupOp),
/// A hybrid operation.
/// Mixed operations combining different approaches
Hybrid(HybridOp),
///
/// Input values to the circuit
Input(Input),
///
/// Constant values in the circuit
Constant(Constant<Fp>),
///
/// Placeholder for unsupported operations
Unknown(Unknown),
///
/// Operations requiring rescaling of inputs
Rescaled(Rescaled),
///
/// Operations requiring scale rebasing
RebaseScale(RebaseScale),
}
impl SupportedOp {
/// Checks if the operation is a lookup operation
///
/// # Returns
/// * `true` if operation requires lookup table
/// * `false` otherwise
pub fn is_lookup(&self) -> bool {
match self {
SupportedOp::Nonlinear(_) => true,
@@ -257,7 +324,12 @@ impl SupportedOp {
_ => false,
}
}
/// Returns input operation if this is an input
///
/// # Returns
/// * `Some(Input)` if this is an input operation
/// * `None` otherwise
pub fn get_input(&self) -> Option<Input> {
match self {
SupportedOp::Input(op) => Some(op.clone()),
@@ -265,7 +337,11 @@ impl SupportedOp {
}
}
/// Returns reference to rebased operation if this is a rebased operation
///
/// # Returns
/// * `Some(&RebaseScale)` if this is a rebased operation
/// * `None` otherwise
pub fn get_rebased(&self) -> Option<&RebaseScale> {
match self {
SupportedOp::RebaseScale(op) => Some(op),
@@ -273,7 +349,11 @@ impl SupportedOp {
}
}
/// Returns reference to lookup operation if this is a lookup operation
///
/// # Returns
/// * `Some(&LookupOp)` if this is a lookup operation
/// * `None` otherwise
pub fn get_lookup(&self) -> Option<&LookupOp> {
match self {
SupportedOp::Nonlinear(op) => Some(op),
@@ -281,7 +361,11 @@ impl SupportedOp {
}
}
/// Returns reference to constant if this is a constant
///
/// # Returns
/// * `Some(&Constant)` if this is a constant
/// * `None` otherwise
pub fn get_constant(&self) -> Option<&Constant<Fp>> {
match self {
SupportedOp::Constant(op) => Some(op),
@@ -289,7 +373,11 @@ impl SupportedOp {
}
}
/// Returns mutable reference to constant if this is a constant
///
/// # Returns
/// * `Some(&mut Constant)` if this is a constant
/// * `None` otherwise
pub fn get_mutable_constant(&mut self) -> Option<&mut Constant<Fp>> {
match self {
SupportedOp::Constant(op) => Some(op),
@@ -297,18 +385,19 @@ impl SupportedOp {
}
}
/// Creates a homogeneously rescaled version of this operation if needed
/// Only available with EZKL feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn homogenous_rescale(
&self,
in_scales: Vec<crate::Scale>,
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let inputs_to_scale = self.requires_homogenous_input_scales();
// creates a rescaled op if the inputs are not homogenous
let op = self.clone_dyn();
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
}
/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
/// Returns reference to underlying Op implementation
fn as_op(&self) -> &dyn Op<Fp> {
match self {
SupportedOp::Linear(op) => op,
@@ -322,9 +411,10 @@ impl SupportedOp {
}
}
/// check if is the identity operation
/// Checks if this is an identity operation
///
/// # Returns
/// * `true` if the operation is the identity operation
/// * `true` if this operation passes input through unchanged
/// * `false` otherwise
pub fn is_identity(&self) -> bool {
match self {
@@ -361,9 +451,11 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
if let Some(op) = value.as_any().downcast_ref::<Unknown>() {
return SupportedOp::Unknown(op.clone());
};
if let Some(op) = value.as_any().downcast_ref::<Rescaled>() {
return SupportedOp::Rescaled(op.clone());
};
if let Some(op) = value.as_any().downcast_ref::<RebaseScale>() {
return SupportedOp::RebaseScale(op.clone());
};
@@ -375,6 +467,7 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
}
impl Op<Fp> for SupportedOp {
/// Layout this operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -384,54 +477,61 @@ impl Op<Fp> for SupportedOp {
self.as_op().layout(config, region, values)
}
/// Check if this is an input operation
fn is_input(&self) -> bool {
self.as_op().is_input()
}
/// Check if this is a constant operation
fn is_constant(&self) -> bool {
self.as_op().is_constant()
}
/// Get which inputs require homogeneous scales
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
self.as_op().requires_homogenous_input_scales()
}
/// Create a clone of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
self.as_op().clone_dyn()
}
/// Get string representation
fn as_string(&self) -> String {
self.as_op().as_string()
}
/// Convert to Any type
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Calculate output scale from input scales
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
self.as_op().out_scale(in_scales)
}
}
/// A node's input is a tensor from another node's output.
/// Represents a connection to another node's output
/// First element is node index, second is output slot index
pub type Outlet = (usize, usize);
/// A single operation in a [crate::graph::Model].
/// Represents a single computational node in the circuit graph
/// Contains all information needed to execute and connect operations
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Node {
/// [Op] i.e what operation this node represents.
/// The operation this node performs
pub opkind: SupportedOp,
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
/// Fixed point scale factor for this node's output
pub out_scale: i32,
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
// but in_dim is [in], out_dim is [out]
/// The indices of the node's inputs.
/// Connections to other nodes' outputs that serve as inputs
pub inputs: Vec<Outlet>,
/// Dimensions of output.
/// Shape of this node's output tensor
pub out_dims: Vec<usize>,
/// The node's unique identifier.
/// Unique identifier for this node
pub idx: usize,
/// The node's num of uses
/// Number of times this node's output is used
pub num_uses: usize,
}
@@ -469,12 +569,19 @@ impl PartialEq for Node {
}
impl Node {
/// Converts a tract [OnnxNode] into an ezkl [Node].
/// # Arguments:
/// * `node` - [OnnxNode]
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
/// * `public_params` - flag if parameters of model are public
/// * `idx` - The node's unique identifier.
/// Creates a new Node from an ONNX node
/// Only available when EZKL feature is enabled
///
/// # Arguments
/// * `node` - Source ONNX node
/// * `other_nodes` - Map of existing nodes in the graph
/// * `scales` - Scale factors for variables
/// * `idx` - Unique identifier for this node
/// * `symbol_values` - ONNX symbol values
/// * `run_args` - Runtime configuration arguments
///
/// # Returns
/// New Node instance or error if creation fails
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[allow(clippy::too_many_arguments)]
pub fn new(
@@ -612,16 +719,14 @@ impl Node {
})
}
/// check if it is a softmax node
/// Check if this node performs softmax operation
pub fn is_softmax(&self) -> bool {
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
true
} else {
false
}
matches!(self.opkind, SupportedOp::Hybrid(HybridOp::Softmax { .. }))
}
}
/// Helper function to rescale constants that are only used once
/// Only available when EZKL feature is enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn rescale_const_with_single_use(
constant: &mut Constant<Fp>,

View File

@@ -44,11 +44,11 @@ use tract_onnx::tract_hir::{
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
};
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
/// Arguments
///
/// * `vec` - the vector to quantize.
/// * `dims` - the dimensionality of the resulting [Tensor].
/// * `elem` - the element to quantize.
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float(
@@ -59,7 +59,7 @@ pub fn quantize_float(
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
if *elem > max_value || *elem < -max_value {
return Err(TensorError::SigBitTruncationError);
}
@@ -85,7 +85,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
f64::powf(2., scale as f64)
}
/// Converts a scale (log base 2) to a fixed point multiplier.
/// Converts a fixed point multiplier to a scale (log base 2).
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
mult.log2().round() as crate::Scale
}
@@ -228,10 +228,7 @@ pub fn extract_tensor_value(
.iter()
.map(|x| match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
},
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
})
.collect();
@@ -277,11 +274,9 @@ pub fn new_op_from_onnx(
symbol_values: &SymbolValues,
run_args: &crate::RunArgs,
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
use std::f64::consts::E;
use tract_onnx::tract_core::ops::array::Trilu;
use crate::circuit::InputType;
use std::f64::consts::E;
use tract_onnx::tract_core::ops::array::Trilu;
let input_scales = inputs
.iter()
@@ -312,6 +307,9 @@ pub fn new_op_from_onnx(
let mut deleted_indices = vec![];
let node = match node.op().name().as_ref() {
"ShiftLeft" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
};
// load shift amount
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
@@ -324,10 +322,13 @@ pub fn new_op_from_onnx(
out_scale: Some(input_scales[0] - raw_values[0] as i32),
})
} else {
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
}
}
"ShiftRight" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
};
// load shift amount
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
@@ -340,7 +341,7 @@ pub fn new_op_from_onnx(
out_scale: Some(input_scales[0] + raw_values[0] as i32),
})
} else {
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
}
}
"MultiBroadcastTo" => {
@@ -363,7 +364,10 @@ pub fn new_op_from_onnx(
}
}
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
if input_ops.len() != 3 {
return Err(GraphError::InvalidDims(idx, "range".to_string()));
}
let input_ops = input_ops
.iter()
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
@@ -378,7 +382,11 @@ pub fn new_op_from_onnx(
// Quantize the raw value (integers)
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
let c = crate::circuit::ops::Constant::new(
quantized_value,
raw_value,
!run_args.ignore_range_check_inputs_outputs,
);
// Create a constant op
SupportedOp::Constant(c)
}
@@ -419,6 +427,10 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
}
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| {
@@ -436,6 +448,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: false,
}));
inputs[1].bump_scale(0);
}
@@ -447,8 +460,17 @@ pub fn new_op_from_onnx(
"Topk" => {
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
};
// if param_visibility.is_public() {
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
if c.raw_values.len() != 1 {
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
}
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
c.raw_values.map(|x| x as usize)[0]
@@ -488,6 +510,10 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -499,6 +525,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -522,6 +549,9 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -532,6 +562,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -555,6 +586,9 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
@@ -566,6 +600,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -589,6 +624,9 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -600,6 +638,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -674,7 +713,11 @@ pub fn new_op_from_onnx(
constant_scale,
&run_args.param_visibility,
)?;
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
let c = crate::circuit::ops::Constant::new(
quantized_value,
raw_value,
run_args.ignore_range_check_inputs_outputs,
);
// Create a constant op
SupportedOp::Constant(c)
}
@@ -684,7 +727,9 @@ pub fn new_op_from_onnx(
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
assert_eq!(axes.len(), 1, "only support argmax over one axis");
if axes.len() != 1 {
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
}
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
}
@@ -694,7 +739,9 @@ pub fn new_op_from_onnx(
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
assert_eq!(axes.len(), 1, "only support argmin over one axis");
if axes.len() != 1 {
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
}
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
}
@@ -803,6 +850,9 @@ pub fn new_op_from_onnx(
}
}
"Recip" => {
if inputs.len() != 1 {
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
};
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
// If the input scale is larger than the params scale
@@ -846,6 +896,9 @@ pub fn new_op_from_onnx(
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Rsqrt" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
};
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Rsqrt {
@@ -927,13 +980,19 @@ pub fn new_op_from_onnx(
DatumType::F64 => (scales.input, InputType::F64),
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
};
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
SupportedOp::Input(crate::circuit::ops::Input {
scale,
datum_type,
decomp: !run_args.ignore_range_check_inputs_outputs,
})
}
"Cast" => {
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
let dt = op.to;
assert_eq!(input_scales.len(), 1);
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
};
match dt {
DatumType::Bool
@@ -983,6 +1042,11 @@ pub fn new_op_from_onnx(
if const_idx.len() == 1 {
let const_idx = const_idx[0];
if inputs.len() <= const_idx {
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
}
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
// if not divisible by 2 then we need to add a range check
@@ -1057,6 +1121,9 @@ pub fn new_op_from_onnx(
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
}
};
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
}
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
@@ -1096,22 +1163,42 @@ pub fn new_op_from_onnx(
pool_dims: kernel_shape.to_vec(),
})
}
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Round" => SupportedOp::Hybrid(HybridOp::Round {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Ceil" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
}
SupportedOp::Hybrid(HybridOp::Ceil {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Floor" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
}
SupportedOp::Hybrid(HybridOp::Floor {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Round" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "round".to_string()));
}
SupportedOp::Hybrid(HybridOp::Round {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"RoundHalfToEven" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
}
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Sign" => SupportedOp::Linear(PolyOp::Sign),
"Pow" => {
// Extract the slope layer hyperparams from a const
@@ -1121,7 +1208,9 @@ pub fn new_op_from_onnx(
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar pow")
return Err(GraphError::NonScalarPower);
} else if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
}
let exponent = c.raw_values[0];
@@ -1138,7 +1227,9 @@ pub fn new_op_from_onnx(
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar base")
return Err(GraphError::NonScalarBase);
} else if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
}
let base = c.raw_values[0];
@@ -1148,10 +1239,14 @@ pub fn new_op_from_onnx(
base: base.into(),
})
} else {
unimplemented!("only support constant base or pow for now")
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
}
}
"Div" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "div".to_string()));
}
let const_idx = inputs
.iter()
.enumerate()
@@ -1159,14 +1254,15 @@ pub fn new_op_from_onnx(
.map(|(i, _)| i)
.collect::<Vec<_>>();
if const_idx.len() > 1 {
if const_idx.len() > 1 || const_idx.is_empty() {
return Err(GraphError::InvalidDims(idx, "div".to_string()));
}
let const_idx = const_idx[0];
if const_idx != 1 {
unimplemented!("only support div with constant as second input")
return Err(GraphError::MisformedParams(
"only support div with constant as second input".to_string(),
));
}
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
@@ -1176,14 +1272,28 @@ pub fn new_op_from_onnx(
// get the non constant index
let denom = c.raw_values[0];
SupportedOp::Hybrid(HybridOp::Div {
let op = SupportedOp::Hybrid(HybridOp::Div {
denom: denom.into(),
})
});
// if the input is scale 0 we re up to the max scale
if input_scales[0] == 0 {
SupportedOp::Rescaled(Rescaled {
inner: Box::new(op),
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
})
} else {
op
}
} else {
unimplemented!("only support non zero divisors of size 1")
return Err(GraphError::MisformedParams(
"only support non zero divisors of size 1".to_string(),
));
}
} else {
unimplemented!("only support div with constant as second input")
return Err(GraphError::MisformedParams(
"only support div with constant as second input".to_string(),
));
}
}
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
@@ -1323,7 +1433,7 @@ pub fn new_op_from_onnx(
if !resize_node.contains("interpolator: Nearest")
&& !resize_node.contains("nearest: Floor")
{
unimplemented!("Only nearest neighbor interpolation is supported")
return Err(GraphError::InvalidInterpolation);
}
// check if optional scale factor is present
if inputs.len() != 2 && inputs.len() != 3 {
@@ -1427,6 +1537,10 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Reshape(output_shape))
}
"Flatten" => {
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
};
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
SupportedOp::Linear(PolyOp::Flatten(new_dims))
}
@@ -1500,12 +1614,10 @@ pub fn homogenize_input_scales(
input_scales: Vec<crate::Scale>,
inputs_to_scale: Vec<usize>,
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let relevant_input_scales = input_scales
.clone()
.into_iter()
.enumerate()
.filter(|(idx, _)| inputs_to_scale.contains(idx))
.map(|(_, scale)| scale)
let relevant_input_scales = inputs_to_scale
.iter()
.filter(|idx| input_scales.len() > **idx)
.map(|&idx| input_scales[idx])
.collect_vec();
if inputs_to_scale.is_empty() {
@@ -1546,10 +1658,30 @@ pub fn homogenize_input_scales(
}
#[cfg(test)]
/// tests for the utility module
pub mod tests {
use super::*;
// quantization tests
#[test]
fn test_quantize_tensor() {
let tensor: Tensor<f32> = (0..10).map(|x| x as f32).into();
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
let scale = 0;
let visibility = &Visibility::Public;
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
assert_eq!(quantized.len(), 10);
assert_eq!(quantized, reference);
}
#[test]
fn test_quantize_edge_cases() {
assert_eq!(quantize_float(&f64::NAN, 0.0, 0).unwrap(), 0);
assert!(quantize_float(&f64::INFINITY, 0.0, 0).is_err());
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
}
#[test]
fn test_flatten_valtensors() {
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();

View File

@@ -11,35 +11,34 @@ use log::debug;
use pyo3::{
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
use self::errors::GraphError;
use super::*;
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
/// Defines the visibility level of values within the zero-knowledge circuit
/// Controls how values are handled during proof generation and verification
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum Visibility {
/// Mark an item as private to the prover (not in the proof submitted for verification)
/// Value is private to the prover and not included in proof
#[default]
Private,
/// Mark an item as public (sent in the proof submitted for verification)
/// Value is public and included in proof for verification
Public,
/// Mark an item as publicly committed to (hash sent in the proof submitted for verification)
/// Value is hashed and the hash is included in proof
Hashed {
/// Whether the hash is used as an instance (sent in the proof submitted for verification)
/// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph
/// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph
/// Controls how the hash is handled in proof
/// true - hash is included directly in proof (public)
/// false - hash is used as advice and passed to computational graph
hash_is_public: bool,
///
/// Specifies which outputs this hash affects
outlets: Vec<usize>,
},
/// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification)
/// Value is committed using KZG commitment scheme
KZGCommit,
/// assigned as a constant in the circuit
/// Value is assigned as a constant in the circuit
Fixed,
}
@@ -66,15 +65,17 @@ impl Display for Visibility {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Visibility {
/// Converts visibility to command line flags
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl<'a> From<&'a str> for Visibility {
/// Converts string representation to Visibility
fn from(s: &'a str) -> Self {
if s.contains("hashed/private") {
// split on last occurrence of '/'
// Split on last occurrence of '/'
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
let outlets = outlets
.trim_start_matches('/')
@@ -106,8 +107,8 @@ impl<'a> From<&'a str> for Visibility {
}
#[cfg(feature = "python-bindings")]
/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python)
impl IntoPy<PyObject> for Visibility {
/// Converts Visibility to Python object
fn into_py(self, py: Python) -> PyObject {
match self {
Visibility::Private => "private".to_object(py),
@@ -134,14 +135,13 @@ impl IntoPy<PyObject> for Visibility {
}
#[cfg(feature = "python-bindings")]
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
impl<'source> FromPyObject<'source> for Visibility {
/// Extracts Visibility from Python object
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let strval = String::extract_bound(ob)?;
let strval = strval.as_str();
if strval.contains("hashed/private") {
// split on last occurence of '/'
let (_, outlets) = strval.split_at(strval.rfind('/').unwrap());
let outlets = outlets
.trim_start_matches('/')
@@ -174,29 +174,32 @@ impl<'source> FromPyObject<'source> for Visibility {
}
impl Visibility {
#[allow(missing_docs)]
/// Returns true if visibility is Fixed
pub fn is_fixed(&self) -> bool {
matches!(&self, Visibility::Fixed)
}
#[allow(missing_docs)]
/// Returns true if visibility is Private or hashed private
pub fn is_private(&self) -> bool {
matches!(&self, Visibility::Private) || self.is_hashed_private()
}
#[allow(missing_docs)]
/// Returns true if visibility is Public
pub fn is_public(&self) -> bool {
matches!(&self, Visibility::Public)
}
#[allow(missing_docs)]
/// Returns true if visibility involves hashing
pub fn is_hashed(&self) -> bool {
matches!(&self, Visibility::Hashed { .. })
}
#[allow(missing_docs)]
/// Returns true if visibility uses KZG commitment
pub fn is_polycommit(&self) -> bool {
matches!(&self, Visibility::KZGCommit)
}
#[allow(missing_docs)]
/// Returns true if visibility is hashed with public hash
pub fn is_hashed_public(&self) -> bool {
if let Visibility::Hashed {
hash_is_public: true,
@@ -207,7 +210,8 @@ impl Visibility {
}
false
}
#[allow(missing_docs)]
/// Returns true if visibility is hashed with private hash
pub fn is_hashed_private(&self) -> bool {
if let Visibility::Hashed {
hash_is_public: false,
@@ -219,11 +223,12 @@ impl Visibility {
false
}
#[allow(missing_docs)]
/// Returns true if visibility requires additional processing
pub fn requires_processing(&self) -> bool {
matches!(&self, Visibility::Hashed { .. }) | matches!(&self, Visibility::KZGCommit)
}
#[allow(missing_docs)]
/// Returns vector of output indices that this visibility setting affects
pub fn overwrites_inputs(&self) -> Vec<usize> {
if let Visibility::Hashed { outlets, .. } = self {
return outlets.clone();
@@ -232,14 +237,14 @@ impl Visibility {
}
}
/// Represents the scale of the model input, model parameters.
/// Manages scaling factors for different parts of the model
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
pub struct VarScales {
///
/// Scale factor for input values
pub input: crate::Scale,
///
/// Scale factor for parameter values
pub params: crate::Scale,
///
/// Multiplier for scale rebasing
pub rebase_multiplier: u32,
}
@@ -250,17 +255,17 @@ impl std::fmt::Display for VarScales {
}
impl VarScales {
///
/// Returns maximum scale value
pub fn get_max(&self) -> crate::Scale {
std::cmp::max(self.input, self.params)
}
///
/// Returns minimum scale value
pub fn get_min(&self) -> crate::Scale {
std::cmp::min(self.input, self.params)
}
/// Place in [VarScales] struct.
/// Creates VarScales from runtime arguments
pub fn from_args(args: &RunArgs) -> Self {
Self {
input: args.input_scale,
@@ -270,16 +275,17 @@ impl VarScales {
}
}
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
/// Controls visibility settings for different parts of the model
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
pub struct VarVisibility {
/// Input to the model or computational graph
/// Visibility of model inputs
pub input: Visibility,
/// Parameters, such as weights and biases, in the model
/// Visibility of model parameters (weights, biases)
pub params: Visibility,
/// Output of the model or computational graph
/// Visibility of model outputs
pub output: Visibility,
}
impl std::fmt::Display for VarVisibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
@@ -301,8 +307,7 @@ impl Default for VarVisibility {
}
impl VarVisibility {
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
/// Place in [VarVisibility] struct.
/// Creates visibility settings from runtime arguments
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
let input_vis = &args.input_visibility;
let params_vis = &args.param_visibility;
@@ -313,17 +318,17 @@ impl VarVisibility {
}
if !output_vis.is_public()
& !params_vis.is_public()
& !input_vis.is_public()
& !output_vis.is_fixed()
& !params_vis.is_fixed()
& !input_vis.is_fixed()
& !output_vis.is_hashed()
& !params_vis.is_hashed()
& !input_vis.is_hashed()
& !output_vis.is_polycommit()
& !params_vis.is_polycommit()
& !input_vis.is_polycommit()
&& !params_vis.is_public()
&& !input_vis.is_public()
&& !output_vis.is_fixed()
&& !params_vis.is_fixed()
&& !input_vis.is_fixed()
&& !output_vis.is_hashed()
&& !params_vis.is_hashed()
&& !input_vis.is_hashed()
&& !output_vis.is_polycommit()
&& !params_vis.is_polycommit()
&& !input_vis.is_polycommit()
{
return Err(GraphError::Visibility);
}
@@ -335,17 +340,17 @@ impl VarVisibility {
}
}
/// A wrapper for holding all columns that will be assigned to by a model.
/// Container for circuit columns used by a model
#[derive(Clone, Debug)]
pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
#[allow(missing_docs)]
/// Advice columns for circuit assignments
pub advices: Vec<VarTensor>,
#[allow(missing_docs)]
/// Optional instance column for public inputs
pub instance: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
/// Get instance col
/// Gets reference to instance column if it exists
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
if let Some(instance) = &self.instance {
match instance {
@@ -357,14 +362,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
/// Set the initial instance offset
/// Sets initial offset for instance values
pub fn set_initial_instance_offset(&mut self, offset: usize) {
if let Some(instance) = &mut self.instance {
instance.set_initial_instance_offset(offset);
}
}
/// Get the total instance len
/// Gets total length of instance data
pub fn get_instance_len(&self) -> usize {
if let Some(instance) = &self.instance {
instance.get_total_instance_len()
@@ -373,21 +378,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
/// Increment the instance offset
/// Increments instance index
pub fn increment_instance_idx(&mut self) {
if let Some(instance) = &mut self.instance {
instance.increment_idx();
}
}
/// Reset the instance offset
/// Sets instance index to specific value
pub fn set_instance_idx(&mut self, val: usize) {
if let Some(instance) = &mut self.instance {
instance.set_idx(val);
}
}
/// Get the instance offset
/// Gets current instance index
pub fn get_instance_idx(&self) -> usize {
if let Some(instance) = &self.instance {
instance.get_idx()
@@ -396,7 +401,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
///
/// Initializes instance column with specified dimensions and scale
pub fn instantiate_instance(
&mut self,
cs: &mut ConstraintSystem<F>,
@@ -417,7 +422,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
};
}
/// Allocate all columns that will be assigned to by a model.
/// Creates new ModelVars with allocated columns based on settings
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
debug!("number of blinding factors: {}", cs.blinding_factors());
@@ -435,7 +440,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
.collect_vec();
if requires_dynamic_lookup || requires_shuffle {
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
let num_cols = 3;
for _ in 0..num_cols {
let dynamic_lookup =
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);

View File

@@ -28,6 +28,9 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use log::warn;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc as _;
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
@@ -99,7 +102,7 @@ use circuit::{table::Range, CheckMode, Tolerance};
use clap::Args;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use fieldutils::IntegerRep;
use graph::Visibility;
use graph::{Visibility, MAX_PUBLIC_SRS};
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};
@@ -165,7 +168,6 @@ pub mod srs_sha;
pub mod tensor;
#[cfg(feature = "ios-bindings")]
uniffi::setup_scaffolding!();
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use lazy_static::lazy_static;
@@ -180,11 +182,9 @@ lazy_static! {
.unwrap_or("8000".to_string())
.parse()
.unwrap();
/// The serialization format for the keys
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
.unwrap_or("raw-bytes".to_string());
}
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
@@ -266,76 +266,102 @@ impl From<String> for Commitments {
}
/// Parameters specific to a proving run
///
/// RunArgs contains all configuration parameters needed to control the proving process,
/// including scaling factors, visibility settings, and circuit parameters.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
derive(Args, ToFlags)
)]
pub struct RunArgs {
/// The tolerance for error on model outputs
/// Error tolerance for model outputs
/// Only applicable when outputs are public
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
pub tolerance: Tolerance,
/// The denominator in the fixed point representation used when quantizing inputs
/// Fixed point scaling factor for quantizing inputs
/// Higher values provide more precision but increase circuit complexity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub input_scale: Scale,
/// The denominator in the fixed point representation used when quantizing parameters
/// Fixed point scaling factor for quantizing parameters
/// Higher values provide more precision but increase circuit complexity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub param_scale: Scale,
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
/// Scale rebase threshold multiplier
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
/// Advanced parameter that should be used with caution
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
pub scale_rebase_multiplier: u32,
/// The min and max elements in the lookup table input column
/// Range for lookup table input column values
/// Specified as (min, max) pair
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
pub lookup_range: Range,
/// The log_2 number of rows
/// Log2 of the number of rows in the circuit
/// Controls circuit size and proving time
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
pub logrows: u32,
/// The log_2 number of rows
/// Number of inner columns per block
/// Affects circuit layout and efficiency
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
pub num_inner_cols: usize,
/// Hand-written parser for graph variables, eg. batch_size=1
/// Graph variables for parameterizing the computation
/// Format: "name->value", e.g. "batch_size->1"
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
pub variables: Vec<(String, usize)>,
/// Flags whether inputs are public, private, fixed, hashed, polycommit
/// Visibility setting for input values
/// Controls whether inputs are public or private in the circuit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
pub input_visibility: Visibility,
/// Flags whether outputs are public, private, fixed, hashed, polycommit
/// Visibility setting for output values
/// Controls whether outputs are public or private in the circuit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
pub output_visibility: Visibility,
/// Flags whether params are fixed, private, hashed, polycommit
/// Visibility setting for parameters
/// Controls how parameters are handled in the circuit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
pub param_visibility: Visibility,
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// Should constants with 0.0 fraction be rebased to scale 0
/// Whether to rebase constants with zero fractional part to scale 0
/// Can improve efficiency for integer constants
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub rebase_frac_zero_constants: bool,
/// check mode (safe, unsafe, etc)
/// Circuit checking mode
/// Controls level of constraint verification
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
pub check_mode: CheckMode,
/// commitment scheme
/// Commitment scheme for circuit proving
/// Affects proof size and verification time
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
pub commitment: Option<Commitments>,
/// the base used for decompositions
/// Base for number decomposition
/// Must be a power of 2
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
pub decomp_base: usize,
/// Number of decomposition legs
/// Controls decomposition granularity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
/// the number of legs used for decompositions
pub decomp_legs: usize,
/// Whether to use bounded lookup for logarithm computation
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// use unbounded lookup for the log
pub bounded_log_lookup: bool,
/// Range check inputs and outputs (turn off if the inputs are felts)
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub ignore_range_check_inputs_outputs: bool,
}
impl Default for RunArgs {
/// Creates a new RunArgs instance with default values
///
/// Default configuration is optimized for common use cases
/// while maintaining reasonable proving time and circuit size
fn default() -> Self {
Self {
bounded_log_lookup: false,
@@ -355,54 +381,138 @@ impl Default for RunArgs {
commitment: None,
decomp_base: 16384,
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
}
}
}
impl RunArgs {
/// Validates the RunArgs configuration
///
/// Performs comprehensive validation of all parameters to ensure they are within
/// acceptable ranges and follow required constraints. Returns accumulated errors
/// if any validations fail.
///
/// # Returns
/// - Ok(()) if all validations pass
/// - Err(String) with detailed error message if any validation fails
pub fn validate(&self) -> Result<(), String> {
let mut errors = Vec::new();
// Visibility validations
if self.param_visibility == Visibility::Public {
return Err(
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
.into(),
errors.push(
"Parameters cannot be public instances. Use 'fixed' or 'kzgcommit' instead"
.to_string(),
);
}
if self.scale_rebase_multiplier < 1 {
return Err("scale_rebase_multiplier must be >= 1".into());
}
if self.lookup_range.0 > self.lookup_range.1 {
return Err("lookup_range min is greater than max".into());
}
if self.logrows < 1 {
return Err("logrows must be >= 1".into());
}
if self.num_inner_cols < 1 {
return Err("num_inner_cols must be >= 1".into());
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
}
// Scale validations
if self.scale_rebase_multiplier < 1 {
errors.push("scale_rebase_multiplier must be >= 1".to_string());
}
// if any of the scales are too small
if self.input_scale < 8 || self.param_scale < 8 {
warn!("low scale values (<8) may impact precision");
}
// Lookup range validations
if self.lookup_range.0 > self.lookup_range.1 {
errors.push(format!(
"Invalid lookup range: min ({}) is greater than max ({})",
self.lookup_range.0, self.lookup_range.1
));
}
// Size validations
if self.logrows < 1 {
errors.push("logrows must be >= 1".to_string());
}
if self.num_inner_cols < 1 {
errors.push("num_inner_cols must be >= 1".to_string());
}
let batch_size = self.variables.iter().find(|(name, _)| name == "batch_size");
if let Some(batch_size) = batch_size {
if batch_size.1 == 0 {
errors.push("'batch_size' cannot be 0".to_string());
}
}
// Decomposition validations
if self.decomp_base == 0 {
errors.push("decomp_base cannot be 0".to_string());
}
if self.decomp_legs == 0 {
errors.push("decomp_legs cannot be 0".to_string());
}
// Performance validations
if self.logrows > MAX_PUBLIC_SRS {
warn!("logrows exceeds maximum public SRS size");
}
// Validate tolerance is non-negative
if self.tolerance.val < 0.0 {
errors.push("tolerance cannot be negative".to_string());
}
// Performance warnings
if self.input_scale > 20 || self.param_scale > 20 {
warn!("High scale values (>20) may impact performance");
}
if errors.is_empty() {
Ok(())
} else {
Err(errors.join("\n"))
}
Ok(())
}
/// Export the ezkl configuration as json
/// Exports the configuration as JSON
///
/// Serializes the RunArgs instance to a JSON string
///
/// # Returns
/// * `Ok(String)` containing JSON representation
/// * `Err` if serialization fails
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
let res = serde_json::to_string(&self)?;
Ok(res)
}
/// Parse an ezkl configuration from a json
/// Parses configuration from JSON
///
/// Deserializes a RunArgs instance from a JSON string
///
/// # Arguments
/// * `arg_json` - JSON string containing configuration
///
/// # Returns
/// * `Ok(RunArgs)` if parsing succeeds
/// * `Err` if parsing fails
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(arg_json)
}
}
/// Parse a single key-value pair
// Additional helper functions for the module
/// Parses a key-value pair from a string in the format "key->value"
///
/// # Arguments
/// * `s` - Input string in the format "key->value"
///
/// # Returns
/// * `Ok((T, U))` - Parsed key and value
/// * `Err` - If parsing fails
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn parse_key_val<T, U>(
s: &str,
@@ -415,14 +525,15 @@ where
{
let pos = s
.find("->")
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
let a = s[..pos].parse()?;
let b = s[pos + 2..].parse()?;
Ok((a, b))
.ok_or_else(|| format!("invalid KEY->VALUE: no `->` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 2..].parse()?))
}
/// Check if the version string matches the artifact version
/// If the version string does not match the artifact version, log a warning
/// Verifies that a version string matches the expected artifact version
/// Logs warnings for version mismatches or unversioned artifacts
///
/// # Arguments
/// * `artifact_version` - Version string from the artifact
pub fn check_version_string_matches(artifact_version: &str) {
if artifact_version == "0.0.0"
|| artifact_version == "source - no compatibility guaranteed"
@@ -447,3 +558,98 @@ pub fn check_version_string_matches(artifact_version: &str) {
);
}
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
#[test]
fn test_valid_default_args() {
let args = RunArgs::default();
assert!(args.validate().is_ok());
}
#[test]
fn test_invalid_param_visibility() {
let mut args = RunArgs::default();
args.param_visibility = Visibility::Public;
let err = args.validate().unwrap_err();
assert!(err.contains("Parameters cannot be public instances"));
}
#[test]
fn test_invalid_scale_rebase() {
let mut args = RunArgs::default();
args.scale_rebase_multiplier = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("scale_rebase_multiplier must be >= 1"));
}
#[test]
fn test_invalid_lookup_range() {
let mut args = RunArgs::default();
args.lookup_range = (100, -100);
let err = args.validate().unwrap_err();
assert!(err.contains("Invalid lookup range"));
}
#[test]
fn test_invalid_logrows() {
let mut args = RunArgs::default();
args.logrows = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("logrows must be >= 1"));
}
#[test]
fn test_invalid_inner_cols() {
let mut args = RunArgs::default();
args.num_inner_cols = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("num_inner_cols must be >= 1"));
}
#[test]
fn test_invalid_tolerance() {
let mut args = RunArgs::default();
args.tolerance.val = 1.0;
args.output_visibility = Visibility::Private;
let err = args.validate().unwrap_err();
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
}
#[test]
fn test_negative_tolerance() {
let mut args = RunArgs::default();
args.tolerance.val = -1.0;
let err = args.validate().unwrap_err();
assert!(err.contains("tolerance cannot be negative"));
}
#[test]
fn test_zero_batch_size() {
let mut args = RunArgs::default();
args.variables = vec![("batch_size".to_string(), 0)];
let err = args.validate().unwrap_err();
assert!(err.contains("'batch_size' cannot be 0"));
}
#[test]
fn test_json_serialization() {
let args = RunArgs::default();
let json = args.as_json().unwrap();
let deserialized = RunArgs::from_json(&json).unwrap();
assert_eq!(args, deserialized);
}
#[test]
fn test_multiple_validation_errors() {
let mut args = RunArgs::default();
args.logrows = 0;
args.lookup_range = (100, -100);
let err = args.validate().unwrap_err();
// Should contain multiple error messages
assert!(err.matches("\n").count() >= 1);
}
}

View File

@@ -133,7 +133,6 @@ pub fn aggregate<'a>(
.collect_vec()
}));
// loader.ctx().constrain_equal(cell_0, cell_1)
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
.map_err(|_| plonk::Error::Synthesis)?;
@@ -309,11 +308,11 @@ impl AggregationCircuit {
})
}
///
/// Number of limbs used for decomposition
pub fn num_limbs() -> usize {
LIMBS
}
///
/// Number of bits used for decomposition
pub fn num_bits() -> usize {
BITS
}

View File

@@ -353,6 +353,7 @@ where
C::ScalarExt: Serialize + DeserializeOwned,
{
/// Create a new application snark from proof and instance variables ready for aggregation
#[allow(clippy::too_many_arguments)]
pub fn new(
protocol: Option<PlonkProtocol<C>>,
instances: Vec<Vec<F>>,
@@ -528,7 +529,6 @@ pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
disable_selector_compression: bool,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
@@ -794,7 +794,6 @@ pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
params: <C as Circuit<Scheme::Scalar>>::Params,
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
where
C: Circuit<Scheme::Scalar>,
Scheme::Curve: SerdeObject + CurveAffine,
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
{
@@ -817,7 +816,6 @@ pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
params: <C as Circuit<Scheme::Scalar>>::Params,
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
where
C: Circuit<Scheme::Scalar>,
Scheme::Curve: SerdeObject + CurveAffine,
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
{

View File

@@ -38,4 +38,10 @@ pub enum TensorError {
/// Decomposition error
#[error("decomposition error: {0}")]
DecompositionError(#[from] DecompositionError),
/// Invalid argument
#[error("invalid argument: {0}")]
InvalidArgument(String),
/// Index out of bounds
#[error("index {0} out of bounds for dimension {1}")]
IndexOutOfBounds(usize, usize),
}

View File

@@ -24,9 +24,6 @@ use std::path::PathBuf;
pub use val::*;
pub use var::*;
#[cfg(feature = "metal")]
use instant::Instant;
use crate::{
circuit::utils,
fieldutils::{integer_rep_to_felt, IntegerRep},
@@ -40,8 +37,6 @@ use halo2_proofs::{
poly::Rotation,
};
use itertools::Itertools;
#[cfg(feature = "metal")]
use metal::{Device, MTLResourceOptions, MTLSize};
use std::error::Error;
use std::fmt::Debug;
use std::io::Read;
@@ -49,31 +44,6 @@ use std::iter::Iterator;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
#[cfg(feature = "metal")]
use std::collections::HashMap;
#[cfg(feature = "metal")]
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
#[cfg(feature = "metal")]
lazy_static::lazy_static! {
static ref DEVICE: Device = Device::system_default().expect("no device found");
static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap();
static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue();
static ref PIPELINES: HashMap<String, metal::ComputePipelineState> = {
let mut map = HashMap::new();
for name in ["add", "sub", "mul"] {
let function = LIB.get_function(name, None).unwrap();
let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap();
map.insert(name.to_string(), pipeline);
}
map
};
}
/// The (inner) type of tensor elements.
pub trait TensorType: Clone + Debug + 'static {
/// Returns the zero value.
@@ -833,6 +803,12 @@ impl<T: Clone + TensorType> Tensor<T> {
num_repeats: usize,
initial_offset: usize,
) -> Result<Tensor<T>, TensorError> {
if n == 0 {
return Err(TensorError::InvalidArgument(
"Cannot duplicate every 0th element".to_string(),
));
}
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.clone().into_iter().enumerate() {
@@ -862,11 +838,17 @@ impl<T: Clone + TensorType> Tensor<T> {
num_repeats: usize,
initial_offset: usize,
) -> Result<Tensor<T>, TensorError> {
if n == 0 {
return Err(TensorError::InvalidArgument(
"Cannot remove every 0th element".to_string(),
));
}
// Pre-calculate capacity to avoid reallocations
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
let mut inner = Vec::with_capacity(estimated_size);
// Use iterator directly instead of creating intermediate collections
// Use iterator directly instead of creating intermediate collectionsif
let mut i = 0;
while i < self.inner.len() {
// Add the current element
@@ -885,7 +867,6 @@ impl<T: Clone + TensorType> Tensor<T> {
}
/// Remove indices
/// WARN: assumes indices are in ascending order for speed
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
@@ -912,7 +893,11 @@ impl<T: Clone + TensorType> Tensor<T> {
}
// remove indices
for elem in indices.iter().rev() {
inner.remove(*elem);
if *elem < self.len() {
inner.remove(*elem);
} else {
return Err(TensorError::IndexOutOfBounds(*elem, self.len()));
}
}
Tensor::new(Some(&inner), &[inner.len()])
@@ -1404,10 +1389,6 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "add");
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
@@ -1505,10 +1486,6 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "sub");
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
@@ -1576,10 +1553,6 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "mul");
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
@@ -1685,7 +1658,9 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
}
// implement remainder
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + PartialEq> Rem
for Tensor<T>
{
type Output = Result<Tensor<T>, TensorError>;
/// Elementwise remainder of a tensor with another tensor.
@@ -1714,9 +1689,25 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
let mut lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
*o = o.clone() % r;
});
lhs.par_iter_mut()
.zip(rhs)
.map(|(o, r)| {
if let Some(zero) = T::zero() {
if r != zero {
*o = o.clone() % r;
Ok(())
} else {
Err(TensorError::InvalidArgument(
"Cannot divide by zero in remainder".to_string(),
))
}
} else {
Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
))
}
})
.collect::<Result<Vec<_>, _>>()?;
Ok(lhs)
}
@@ -1751,7 +1742,6 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
/// assert_eq!(c, vec![2, 3]);
///
/// ```
pub fn get_broadcasted_shape(
shape_a: &[usize],
shape_b: &[usize],
@@ -1759,20 +1749,21 @@ pub fn get_broadcasted_shape(
let num_dims_a = shape_a.len();
let num_dims_b = shape_b.len();
match (num_dims_a, num_dims_b) {
(a, b) if a == b => {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
}
Ok(broadcasted_shape)
if num_dims_a == num_dims_b {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
}
(a, b) if a < b => Ok(shape_b.to_vec()),
(a, b) if a > b => Ok(shape_a.to_vec()),
_ => Err(TensorError::DimError(
Ok(broadcasted_shape)
} else if num_dims_a < num_dims_b {
Ok(shape_b.to_vec())
} else if num_dims_a > num_dims_b {
Ok(shape_a.to_vec())
} else {
Err(TensorError::DimError(
"Unknown condition for broadcasting".to_string(),
)),
))
}
}
////////////////////////
@@ -1811,66 +1802,4 @@ mod tests {
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
}
#[test]
#[cfg(feature = "metal")]
fn tensor_metal_int() {
let a = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
let b = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
let c = metal_tensor_op(&a, &b, "add");
assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap());
let c = metal_tensor_op(&a, &b, "sub");
assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap());
let c = metal_tensor_op(&a, &b, "mul");
assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap());
}
#[test]
#[cfg(feature = "metal")]
fn tensor_metal_felt() {
use halo2curves::bn256::Fr;
let a = Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
&[2, 2],
)
.unwrap();
let b = Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
&[2, 2],
)
.unwrap();
let c = metal_tensor_op(&a, &b, "add");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]),
&[2, 2],
)
.unwrap()
);
let c = metal_tensor_op(&a, &b, "sub");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]),
&[2, 2],
)
.unwrap()
);
let c = metal_tensor_op(&a, &b, "mul");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]),
&[2, 2],
)
.unwrap()
);
}
}

View File

@@ -27,7 +27,7 @@ pub fn get_rep(
n: usize,
) -> Result<Vec<IntegerRep>, DecompositionError> {
// check if x is too large
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
if (*x).abs() > ((base as i128).pow(n as u32)) - 1 {
return Err(DecompositionError::TooLarge(*x, base, n));
}
let mut rep = vec![0; n + 1];
@@ -43,8 +43,8 @@ pub fn get_rep(
let mut x = x.abs();
//
for i in (1..rep.len()).rev() {
rep[i] = x % base as i128;
x /= base as i128;
rep[i] = x % base as IntegerRep;
x /= base as IntegerRep;
}
Ok(rep)
@@ -127,7 +127,7 @@ pub fn decompose(
.flatten()
.collect::<Vec<IntegerRep>>();
let output = Tensor::<i128>::new(Some(&resp), &dims)?;
let output = Tensor::<IntegerRep>::new(Some(&resp), &dims)?;
Ok(output)
}
@@ -385,6 +385,12 @@ pub fn resize<T: TensorType + Send + Sync>(
pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync>(
t: &[Tensor<T>],
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
return Err(TensorError::DimMismatch("add".to_string()));
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();
@@ -433,6 +439,11 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync>(
t: &[Tensor<T>],
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
return Err(TensorError::DimMismatch("sub".to_string()));
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();
@@ -479,6 +490,11 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync>(
t: &[Tensor<T>],
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
return Err(TensorError::DimMismatch("mult".to_string()));
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();

File diff suppressed because it is too large Load Diff

View File

@@ -5,33 +5,35 @@ use log::{debug, error, warn};
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
/// Typically assign [ValTensor]s to [VarTensor]s when laying out a circuit.
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
/// VarTensors are used to store and manage circuit columns, typically for assigning ValTensor
/// values during circuit layout. The tensor organizes storage into blocks of columns, where each
/// block contains multiple columns and each column contains multiple rows.
#[derive(Clone, Default, Debug, PartialEq, Eq)]
pub enum VarTensor {
/// A VarTensor for holding Advice values, which are assigned at proving time.
Advice {
/// Vec of Advice columns, we have [[xx][xx][xx]...] where each inner vec is xx columns
inner: Vec<Vec<Column<Advice>>>,
///
/// The number of columns in each inner block
num_inner_cols: usize,
/// Number of rows available to be used in each column of the storage
col_size: usize,
},
/// Dummy var
/// A placeholder tensor used for testing or temporary storage
Dummy {
///
/// The number of columns in each inner block
num_inner_cols: usize,
/// Number of rows available to be used in each column of the storage
col_size: usize,
},
/// Empty var
/// An empty tensor with no storage
#[default]
Empty,
}
impl VarTensor {
/// name of the tensor
/// Returns the name of the tensor variant as a static string
pub fn name(&self) -> &'static str {
match self {
VarTensor::Advice { .. } => "Advice",
@@ -40,22 +42,35 @@ impl VarTensor {
}
}
///
/// Returns true if the tensor is an Advice variant
pub fn is_advice(&self) -> bool {
matches!(self, VarTensor::Advice { .. })
}
/// Calculates the maximum number of usable rows in the constraint system
///
/// # Arguments
/// * `cs` - The constraint system
/// * `logrows` - Log base 2 of the total number of rows (including system and blinding rows)
///
/// # Returns
/// The maximum number of usable rows after accounting for blinding factors
pub fn max_rows<F: PrimeField>(cs: &ConstraintSystem<F>, logrows: usize) -> usize {
let base = 2u32;
base.pow(logrows as u32) as usize - cs.blinding_factors() - 1
}
/// Create a new VarTensor::Advice that is unblinded
/// Arguments
/// * `cs` - The constraint system
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
/// * `capacity` - The number of advice cells to allocate
/// Creates a new VarTensor::Advice with unblinded columns. Unblinded columns are used when
/// the values do not need to be hidden in the proof.
///
/// # Arguments
/// * `cs` - The constraint system to create columns in
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_inner_cols` - Number of columns in each inner block
/// * `capacity` - Total number of advice cells to allocate
///
/// # Returns
/// A new VarTensor::Advice with unblinded columns enabled for equality constraints
pub fn new_unblinded_advice<F: PrimeField>(
cs: &mut ConstraintSystem<F>,
logrows: usize,
@@ -93,11 +108,17 @@ impl VarTensor {
}
}
/// Create a new VarTensor::Advice
/// Arguments
/// * `cs` - The constraint system
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
/// * `capacity` - The number of advice cells to allocate
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
/// the values need to be hidden in the proof.
///
/// # Arguments
/// * `cs` - The constraint system to create columns in
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_inner_cols` - Number of columns in each inner block
/// * `capacity` - Total number of advice cells to allocate
///
/// # Returns
/// A new VarTensor::Advice with blinded columns enabled for equality constraints
pub fn new_advice<F: PrimeField>(
cs: &mut ConstraintSystem<F>,
logrows: usize,
@@ -133,11 +154,17 @@ impl VarTensor {
}
}
/// Initializes fixed columns to support the VarTensor::Advice
/// Arguments
/// * `cs` - The constraint system
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
/// * `capacity` - The number of advice cells to allocate
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
/// Fixed columns are used for constant values that are known at circuit creation time.
///
/// # Arguments
/// * `cs` - The constraint system to create columns in
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_constants` - Number of constant values needed
/// * `module_requires_fixed` - Whether the module requires at least one fixed column
///
/// # Returns
/// The number of fixed columns created
pub fn constant_cols<F: PrimeField>(
cs: &mut ConstraintSystem<F>,
logrows: usize,
@@ -169,7 +196,14 @@ impl VarTensor {
modulo
}
/// Create a new VarTensor::Dummy
/// Creates a new dummy VarTensor for testing or temporary storage
///
/// # Arguments
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_inner_cols` - Number of columns in each inner block
///
/// # Returns
/// A new VarTensor::Dummy with the specified dimensions
pub fn dummy(logrows: usize, num_inner_cols: usize) -> Self {
let base = 2u32;
let max_rows = base.pow(logrows as u32) as usize - 6;
@@ -179,7 +213,7 @@ impl VarTensor {
}
}
/// Gets the dims of the object the VarTensor represents
/// Returns the number of blocks in the tensor
pub fn num_blocks(&self) -> usize {
match self {
VarTensor::Advice { inner, .. } => inner.len(),
@@ -187,7 +221,7 @@ impl VarTensor {
}
}
/// Num inner cols
/// Returns the number of columns in each inner block
pub fn num_inner_cols(&self) -> usize {
match self {
VarTensor::Advice { num_inner_cols, .. } | VarTensor::Dummy { num_inner_cols, .. } => {
@@ -197,7 +231,7 @@ impl VarTensor {
}
}
/// Total number of columns
/// Returns the total number of columns across all blocks
pub fn num_cols(&self) -> usize {
match self {
VarTensor::Advice { inner, .. } => inner[0].len() * inner.len(),
@@ -205,7 +239,7 @@ impl VarTensor {
}
}
/// Gets the size of each column
/// Returns the maximum number of rows in each column
pub fn col_size(&self) -> usize {
match self {
VarTensor::Advice { col_size, .. } | VarTensor::Dummy { col_size, .. } => *col_size,
@@ -213,7 +247,7 @@ impl VarTensor {
}
}
/// Gets the size of each column
/// Returns the total size of each block (num_inner_cols * col_size)
pub fn block_size(&self) -> usize {
match self {
VarTensor::Advice {
@@ -230,7 +264,13 @@ impl VarTensor {
}
}
/// Take a linear coordinate and output the (column, row) position in the storage block.
/// Converts a linear coordinate to (block, column, row) coordinates in the storage
///
/// # Arguments
/// * `linear_coord` - The linear index to convert
///
/// # Returns
/// A tuple of (block_index, column_index, row_index)
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
// x indexes over blocks of size num_inner_cols
let x = linear_coord / self.block_size();
@@ -243,7 +283,17 @@ impl VarTensor {
}
impl VarTensor {
/// Retrieve the value of a specific cell in the tensor.
/// Queries a range of cells in the tensor during circuit synthesis
///
/// # Arguments
/// * `meta` - Virtual cells accessor
/// * `x` - Block index
/// * `y` - Column index within block
/// * `z` - Starting row offset
/// * `rng` - Number of consecutive rows to query
///
/// # Returns
/// A tensor of expressions representing the queried cells
pub fn query_rng<F: PrimeField>(
&self,
meta: &mut VirtualCells<'_, F>,
@@ -268,7 +318,16 @@ impl VarTensor {
}
}
/// Retrieve the value of a specific block at an offset in the tensor.
/// Queries an entire block of cells at a given offset
///
/// # Arguments
/// * `meta` - Virtual cells accessor
/// * `x` - Block index
/// * `z` - Row offset
/// * `rng` - Number of consecutive rows to query
///
/// # Returns
/// A tensor of expressions representing the queried block
pub fn query_whole_block<F: PrimeField>(
&self,
meta: &mut VirtualCells<'_, F>,
@@ -293,7 +352,16 @@ impl VarTensor {
}
}
/// Assigns a constant value to a specific cell in the tensor.
/// Assigns a constant value to a specific cell in the tensor
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for the assignment
/// * `coord` - Coordinate within the tensor
/// * `constant` - The constant value to assign
///
/// # Returns
/// The assigned cell or an error if assignment fails
pub fn assign_constant<F: PrimeField + TensorType + PartialOrd>(
&self,
region: &mut Region<F>,
@@ -313,7 +381,17 @@ impl VarTensor {
}
}
/// Assigns [ValTensor] to the columns of the inner tensor.
/// Assigns values from a ValTensor to this tensor, excluding specified positions
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `omissions` - Set of positions to skip during assignment
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
@@ -344,7 +422,16 @@ impl VarTensor {
Ok(res)
}
/// Assigns [ValTensor] to the columns of the inner tensor.
/// Assigns values from a ValTensor to this tensor
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
@@ -396,14 +483,23 @@ impl VarTensor {
Ok(res)
}
/// Helper function to get the remaining size of the column
/// Returns the remaining available space in a column for assignments
///
/// # Arguments
/// * `offset` - Current offset in the column
/// * `values` - The ValTensor to check space for
///
/// # Returns
/// The number of rows that need to be flushed or an error if space is insufficient
pub fn get_column_flush<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
offset: usize,
values: &ValTensor<F>,
) -> Result<usize, halo2_proofs::plonk::Error> {
if values.len() > self.col_size() {
error!("Values are too large for the column");
error!(
"There are too many values to flush for this column size, try setting the logrows to a higher value (eg. --logrows 22 on the cli)"
);
return Err(halo2_proofs::plonk::Error::Synthesis);
}
@@ -427,8 +523,16 @@ impl VarTensor {
Ok(flush_len)
}
/// Assigns [ValTensor] to the columns of the inner tensor. Whereby the values are assigned to a single column, without overflowing.
/// So for instance if we are assigning 10 values and we are at index 18 of the column, and the columns are of length 20, we skip the last 2 values of current column and start from the beginning of the next column.
/// Assigns values to a single column, avoiding column overflow by flushing to the next column if needed
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// A tuple of (assigned ValTensor, number of rows flushed) or an error if assignment fails
pub fn assign_exact_column<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
@@ -443,8 +547,17 @@ impl VarTensor {
Ok((assigned_vals, flush_len))
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
/// Assigns values with duplication in dummy mode, used for testing and simulation
///
/// # Arguments
/// * `row` - Starting row for assignment
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `single_inner_col` - Whether to treat as a single column
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
pub fn dummy_assign_with_duplication<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
@@ -494,7 +607,16 @@ impl VarTensor {
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Assigns values with duplication but without enforcing constraints between duplicated values
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
pub fn assign_with_duplication_unconstrained<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
@@ -533,8 +655,18 @@ impl VarTensor {
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
/// Assigns values with duplication and enforces equality constraints between duplicated values
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `row` - Starting row for assignment
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `check_mode` - Mode for checking equality constraints
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
pub fn assign_with_duplication_constrained<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
@@ -558,7 +690,7 @@ impl VarTensor {
// duplicates every nth element to adjust for column overflow
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
let mut res: ValTensor<F> = {
let mut res: ValTensor<F> =
v.enum_map(|coord, k| {
let step = self.num_inner_cols();
@@ -579,12 +711,18 @@ impl VarTensor {
prev_cell = Some(cell.clone());
} else if coord > 0 && at_beginning_of_column {
if let Some(prev_cell) = prev_cell.as_ref() {
let cell = cell.cell().ok_or({
let cell = if let Some(cell) = cell.cell() {
cell
} else {
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
let prev_cell = prev_cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
return Err(halo2_proofs::plonk::Error::Synthesis);
};
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
prev_cell
} else {
error!("Error getting prev cell: {:?}", (x,y));
return Err(halo2_proofs::plonk::Error::Synthesis);
};
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Previous cell was not set");
@@ -594,7 +732,8 @@ impl VarTensor {
Ok(cell)
})?.into()};
})?.into();
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
@@ -606,6 +745,17 @@ impl VarTensor {
}
}
/// Assigns a single value to the tensor. This is a helper function used by other assignment methods.
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for the assignment
/// * `k` - The value to assign
/// * `coord` - The coordinate where to assign the value
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned value or an error if assignment fails
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
@@ -616,24 +766,28 @@ impl VarTensor {
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset + coord);
let res = match k {
// Handle direct value assignment
ValType::Value(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
}
_ => unimplemented!(),
},
// Handle copying previously assigned value
ValType::PrevAssigned(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
}
_ => unimplemented!(),
},
// Handle copying previously assigned constant
ValType::AssignedConstant(v, val) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
}
_ => unimplemented!(),
},
// Handle assigning evaluated value
ValType::AssignedValue(v) => match &self {
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
region
@@ -642,6 +796,7 @@ impl VarTensor {
),
_ => unimplemented!(),
},
// Handle constant value assignment with caching
ValType::Constant(v) => {
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
let value = ValType::AssignedConstant(

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -28,11 +28,12 @@
"commitment": "KZG",
"decomp_base": 128,
"decomp_legs": 2,
"bounded_log_lookup": false
"bounded_log_lookup": false,
"ignore_range_check_inputs_outputs": false
},
"num_rows": 46,
"total_assignments": 92,
"total_const_size": 3,
"num_rows": 236,
"total_assignments": 472,
"total_const_size": 4,
"total_dynamic_col_size": 0,
"max_dynamic_input_len": 0,
"num_dynamic_lookups": 0,

Binary file not shown.

View File

@@ -1,7 +1,6 @@
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[cfg(test)]
mod native_tests {
use ezkl::circuit::Tolerance;
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
@@ -23,6 +22,8 @@ mod native_tests {
static COMPILE_WASM: Once = Once::new();
static ENV_SETUP: Once = Once::new();
const TEST_BINARY: &str = "test-runs/ezkl";
//Sure to run this once
#[derive(Debug)]
#[allow(dead_code)]
@@ -75,9 +76,8 @@ mod native_tests {
});
}
///
#[allow(dead_code)]
pub fn init_wasm() {
fn init_wasm() {
COMPILE_WASM.call_once(|| {
build_wasm_ezkl();
});
@@ -104,7 +104,7 @@ mod native_tests {
fn download_srs(logrows: u32, commitment: Commitments) {
// if does not exist, download it
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"get-srs",
"--logrows",
@@ -187,13 +187,14 @@ mod native_tests {
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
const LARGE_TESTS: [&str; 6] = [
const LARGE_TESTS: [&str; 7] = [
"self_attention",
"nanoGPT",
"multihead_attention",
"mobilenet",
"mnist_gan",
"smallworm",
"fr_age",
];
const ACCURACY_CAL_TESTS: [&str; 6] = [
@@ -205,62 +206,62 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 98] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
"1l_flatten",
const TESTS: [&str; 99] = [
"1l_mlp", //0
"1l_slice", //1
"1l_concat", //2
"1l_flatten", //3
// "1l_average",
"1l_div",
"1l_pad", // 5
"1l_reshape",
"1l_eltwise_div",
"1l_sigmoid",
"1l_sqrt",
"1l_softmax", //10
"1l_div", //4
"1l_pad", // 5
"1l_reshape", //6
"1l_eltwise_div", //7
"1l_sigmoid", //8
"1l_sqrt", //9
"1l_softmax", //10
// "1l_instance_norm",
"1l_batch_norm",
"1l_prelu",
"1l_leakyrelu",
"1l_gelu_noappx",
"1l_batch_norm", //11
"1l_prelu", //12
"1l_leakyrelu", //13
"1l_gelu_noappx", //14
// "1l_gelu_tanh_appx",
"1l_relu", //15
"1l_downsample",
"1l_tanh",
"2l_relu_sigmoid_small",
"2l_relu_fc",
"2l_relu_small", //20
"2l_relu_sigmoid",
"1l_conv",
"2l_sigmoid_small",
"2l_relu_sigmoid_conv",
"3l_relu_conv_fc", //25
"4l_relu_conv_fc",
"1l_erf",
"1l_var",
"1l_elu",
"min", //30
"max",
"1l_max_pool",
"1l_conv_transpose",
"1l_upsample",
"1l_identity", //35
"idolmodel", // too big evm
"trig", // too big evm
"prelu_gmm",
"lstm",
"rnn", //40
"quantize_dequantize",
"1l_where",
"boolean",
"boolean_identity",
"decision_tree", // 45
"random_forest",
"gradient_boosted_trees",
"1l_topk",
"xgboost",
"lightgbm", //50
"hummingbird_decision_tree",
"1l_relu", //15
"1l_downsample", //16
"1l_tanh", //17
"2l_relu_sigmoid_small", //18
"2l_relu_fc", //19
"2l_relu_small", //20
"2l_relu_sigmoid", //21
"1l_conv", //22
"2l_sigmoid_small", //23
"2l_relu_sigmoid_conv", //24
"3l_relu_conv_fc", //25
"4l_relu_conv_fc", //26
"1l_erf", //27
"1l_var", //28
"1l_elu", //29
"min", //30
"max", //31
"1l_max_pool", //32
"1l_conv_transpose", //33
"1l_upsample", //34
"1l_identity", //35
"idolmodel", // too big evm
"trig", // too big evm
"prelu_gmm", //38
"lstm", //39
"rnn", //40
"quantize_dequantize", //41
"1l_where", //42
"boolean", //43
"boolean_identity", //44
"decision_tree", // 45
"random_forest", //46
"gradient_boosted_trees", //47
"1l_topk", //48
"xgboost", //49
"lightgbm", //50
"hummingbird_decision_tree", //51
"oh_decision_tree",
"linear_svc",
"gather_elements",
@@ -308,6 +309,7 @@ mod native_tests {
"log", // 95
"exp", // 96
"general_exp", // 97
"integer_div", // 98
];
const WASM_TESTS: [&str; 46] = [
@@ -395,29 +397,29 @@ mod native_tests {
const TESTS_AGGR: [&str; 3] = ["1l_mlp", "1l_flatten", "1l_average"];
const TESTS_EVM: [&str; 23] = [
"1l_mlp",
"1l_flatten",
"1l_average",
"1l_reshape",
"1l_sigmoid",
"1l_div",
"1l_sqrt",
"1l_prelu",
"1l_var",
"1l_leakyrelu",
"1l_gelu_noappx",
"1l_relu",
"1l_tanh",
"2l_relu_sigmoid_small",
"2l_relu_small",
"min",
"max",
"1l_max_pool",
"idolmodel",
"1l_identity",
"lstm",
"rnn",
"quantize_dequantize",
"1l_mlp", // 0
"1l_flatten", // 1
"1l_average", // 2
"1l_reshape", // 3
"1l_sigmoid", // 4
"1l_div", // 5
"1l_sqrt", // 6
"1l_prelu", // 7
"1l_var", // 8
"1l_leakyrelu", // 9
"1l_gelu_noappx", // 10
"1l_relu", // 11
"1l_tanh", // 12
"2l_relu_sigmoid_small", // 13
"2l_relu_small", // 14
"min", // 15
"max", // 16
"1l_max_pool", // 17
"idolmodel", // 18
"1l_identity", // 19
"lstm", // 20
"rnn", // 21
"quantize_dequantize", // 22
];
const TESTS_EVM_AGGR: [&str; 18] = [
@@ -541,12 +543,12 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
});
seq!(N in 0..=97 {
seq!(N in 0..=98 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -606,7 +608,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -616,7 +618,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
test_dir.close().unwrap();
}
@@ -627,7 +629,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// gen random number between 0.0 and 1.0
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(32776), Some(5));
test_dir.close().unwrap();
}
@@ -642,7 +644,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
}
@@ -652,7 +654,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -661,7 +663,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -670,7 +672,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -679,7 +681,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -688,7 +690,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -697,7 +699,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -706,7 +708,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -716,7 +718,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -726,7 +728,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -735,7 +737,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -745,7 +747,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -754,7 +756,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -764,7 +766,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -774,7 +776,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -784,7 +786,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -794,7 +796,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false);
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -804,7 +806,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false);
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -963,7 +965,7 @@ mod native_tests {
});
seq!(N in 0..=5 {
seq!(N in 0..=6 {
#(#[test_case(LARGE_TESTS[N])])*
#[ignore]
@@ -981,7 +983,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, Some(5));
test_dir.close().unwrap();
}
});
@@ -1459,6 +1461,8 @@ mod native_tests {
scales_to_use: Option<Vec<u32>>,
tolerance: f32,
bounded_lookup_log: bool,
decomp_base: Option<usize>,
decomp_legs: Option<usize>,
) {
let mut tolerance = tolerance;
gen_circuit_settings_and_witness(
@@ -1475,6 +1479,8 @@ mod native_tests {
Commitments::KZG,
2,
bounded_lookup_log,
decomp_base,
decomp_legs,
);
if tolerance > 0.0 {
@@ -1551,7 +1557,7 @@ mod native_tests {
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
.unwrap();
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
@@ -1563,7 +1569,7 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
@@ -1575,7 +1581,7 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
@@ -1587,7 +1593,7 @@ mod native_tests {
.expect("failed to execute process");
assert!(!status.success());
} else {
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
@@ -1616,6 +1622,8 @@ mod native_tests {
commitment: Commitments,
lookup_safety_margin: usize,
bounded_lookup_log: bool,
decomp_base: Option<usize>,
decomp_legs: Option<usize>,
) {
let mut args = vec![
"gen-settings".to_string(),
@@ -1634,11 +1642,24 @@ mod native_tests {
format!("--commitment={}", commitment),
];
// if output-visibility is fixed set --range-check-inputs-outputs to False
if output_visibility == "fixed" {
args.push("--ignore-range-check-inputs-outputs".to_string());
}
if let Some(decomp_base) = decomp_base {
args.push(format!("--decomp-base={}", decomp_base));
}
if let Some(decomp_legs) = decomp_legs {
args.push(format!("--decomp-legs={}", decomp_legs));
}
if bounded_lookup_log {
args.push("--bounded-log-lookup".to_string());
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(args)
.status()
.expect("failed to execute process");
@@ -1668,7 +1689,7 @@ mod native_tests {
calibrate_args.push(scales);
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(calibrate_args)
.status()
.expect("failed to execute process");
@@ -1692,7 +1713,7 @@ mod native_tests {
*tolerance = 0.0;
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"compile-circuit",
"-M",
@@ -1709,7 +1730,7 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"gen-witness",
"-D",
@@ -1751,6 +1772,8 @@ mod native_tests {
Commitments::KZG,
2,
false,
None,
None,
);
println!(
@@ -1775,7 +1798,7 @@ mod native_tests {
// Mock prove (fast, but does not cover some potential issues)
fn render_circuit(test_dir: &str, example_name: String) {
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"render-circuit",
"-M",
@@ -1806,7 +1829,7 @@ mod native_tests {
Commitments::KZG,
2,
);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock-aggregate",
"--logrows=23",
@@ -1844,7 +1867,7 @@ mod native_tests {
download_srs(23, commitment);
// now setup-aggregate
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"setup-aggregate",
"--sample-snarks",
@@ -1860,7 +1883,7 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"aggregate",
"--logrows=23",
@@ -1875,7 +1898,7 @@ mod native_tests {
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"verify-aggr",
"--logrows=23",
@@ -1925,7 +1948,7 @@ mod native_tests {
let private_key = format!("--private-key={}", *ANVIL_DEFAULT_PRIVATE_KEY);
// create encoded calldata
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"encode-evm-calldata",
"--proof-path",
@@ -1947,7 +1970,7 @@ mod native_tests {
let args = build_args(base_args, &sol_arg);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(args)
.status()
.expect("failed to execute process");
@@ -1963,7 +1986,7 @@ mod native_tests {
private_key.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -1985,14 +2008,14 @@ mod native_tests {
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&base_args)
.status()
.expect("failed to execute process");
assert!(status.success());
// As sanity check, add example that should fail.
base_args[2] = PF_FAILURE_AGGR;
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(base_args)
.status()
.expect("failed to execute process");
@@ -2035,13 +2058,15 @@ mod native_tests {
commitment,
lookup_safety_margin,
false,
None,
None,
);
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
init_params(settings_path.clone().into());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"setup",
"-M",
@@ -2056,7 +2081,7 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"prove",
"-W",
@@ -2074,7 +2099,7 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"swap-proof-commitments",
"--proof-path",
@@ -2086,7 +2111,7 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"verify",
format!("--settings-path={}", settings_path).as_str(),
@@ -2109,7 +2134,7 @@ mod native_tests {
// get_srs for the graph_settings_num_instances
download_srs(1, graph_settings.run_args.commitment.into());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"verify",
format!("--settings-path={}", settings_path).as_str(),
@@ -2159,7 +2184,7 @@ mod native_tests {
let settings_arg = format!("--settings-path={}", settings_path);
// create encoded calldata
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"encode-evm-calldata",
"--proof-path",
@@ -2179,7 +2204,7 @@ mod native_tests {
args.push("--sol-code-path");
args.push(sol_arg.as_str());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2191,7 +2216,7 @@ mod native_tests {
args.push("--sol-code-path");
args.push(sol_arg.as_str());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2213,14 +2238,14 @@ mod native_tests {
deployed_addr_arg.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
assert!(status.success());
// As sanity check, add example that should fail.
args[2] = PF_FAILURE;
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(args)
.status()
.expect("failed to execute process");
@@ -2228,6 +2253,7 @@ mod native_tests {
}
// prove-serialize-verify, the usual full path
#[allow(clippy::too_many_arguments)]
fn kzg_evm_prove_and_verify_reusable_verifier(
num_inner_columns: usize,
test_dir: &str,
@@ -2278,7 +2304,7 @@ mod native_tests {
"--reusable",
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2293,7 +2319,7 @@ mod native_tests {
"-C=verifier/reusable",
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2322,7 +2348,7 @@ mod native_tests {
&sol_arg_vk,
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2337,7 +2363,7 @@ mod native_tests {
"-C=vka",
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2350,7 +2376,7 @@ mod native_tests {
let deployed_addr_arg_vk = format!("--addr-vk={}", addr_vk);
// create encoded calldata
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"encode-evm-calldata",
"--proof-path",
@@ -2373,7 +2399,7 @@ mod native_tests {
deployed_addr_arg_vk.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2406,7 +2432,7 @@ mod native_tests {
// Verify the modified proof (should fail)
let mut args_mod = args.clone();
args_mod[2] = &modified_pf_arg;
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args_mod)
.status()
.expect("failed to execute process");
@@ -2467,6 +2493,8 @@ mod native_tests {
Commitments::KZG,
2,
false,
None,
None,
);
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
@@ -2482,7 +2510,7 @@ mod native_tests {
let test_input_source = format!("--input-source={}", input_source);
let test_output_source = format!("--output-source={}", output_source);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"setup",
"-M",
@@ -2497,7 +2525,7 @@ mod native_tests {
assert!(status.success());
// generate the witness, passing the vk path to generate the necessary kzg commits
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"gen-witness",
"-D",
@@ -2554,7 +2582,7 @@ mod native_tests {
}
input.save(data_path.clone().into()).unwrap();
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"setup-test-evm-data",
"-D",
@@ -2572,7 +2600,7 @@ mod native_tests {
assert!(status.success());
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"prove",
"-W",
@@ -2593,7 +2621,7 @@ mod native_tests {
let settings_arg = format!("--settings-path={}", settings_path);
// create encoded calldata
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"encode-evm-calldata",
"--proof-path",
@@ -2612,7 +2640,7 @@ mod native_tests {
args.push("--sol-code-path");
args.push(sol_arg.as_str());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2633,7 +2661,7 @@ mod native_tests {
args.push("--sol-code-path");
args.push(sol_arg.as_str());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2656,7 +2684,7 @@ mod native_tests {
create_da_args.push(test_on_chain_data_path.as_str());
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&create_da_args)
.status()
.expect("failed to execute process");
@@ -2669,7 +2697,7 @@ mod native_tests {
};
let addr_path_da_arg = format!("--addr-path={}/{}/addr_da.txt", test_dir, example_name);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"deploy-evm-da",
format!("--settings-path={}", settings_path).as_str(),
@@ -2707,14 +2735,14 @@ mod native_tests {
deployed_addr_da_arg.as_str(),
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
assert!(status.success());
// Create a new set of test on chain data only for the on-chain input source
if input_source != "file" || output_source != "file" {
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"setup-test-evm-data",
"-D",
@@ -2741,7 +2769,7 @@ mod native_tests {
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
@@ -2757,7 +2785,7 @@ mod native_tests {
deployed_addr_da_arg.as_str(),
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(args)
.status()
.expect("failed to execute process");
@@ -2768,18 +2796,28 @@ mod native_tests {
#[cfg(feature = "icicle")]
let args = [
"build",
"--release",
"--profile=test-runs",
"--bin",
"ezkl",
"--features",
"icicle",
];
#[cfg(not(feature = "icicle"))]
let args = ["build", "--release", "--bin", "ezkl"];
#[cfg(feature = "macos-metal")]
let args = [
"build",
"--profile=test-runs",
"--bin",
"ezkl",
"--features",
"macos-metal",
];
// not macos-metal and not icicle
#[cfg(all(not(feature = "icicle"), not(feature = "macos-metal")))]
let args = ["build", "--profile=test-runs", "--bin", "ezkl"];
#[cfg(not(feature = "mv-lookup"))]
let args = [
"build",
"--release",
"--profile=test-runs",
"--bin",
"ezkl",
"--no-default-features",
@@ -2800,7 +2838,7 @@ mod native_tests {
let status = Command::new("wasm-pack")
.args([
"build",
"--release",
"--profile=test-runs",
"--target",
"nodejs",
"--out-dir",

View File

@@ -72,11 +72,10 @@ mod py_tests {
"torchtext==0.17.2",
"torchvision==0.17.2",
"pandas==2.2.1",
"numpy==1.26.4",
"seaborn==0.13.2",
"notebook==7.1.2",
"nbconvert==7.16.3",
"onnx==1.16.0",
"onnx==1.17.0",
"kaggle==1.6.8",
"py-solc-x==2.0.3",
"web3==7.5.0",
@@ -90,12 +89,13 @@ mod py_tests {
"xgboost==2.0.3",
"hummingbird-ml==0.4.11",
"lightgbm==4.3.0",
"numpy==1.26.4",
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new("pip")
.args(["install", "numpy==1.23"])
.args(["install", "numpy==1.26.4"])
.status()
.expect("failed to execute process");
@@ -126,10 +126,10 @@ mod py_tests {
}
const TESTS: [&str; 35] = [
"ezkl_demo_batch.ipynb", // 0
"proof_splitting.ipynb", // 1
"variance.ipynb", // 2
"mnist_gan.ipynb", // 3
"mnist_gan.ipynb", // 0
"ezkl_demo_batch.ipynb", // 1
"proof_splitting.ipynb", // 2
"variance.ipynb", // 3
"keras_simple_demo.ipynb", // 4
"mnist_gan_proof_splitting.ipynb", // 5
"hashed_vis.ipynb", // 6

View File

@@ -873,6 +873,7 @@ def get_examples():
'linear_regression',
"mnist_gan",
"smallworm",
"fr_age"
]
examples = []
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):