mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
1 Commits
v18.1.9
...
ac/patch-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73f5f95aad |
72
.github/workflows/engine.yml
vendored
72
.github/workflows/engine.yml
vendored
@@ -19,8 +19,6 @@ jobs:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
@@ -47,39 +45,43 @@ jobs:
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
echo '{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "${RELEASE_TAG}",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}' > pkg/package.json
|
||||
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
@@ -193,8 +195,6 @@ jobs:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -202,8 +202,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${RELEASE_TAG}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
|
||||
|
||||
4
.github/workflows/pypi-gpu.yml
vendored
4
.github/workflows/pypi-gpu.yml
vendored
@@ -25,8 +25,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -51,6 +49,8 @@ jobs:
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
|
||||
109
.github/workflows/pypi.yml
vendored
109
.github/workflows/pypi.yml
vendored
@@ -23,8 +23,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64, universal2-apple-darwin]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -34,14 +32,10 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
@@ -95,14 +89,6 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: ${{ matrix.target }}
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -152,14 +138,6 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -170,6 +148,7 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -208,6 +187,57 @@ jobs:
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
|
||||
# linux-cross:
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
# matrix:
|
||||
# target: [aarch64, armv7]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: 3.12
|
||||
|
||||
# - name: Install cross-compilation tools for aarch64
|
||||
# if: matrix.target == 'aarch64'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
|
||||
|
||||
# - name: Install cross-compilation tools for armv7
|
||||
# if: matrix.target == 'armv7'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
|
||||
|
||||
# - name: Build wheels
|
||||
# uses: PyO3/maturin-action@v1
|
||||
# with:
|
||||
# target: ${{ matrix.target }}
|
||||
# manylinux: auto
|
||||
# args: --release --out dist --features python-bindings
|
||||
|
||||
# - uses: uraimo/run-on-arch-action@v2.5.0
|
||||
# name: Install built wheel
|
||||
# with:
|
||||
# arch: ${{ matrix.target }}
|
||||
# distro: ubuntu20.04
|
||||
# githubToken: ${{ github.token }}
|
||||
# install: |
|
||||
# apt-get update
|
||||
# apt-get install -y --no-install-recommends python3 python3-pip
|
||||
# pip3 install -U pip
|
||||
# run: |
|
||||
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
|
||||
# python3 -c "import ezkl"
|
||||
|
||||
# - name: Upload wheels
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: wheels
|
||||
# path: dist
|
||||
|
||||
musllinux:
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -243,7 +273,6 @@ jobs:
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -294,14 +323,6 @@ jobs:
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -345,6 +366,8 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
if: "startsWith(github.ref, 'refs/tags/')"
|
||||
# TODO: Uncomment if linux-cross is working
|
||||
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
|
||||
needs: [macos, windows, linux, musllinux, musllinux-cross]
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
@@ -352,20 +375,24 @@ jobs:
|
||||
name: wheels
|
||||
- name: List Files
|
||||
run: ls -R
|
||||
|
||||
# # publishes to TestPyPI
|
||||
# - name: Publish package distribution to TestPyPI
|
||||
# uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# packages-dir: ./
|
||||
|
||||
# Both publish steps will fail if there is no trusted publisher setup
|
||||
# On failure the publish step will then simply continue to the next one
|
||||
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
permissions:
|
||||
@@ -382,4 +409,4 @@ jobs:
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
|
||||
354
.github/workflows/rust.yml
vendored
354
.github/workflows/rust.yml
vendored
@@ -19,8 +19,8 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
|
||||
fr-age-test:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
@@ -33,12 +33,8 @@ jobs:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: fr age Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
|
||||
run: cargo test --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
|
||||
|
||||
build:
|
||||
permissions:
|
||||
@@ -103,8 +99,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -163,13 +159,13 @@ jobs:
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
run: RUST_LOG=debug cargo nextest run --release matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
- name: Conv overflow
|
||||
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
|
||||
ultra-overflow-tests:
|
||||
permissions:
|
||||
@@ -200,13 +196,13 @@ jobs:
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run lookup_ultra_overflow --no-capture -- --include-ignored
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture -- --include-ignored
|
||||
- name: Conv overflow
|
||||
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
|
||||
|
||||
model-serialization:
|
||||
permissions:
|
||||
@@ -231,7 +227,7 @@ jobs:
|
||||
wasm32-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -244,7 +240,7 @@ jobs:
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
version: 'v0.12.1'
|
||||
- uses: nanasess/setup-chromedriver@v2
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
@@ -261,6 +257,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -275,53 +272,53 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: MNIST Gan Mock
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
|
||||
- name: NanoGPT Mock
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_1_expects -- --include-ignored
|
||||
- name: Self Attention Mock
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_0_expects -- --include-ignored
|
||||
- name: Multihead Attention Mock
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
|
||||
- name: public outputs
|
||||
run: cargo nextest run --verbose tests::mock_public_outputs_ --test-threads 32
|
||||
- name: public inputs
|
||||
run: cargo nextest run --verbose tests::mock_public_inputs_ --test-threads 32
|
||||
- name: fixed params
|
||||
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
- name: kzg inputs
|
||||
run: cargo nextest run --verbose tests::mock_kzg_input_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_input_::t --test-threads 32
|
||||
- name: kzg params
|
||||
run: cargo nextest run --verbose tests::mock_kzg_params_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_params_::t --test-threads 32
|
||||
- name: kzg outputs
|
||||
run: cargo nextest run --verbose tests::mock_kzg_output_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_output_::t --test-threads 32
|
||||
- name: kzg inputs + params + outputs
|
||||
run: cargo nextest run --verbose tests::mock_kzg_all_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_kzg_all_::t --test-threads 32
|
||||
- name: Mock fixed inputs
|
||||
run: cargo nextest run --verbose tests::mock_fixed_inputs_ --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_inputs_ --test-threads 32
|
||||
- name: Mock fixed outputs
|
||||
run: cargo nextest run --verbose tests::mock_fixed_outputs --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_outputs --test-threads 32
|
||||
- name: Mock accuracy calibration
|
||||
run: cargo nextest run --verbose tests::mock_accuracy_cal_tests::a
|
||||
run: cargo nextest run --release --verbose tests::mock_accuracy_cal_tests::a
|
||||
- name: hashed inputs
|
||||
run: cargo nextest run --verbose tests::mock_hashed_input_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
|
||||
- name: hashed params
|
||||
run: cargo nextest run --verbose tests::mock_hashed_params_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
|
||||
- name: hashed params public inputs
|
||||
run: cargo nextest run --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
|
||||
- name: hashed outputs
|
||||
run: cargo nextest run --verbose tests::mock_hashed_output_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
|
||||
- name: hashed inputs + params + outputs
|
||||
run: cargo nextest run --verbose tests::mock_hashed_all_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_all_::t --test-threads 32
|
||||
- name: hashed inputs + fixed params
|
||||
run: cargo nextest run --verbose tests::mock_hashed_output_fixed_params_::t --test-threads 32
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_output_fixed_params_::t --test-threads 32
|
||||
- name: MNIST Gan Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
|
||||
- name: NanoGPT Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_1_expects -- --include-ignored
|
||||
- name: Self Attention Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_0_expects -- --include-ignored
|
||||
- name: Multihead Attention Mock
|
||||
run: cargo nextest run --release --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
|
||||
- name: public outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_public_outputs_ --test-threads 32
|
||||
- name: public inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_public_inputs_ --test-threads 32
|
||||
- name: fixed params
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
|
||||
prove-and-verify-evm-tests:
|
||||
permissions:
|
||||
@@ -364,7 +361,7 @@ jobs:
|
||||
NODE_ENV: development
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
|
||||
@@ -378,35 +375,35 @@ jobs:
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg inputs)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_input_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_input_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg params)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain outputs)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & outputs)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & kzg outputs + params)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain all kzg)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + hashed inputs)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_input_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_input_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + hashed params)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_params_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + hashed outputs)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
|
||||
|
||||
# prove-and-verify-tests-metal:
|
||||
# permissions:
|
||||
@@ -440,7 +437,8 @@ jobs:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
|
||||
# run: cargo nextest run --release --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
|
||||
|
||||
|
||||
prove-and-verify-tests:
|
||||
permissions:
|
||||
@@ -459,7 +457,7 @@ jobs:
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
version: 'v0.12.1'
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
@@ -489,40 +487,40 @@ jobs:
|
||||
locked: true
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
- name: KZG prove and verify tests (hashed inputs + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
- name: IPA prove and verify tests
|
||||
run: cargo nextest run --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
- name: IPA prove and verify tests (ipa outputs)
|
||||
run: cargo nextest run --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
- name: KZG prove and verify tests single inner col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_single_col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_single_col
|
||||
- name: KZG prove and verify tests triple inner col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_triple_col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_triple_col
|
||||
- name: KZG prove and verify tests quadruple inner col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_quadruple_col
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_quadruple_col
|
||||
- name: KZG prove and verify tests octuple inner col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_octuple_col --test-threads 8
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed
|
||||
|
||||
# prove-and-verify-tests-gpu:
|
||||
# runs-on: GPU
|
||||
@@ -530,8 +528,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -545,21 +543,21 @@ jobs:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (kzg outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public inputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (fixed params)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (hashed outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
permissions:
|
||||
@@ -580,7 +578,7 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Mock aggr tests (KZG)
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
|
||||
# prove-and-verify-aggr-tests-gpu:
|
||||
# runs-on: GPU
|
||||
@@ -588,8 +586,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -621,7 +619,7 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG tests
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
permissions:
|
||||
@@ -646,7 +644,7 @@ jobs:
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
examples:
|
||||
permissions:
|
||||
@@ -695,7 +693,7 @@ jobs:
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
|
||||
|
||||
@@ -723,15 +721,15 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Public inputs
|
||||
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_public_inputs_
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
|
||||
- name: fixed params
|
||||
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_fixed_params_
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_fixed_params_
|
||||
- name: Public outputs
|
||||
run: source .env/bin/activate; cargo nextest run --verbose tests::accuracy_measurement_public_outputs_
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_outputs_
|
||||
- name: Public outputs + resources
|
||||
run: source .env/bin/activate; cargo nextest run --verbose tests::resources_accuracy_measurement_public_outputs_
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
|
||||
|
||||
python-integration-tests:
|
||||
permissions:
|
||||
@@ -780,11 +778,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Neural bow
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
@@ -802,87 +796,91 @@ jobs:
|
||||
# # now dump the contents of the file into a file called kaggle.json
|
||||
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Reusable verifier tutorial
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
|
||||
ios-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
|
||||
swift-package-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
22
Cargo.lock
generated
22
Cargo.lock
generated
@@ -1,6 +1,6 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
@@ -1760,7 +1760,7 @@ checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
|
||||
[[package]]
|
||||
name = "ecc"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"integer",
|
||||
"num-bigint",
|
||||
@@ -1968,11 +1968,13 @@ dependencies = [
|
||||
"objc",
|
||||
"openssl",
|
||||
"pg_bigdecimal",
|
||||
"portable-atomic",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"semver 1.0.22",
|
||||
"seq-macro",
|
||||
@@ -2604,7 +2606,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2_proofs",
|
||||
"num-bigint",
|
||||
@@ -2955,7 +2957,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "integer"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"maingate",
|
||||
"num-bigint",
|
||||
@@ -3139,7 +3141,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.52.6",
|
||||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3266,7 +3268,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "maingate"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2wrong",
|
||||
"num-bigint",
|
||||
@@ -3675,9 +3677,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-src"
|
||||
version = "300.4.1+3.4.0"
|
||||
version = "300.2.3+3.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faa4eac4138c62414b5622d1b31c5c304f34b406b013c079c2bbc652fdd6678c"
|
||||
checksum = "5cff92b6f71555b61bb9315f7c64da3ca43d87531622120fea0195fc761b4843"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
@@ -5230,7 +5232,7 @@ checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
|
||||
[[package]]
|
||||
name = "snark-verifier"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac%2Fchunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
dependencies = [
|
||||
"ecc",
|
||||
"halo2_proofs",
|
||||
@@ -6234,7 +6236,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "uniffi_testing"
|
||||
version = "0.28.0"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat%2Ftesting-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat/testing-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"camino",
|
||||
|
||||
12
Cargo.toml
12
Cargo.toml
@@ -40,6 +40,7 @@ maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = { version = "1.6.0", optional = true }
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
|
||||
semver = { version = "1.0.22", optional = true }
|
||||
|
||||
@@ -73,6 +74,7 @@ tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
regex = { version = "1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
@@ -242,14 +244,16 @@ ezkl = [
|
||||
"dep:indicatif",
|
||||
"dep:gag",
|
||||
"dep:reqwest",
|
||||
"dep:openssl",
|
||||
"dep:tokio-postgres",
|
||||
"dep:pg_bigdecimal",
|
||||
"dep:lazy_static",
|
||||
"dep:regex",
|
||||
"dep:tokio",
|
||||
"dep:openssl",
|
||||
"dep:mimalloc",
|
||||
"dep:chrono",
|
||||
"dep:sha256",
|
||||
"dep:portable-atomic",
|
||||
"dep:clap_complete",
|
||||
"dep:halo2_solidity_verifier",
|
||||
"dep:semver",
|
||||
@@ -288,13 +292,9 @@ uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "fea
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
#panic = "abort"
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
[profile.test-runs]
|
||||
inherits = "dev"
|
||||
opt-level = 3
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = [
|
||||
"-O4",
|
||||
|
||||
@@ -10,7 +10,6 @@ use rand::Rng;
|
||||
|
||||
// Assuming these are your types
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
enum ValType {
|
||||
Constant(F),
|
||||
AssignedConstant(usize, F),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '18.1.9'
|
||||
release = '0.0.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gip_run_args = ezkl.PyRunArgs()\n",
|
||||
"gip_run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
|
||||
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
|
||||
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
|
||||
@@ -336,9 +335,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -308,11 +308,8 @@
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.decomp_legs = 4\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
|
||||
@@ -167,8 +167,6 @@
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# \"hashed/private\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
|
||||
"run_args.input_visibility = \"hashed/private/0\"\n",
|
||||
"# as the inputs are felts we turn off input range checks\n",
|
||||
"run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"# we set it to fix the set we want to check membership for\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"# the output is public -- set membership fails if it is not = 0\n",
|
||||
@@ -521,4 +519,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,7 +204,6 @@
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
|
||||
"run_args.input_visibility = \"polycommit\"\n",
|
||||
"run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"# the parameters are public\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"# the output is public (this is the inequality test)\n",
|
||||
@@ -515,4 +514,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -60,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -94,7 +94,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -134,7 +134,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -183,7 +183,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -201,7 +201,6 @@
|
||||
"run_args.input_visibility = \"public\"\n",
|
||||
"run_args.param_visibility = \"private\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.decomp_legs=6\n",
|
||||
"run_args.num_inner_cols = 1\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]"
|
||||
]
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
// ignore file if compiling for wasm
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::{CommandFactory, Parser};
|
||||
|
||||
@@ -207,9 +207,6 @@ struct PyRunArgs {
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
|
||||
#[pyo3(get, set)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -242,7 +239,6 @@ impl From<PyRunArgs> for RunArgs {
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
decomp_base: py_run_args.decomp_base,
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -267,7 +263,6 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
commitment: self.commitment.into(),
|
||||
decomp_base: self.decomp_base,
|
||||
decomp_legs: self.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,10 +170,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
if message.len() != 1 {
|
||||
return Err(ModuleError::InputWrongLength(message.len()));
|
||||
}
|
||||
|
||||
assert_eq!(message.len(), 1);
|
||||
let message = message[0].clone();
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -228,7 +225,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"AssignedValue".to_string(),
|
||||
"PrevAssigned".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
@@ -293,12 +290,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
|
||||
// empty hash case
|
||||
if input_cells.is_empty() {
|
||||
return Ok(input[0].clone());
|
||||
}
|
||||
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
input_cells.iter().map(|e| ValType::from(e.clone())).into();
|
||||
@@ -520,21 +511,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash_empty() {
|
||||
let message = [];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 2> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash() {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
@@ -17,6 +17,7 @@ pub enum BaseOp {
|
||||
Sub,
|
||||
SumInit,
|
||||
Sum,
|
||||
IsBoolean,
|
||||
}
|
||||
|
||||
/// Matches a [BaseOp] to an operation over inputs
|
||||
@@ -33,6 +34,7 @@ impl BaseOp {
|
||||
BaseOp::Add => a + b,
|
||||
BaseOp::Sub => a - b,
|
||||
BaseOp::Mult => a * b,
|
||||
BaseOp::IsBoolean => b,
|
||||
_ => panic!("nonaccum_f called on accumulating operation"),
|
||||
}
|
||||
}
|
||||
@@ -72,6 +74,7 @@ impl BaseOp {
|
||||
BaseOp::Mult => "MULT",
|
||||
BaseOp::Sum => "SUM",
|
||||
BaseOp::SumInit => "SUMINIT",
|
||||
BaseOp::IsBoolean => "ISBOOLEAN",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,6 +90,7 @@ impl BaseOp {
|
||||
BaseOp::Mult => (0, 1),
|
||||
BaseOp::Sum => (-1, 2),
|
||||
BaseOp::SumInit => (0, 1),
|
||||
BaseOp::IsBoolean => (0, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +106,7 @@ impl BaseOp {
|
||||
BaseOp::Mult => 2,
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,6 +122,7 @@ impl BaseOp {
|
||||
BaseOp::SumInit => 0,
|
||||
BaseOp::CumProd => 1,
|
||||
BaseOp::CumProdInit => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::str::FromStr;
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector, TableColumn},
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
poly::Rotation,
|
||||
};
|
||||
use log::debug;
|
||||
@@ -341,8 +341,6 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
/// shared table inputs
|
||||
pub shared_table_inputs: Vec<TableColumn>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
@@ -355,7 +353,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
shared_table_inputs: vec![],
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -394,6 +391,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,13 +425,24 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.query_offset_rng();
|
||||
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("non accum: output query failed");
|
||||
let constraints = match base_op {
|
||||
BaseOp::IsBoolean => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res]
|
||||
let output = expected_output[base_op.constraint_idx()].clone();
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res]
|
||||
}
|
||||
};
|
||||
|
||||
Constraints::with_selector(selector, constraints)
|
||||
@@ -488,7 +497,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
shared_table_inputs: vec![],
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -519,9 +527,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
return Err(CircuitError::WrongColumnType(output.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let table = if !self.static_lookups.tables.contains_key(nl) {
|
||||
let table =
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, &mut self.shared_table_inputs);
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let table = if let Some(table) = self.static_lookups.tables.values().next() {
|
||||
Table::<F>::configure(
|
||||
cs,
|
||||
lookup_range,
|
||||
logrows,
|
||||
nl,
|
||||
Some(table.table_inputs.clone()),
|
||||
)
|
||||
} else {
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
|
||||
};
|
||||
self.static_lookups.tables.insert(nl.clone(), table.clone());
|
||||
table
|
||||
} else {
|
||||
@@ -572,9 +592,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// this is 0 if the index is the same as the column index (starting from 1)
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* (table
|
||||
* table
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
|
||||
let multiplier =
|
||||
table.selector_constructor.get_selector_val_at_idx(col_idx);
|
||||
@@ -606,40 +626,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.static_lookups
|
||||
.selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
@@ -880,6 +866,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
|
||||
self.range_checks.ranges.entry(range)
|
||||
{
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
@@ -917,9 +904,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let default_x = range_check.get_first_element(col_idx);
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* (range_check
|
||||
* range_check
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
|
||||
let multiplier = range_check
|
||||
.selector_constructor
|
||||
@@ -942,40 +929,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.range_checks
|
||||
.selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
|
||||
@@ -25,7 +25,7 @@ pub enum CircuitError {
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
/// Invalid einsum expression
|
||||
///
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
/// Flush error
|
||||
@@ -103,10 +103,4 @@ pub enum CircuitError {
|
||||
#[error("an element is missing from the shuffled version of the tensor")]
|
||||
/// An element is missing from the shuffled version of the tensor
|
||||
MissingShuffleElement,
|
||||
/// Visibility has not been set
|
||||
#[error("visibility has not been set")]
|
||||
UnsetVisibility,
|
||||
/// A decomposition base overflowed
|
||||
#[error("decomposition base overflowed")]
|
||||
DecompositionBaseOverflow,
|
||||
}
|
||||
|
||||
@@ -76,10 +76,7 @@ pub enum HybridOp {
|
||||
output_scale: utils::F32,
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Output {
|
||||
tol: Tolerance,
|
||||
decomp: bool,
|
||||
},
|
||||
RangeCheck(Tolerance),
|
||||
Greater,
|
||||
GreaterEqual,
|
||||
Less,
|
||||
@@ -181,9 +178,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale, output_scale, axes
|
||||
)
|
||||
}
|
||||
HybridOp::Output { tol, decomp } => {
|
||||
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
|
||||
}
|
||||
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
|
||||
HybridOp::Greater => "GREATER".to_string(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
|
||||
HybridOp::Less => "LESS".to_string(),
|
||||
@@ -319,13 +314,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
*output_scale,
|
||||
axes,
|
||||
)?,
|
||||
HybridOp::Output { tol, decomp } => layouts::output(
|
||||
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
tol.scale,
|
||||
tol.val,
|
||||
*decomp,
|
||||
)?,
|
||||
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
|
||||
HybridOp::GreaterEqual => {
|
||||
|
||||
@@ -75,7 +75,7 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>,
|
||||
f: impl Fn(&BaseConfig<F>, &mut RegionCtx<F>, &ValTensor<F>) -> Result<ValTensor<F>, CircuitError>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
) -> Result<(), CircuitError> {
|
||||
let one = create_constant_tensor(F::from(1), 1);
|
||||
|
||||
let f_x = f(config, region, x)?;
|
||||
@@ -87,17 +87,22 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
let f_x_minus_1 = f(config, region, &x_minus_1)?;
|
||||
|
||||
// because the function is convex, the result should be the minimum of the three
|
||||
// note that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) < f(x-1)
|
||||
// not that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) <= f(x-1)
|
||||
// the result is 1 if the function is optimal solely because of the convexity of the function
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal, but if (f(x) = f(x + 1))
|
||||
// f(x+1) is not smaller than f(x + 1 - 1) = f(x) and thus f(x) is unique
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal (or f(x) and f(x-1)).
|
||||
let f_x_is_opt_rhs = less_equal(config, region, &[f_x.clone(), f_x_plus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less_equal(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
|
||||
let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;
|
||||
|
||||
Ok(is_opt)
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Err is less than some constant
|
||||
@@ -155,13 +160,13 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
// implicitly check if the prover provided output is within range
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -221,9 +226,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// implicitly check if the prover provided output is within range
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
// divide by input_scale
|
||||
let zero_inverse_val =
|
||||
tensor::ops::nonlinearities::zero_recip(felt_to_integer_rep(output_scale) as f64)[0];
|
||||
@@ -254,10 +259,10 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -285,14 +290,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
// we need to add 1 to the points where it is zero to ignore the cvx opt conditions at those points
|
||||
let mut is_opt = optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
is_opt = pairwise(config, region, &[is_opt, equal_zero_mask], BaseOp::Add)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
@@ -346,8 +344,12 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
// force the output to be positive or zero, also implicitly checks that the ouput is in range
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// force the output to be positive or zero
|
||||
let claimed_output = abs(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
// rescaled input
|
||||
let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?;
|
||||
|
||||
@@ -360,13 +362,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
let is_opt = optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
@@ -1181,6 +1177,8 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
|
||||
region.enable(Some(lookup_selector), z)?;
|
||||
|
||||
// region.enable(Some(lookup_selector), z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
@@ -1200,7 +1198,7 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
/// 4. index_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 5. value_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 6. Given the above, and given the fixed index_input , we go through every (index_input, value_input) pair and ascertain that it is contained in the input.
|
||||
/// 7. Given the fixed incrementing index index_input, we avoid multiplicity in the output by leveraging this surrogate index: if index_output isn't matched to the exact value where for `index_input=index_output` -> `value_input=value_output`, then the lookup fails
|
||||
/// Given the fixed incrementing index index_input, we avoid multiplicity in the output by leveraging this surrogate index: if index_output isn't matched to the exact value where for `index_input=index_output` -> `value_input=value_output`, then the lookup fails
|
||||
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
@@ -2979,7 +2977,7 @@ pub fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let lhs_and_rhs_not = and(config, region, &[lhs, rhs_not.clone()])?;
|
||||
let lhs_not_and_rhs = and(config, region, &[rhs, lhs_not])?;
|
||||
|
||||
// we can safely use add and not OR here because we know that lhs_and_rhs_not and lhs_not_and_rhs are =1 at different indices
|
||||
// we can safely use add and not OR here because we know that lhs_and_rhs_not and lhs_not_and_rhs are =1 at different incices
|
||||
let res: ValTensor<F> = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -3256,11 +3254,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.map(|(i, d)| {
|
||||
let d = padding[i].0 + d + padding[i].1;
|
||||
d.checked_sub(pool_dims[i])
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))?
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))?
|
||||
.checked_div(stride[i])
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))?
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))?
|
||||
.checked_add(1)
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))
|
||||
})
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
|
||||
@@ -3919,24 +3917,11 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
decomp: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut output = values[0].clone();
|
||||
if !output.all_prev_assigned() {
|
||||
// checks they are in range
|
||||
if decomp {
|
||||
output = decompose(
|
||||
config,
|
||||
region,
|
||||
&[output.clone()],
|
||||
®ion.base(),
|
||||
®ion.legs(),
|
||||
)?
|
||||
.1;
|
||||
} else {
|
||||
output = region.assign(&config.custom_gates.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
}
|
||||
output = region.assign(&config.custom_gates.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
@@ -3957,8 +3942,23 @@ pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd + std::ha
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
range_check(config, region, values, &(0, 1))?;
|
||||
let (x, y, z) = config.custom_gates.output.cartesian_coord(index);
|
||||
let selector = config
|
||||
.custom_gates
|
||||
.selectors
|
||||
.get(&(BaseOp::IsBoolean, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -4411,7 +4411,7 @@ pub fn floor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?.0;
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
@@ -4524,7 +4524,7 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?.0;
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
@@ -4678,7 +4678,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input.dims())?;
|
||||
let claimed_output = identity(&config, region, &[claimed_output], true)?;
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
let pow2_of_claimed_output = nonlinearity(
|
||||
@@ -4924,7 +4924,7 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?.0;
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
@@ -4942,7 +4942,6 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
1,
|
||||
);
|
||||
let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?;
|
||||
region.increment(assigned_midway_point.len());
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
@@ -5068,7 +5067,7 @@ pub fn round_half_to_even<F: PrimeField + TensorType + PartialOrd + std::hash::H
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?.0;
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
@@ -5176,64 +5175,58 @@ pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
values: &[ValTensor<F>; 1],
|
||||
base: &usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut input = values[0].clone();
|
||||
let input = values[0].clone();
|
||||
|
||||
let first_dims = input.dims().to_vec()[..input.dims().len() - 1].to_vec();
|
||||
let num_first_dims = first_dims.iter().product::<usize>();
|
||||
let n = input.dims().last().unwrap() - 1;
|
||||
|
||||
if !input.all_prev_assigned() {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
region.increment(input.len());
|
||||
}
|
||||
let is_assigned = !input.all_prev_assigned();
|
||||
|
||||
// to force the bases to be assigned
|
||||
if input.is_singleton() {
|
||||
input.reshape(&[1])?;
|
||||
}
|
||||
|
||||
let mut bases: ValTensor<F> = Tensor::from({
|
||||
(0..num_first_dims)
|
||||
.flat_map(|_| {
|
||||
(0..n).rev().map(|x| {
|
||||
let base = (*base).checked_pow(x as u32);
|
||||
if let Some(base) = base {
|
||||
Ok(ValType::Constant(integer_rep_to_felt(base as IntegerRep)))
|
||||
} else {
|
||||
Err(CircuitError::DecompositionBaseOverflow)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?
|
||||
.into_iter()
|
||||
})
|
||||
let bases: ValTensor<F> = Tensor::from(
|
||||
(0..n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))),
|
||||
)
|
||||
.into();
|
||||
let mut bases_dims = first_dims.clone();
|
||||
bases_dims.push(n);
|
||||
bases.reshape(&bases_dims)?;
|
||||
|
||||
// equation needs to be constructed as ij,j->i but for arbitrary n dims we need to construct this dynamically
|
||||
// indices should map in order of the alphabet
|
||||
// start with lhs
|
||||
let lhs = ASCII_ALPHABET.chars().take(input.dims().len()).join("");
|
||||
let rhs = ASCII_ALPHABET.chars().take(input.dims().len() - 1).join("");
|
||||
// multiply and sum the values
|
||||
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let equation = format!("{},{}->{}", lhs, lhs, rhs);
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut sign_slice = first_dims.iter().map(|x| 0..*x).collect::<Vec<_>>();
|
||||
sign_slice.push(0..1);
|
||||
let mut rest_slice = first_dims.iter().map(|x| 0..*x).collect::<Vec<_>>();
|
||||
rest_slice.push(1..n + 1);
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
|
||||
let sign = input.get_slice(&sign_slice)?;
|
||||
let rest = input.get_slice(&rest_slice)?;
|
||||
if !is_assigned {
|
||||
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
|
||||
}
|
||||
|
||||
// now add the rhs
|
||||
let prod_recomp = einsum(config, region, &[rest.clone(), bases], &equation)?;
|
||||
let mut signed_recomp = pairwise(config, region, &[prod_recomp, sign], BaseOp::Mult)?;
|
||||
signed_recomp.reshape(&first_dims)?;
|
||||
// get the sign bit and make sure it is valid
|
||||
let sign = sliced_input.first()?;
|
||||
let rest = sliced_input.get_slice(&[1..sliced_input.len()])?;
|
||||
|
||||
Ok(signed_recomp.into())
|
||||
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
|
||||
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
Ok(signed_decomp.get_inner_tensor()?.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let mut combined_output = output.combine()?;
|
||||
|
||||
combined_output.reshape(&first_dims)?;
|
||||
|
||||
Ok(combined_output.into())
|
||||
}
|
||||
|
||||
pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
@@ -5242,35 +5235,24 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
values: &[ValTensor<F>; 1],
|
||||
base: &usize,
|
||||
n: &usize,
|
||||
) -> Result<(ValTensor<F>, ValTensor<F>), CircuitError> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut input = values[0].clone();
|
||||
|
||||
if !input.all_prev_assigned() {
|
||||
let is_assigned = !input.all_prev_assigned();
|
||||
|
||||
if !is_assigned {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &input)?;
|
||||
}
|
||||
|
||||
// to force the bases to be assigned
|
||||
if input.is_singleton() {
|
||||
input.reshape(&[1])?;
|
||||
}
|
||||
|
||||
let mut bases: ValTensor<F> = Tensor::from({
|
||||
(0..input.len())
|
||||
.flat_map(|_| {
|
||||
(0..*n).rev().map(|x| {
|
||||
let base = (*base).checked_pow(x as u32);
|
||||
if let Some(base) = base {
|
||||
Ok(ValType::Constant(integer_rep_to_felt(base as IntegerRep)))
|
||||
} else {
|
||||
Err(CircuitError::DecompositionBaseOverflow)
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?
|
||||
.into_iter()
|
||||
})
|
||||
let mut bases: ValTensor<F> = Tensor::from(
|
||||
// repeat it input.len() times
|
||||
(0..input.len()).flat_map(|_| {
|
||||
(0..*n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep)))
|
||||
}),
|
||||
)
|
||||
.into();
|
||||
|
||||
let mut bases_dims = input.dims().to_vec();
|
||||
bases_dims.push(*n);
|
||||
bases.reshape(&bases_dims)?;
|
||||
@@ -5289,7 +5271,7 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
|
||||
claimed_output.into()
|
||||
};
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
let input_slice = input.dims().iter().map(|x| 0..*x).collect::<Vec<_>>();
|
||||
@@ -5334,9 +5316,9 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
enforce_equality(config, region, &[input.clone(), signed_decomp])?;
|
||||
enforce_equality(config, region, &[input, signed_decomp])?;
|
||||
|
||||
Ok((claimed_output, input))
|
||||
Ok(claimed_output)
|
||||
}
|
||||
|
||||
pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
@@ -5344,7 +5326,7 @@ pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut decomp = decompose(config, region, values, ®ion.base(), ®ion.legs())?.0;
|
||||
let mut decomp = decompose(config, region, values, ®ion.base(), ®ion.legs())?;
|
||||
// get every n elements now, which correspond to the sign bit
|
||||
decomp.get_every_n(region.legs() + 1)?;
|
||||
decomp.reshape(values[0].dims())?;
|
||||
@@ -5626,7 +5608,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::output;
|
||||
/// use ezkl::circuit::ops::layouts::range_check_percent;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
@@ -5644,33 +5626,29 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[101, 201, 302, 403, 503, 603]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0, false).unwrap();
|
||||
/// let result = range_check_percent::<Fp>(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0).unwrap();
|
||||
/// ```
|
||||
pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
scale: utils::F32,
|
||||
tol: f32,
|
||||
decomp: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut values = [values[0].clone(), values[1].clone()];
|
||||
|
||||
if !values[0].all_prev_assigned() {
|
||||
// range check the outputs
|
||||
values[0] = layouts::identity(config, region, &[values[0].clone()], decomp)?;
|
||||
}
|
||||
|
||||
if !values[1].all_prev_assigned() {
|
||||
// range check the outputs
|
||||
values[1] = layouts::identity(config, region, &[values[1].clone()], decomp)?;
|
||||
}
|
||||
|
||||
if tol == 0.0 {
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
|
||||
return enforce_equality(config, region, values);
|
||||
}
|
||||
|
||||
let mut values = [values[0].clone(), values[1].clone()];
|
||||
|
||||
values[0] = region.assign(&config.custom_gates.inputs[0], &values[0])?;
|
||||
values[1] = region.assign(&config.custom_gates.inputs[1], &values[1])?;
|
||||
let total_assigned_0 = values[0].len();
|
||||
let total_assigned_1 = values[1].len();
|
||||
let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1);
|
||||
region.increment(total_assigned);
|
||||
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
|
||||
|
||||
|
||||
@@ -159,8 +159,6 @@ pub struct Input {
|
||||
pub scale: crate::Scale,
|
||||
///
|
||||
pub datum_type: InputType,
|
||||
/// decomp check
|
||||
pub decomp: bool,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
@@ -198,7 +196,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
self.decomp,
|
||||
)?)),
|
||||
}
|
||||
} else {
|
||||
@@ -254,26 +251,20 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
///
|
||||
#[serde(skip)]
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
///
|
||||
pub decomp: bool,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>, decomp: bool) -> Self {
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
Self {
|
||||
quantized_values,
|
||||
raw_values,
|
||||
pre_assigned_val: None,
|
||||
decomp,
|
||||
}
|
||||
}
|
||||
/// Rebase the scale of the constant
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
|
||||
let visibility = match self.quantized_values.visibility() {
|
||||
Some(v) => v,
|
||||
None => return Err(CircuitError::UnsetVisibility),
|
||||
};
|
||||
let visibility = self.quantized_values.visibility().unwrap();
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -317,12 +308,7 @@ impl<
|
||||
self.quantized_values.clone().try_into()?
|
||||
};
|
||||
// we gotta constrain it once if its used multiple times
|
||||
Ok(Some(layouts::identity(
|
||||
config,
|
||||
region,
|
||||
&[value],
|
||||
self.decomp,
|
||||
)?))
|
||||
Ok(Some(layouts::identity(config, region, &[value])?))
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
|
||||
@@ -252,12 +252,6 @@ impl<
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 1 {
|
||||
return Err(TensorError::DimError(
|
||||
"GatherElements only accepts single inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
@@ -275,12 +269,6 @@ impl<
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 2 {
|
||||
return Err(TensorError::DimError(
|
||||
"ScatterElements requires two inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
@@ -323,9 +311,7 @@ impl<
|
||||
PolyOp::Mult => {
|
||||
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
|
||||
}
|
||||
PolyOp::Identity { .. } => {
|
||||
layouts::identity(config, region, values[..].try_into()?, false)?
|
||||
}
|
||||
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
|
||||
PolyOp::Pad(p) => {
|
||||
if values.len() != 1 {
|
||||
|
||||
@@ -132,16 +132,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
/// calculates the column size given the number of rows and reserved blinding rows
|
||||
///
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / col_size as IntegerRep) as usize + 1
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
@@ -163,7 +168,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
nonlinearity: &LookupOp,
|
||||
preexisting_inputs: &mut Vec<TableColumn>,
|
||||
preexisting_inputs: Option<Vec<TableColumn>>,
|
||||
) -> Table<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
@@ -172,28 +177,28 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
debug!("table range: {:?}", range);
|
||||
|
||||
// validate enough columns are provided to store the range
|
||||
if preexisting_inputs.len() < num_cols {
|
||||
// add columns to match the required number of columns
|
||||
let diff = num_cols - preexisting_inputs.len();
|
||||
for _ in 0..diff {
|
||||
preexisting_inputs.push(cs.lookup_table_column());
|
||||
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
|
||||
let mut cols = vec![];
|
||||
for _ in 0..num_cols {
|
||||
cols.push(cs.lookup_table_column());
|
||||
}
|
||||
}
|
||||
cols
|
||||
});
|
||||
|
||||
let num_cols = table_inputs.len();
|
||||
|
||||
let num_cols = preexisting_inputs.len();
|
||||
if num_cols > 1 {
|
||||
warn!("Using {} columns for non-linearity table.", num_cols);
|
||||
}
|
||||
|
||||
let table_outputs = preexisting_inputs
|
||||
let table_outputs = table_inputs
|
||||
.iter()
|
||||
.map(|_| cs.lookup_table_column())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Table {
|
||||
nonlinearity: nonlinearity.clone(),
|
||||
table_inputs: preexisting_inputs.clone(),
|
||||
table_inputs,
|
||||
table_outputs,
|
||||
is_assigned: false,
|
||||
selector_constructor: SelectorConstructor::new(num_cols),
|
||||
@@ -350,11 +355,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
/// calculates the column size
|
||||
///
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
|
||||
@@ -517,7 +517,6 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn deploy_multi_da_contract(
|
||||
client: EthersClient,
|
||||
contract_instance_offset: usize,
|
||||
|
||||
@@ -118,7 +118,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
} => gen_srs_cmd(
|
||||
srs_path,
|
||||
logrows as u32,
|
||||
commitment.unwrap_or_else(|| Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
),
|
||||
Commands::GetSrs {
|
||||
srs_path,
|
||||
@@ -1535,8 +1535,7 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
trace!("params computed");
|
||||
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data =
|
||||
GraphData::from_path(input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
|
||||
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
// The number of input and output instances we attest to for the single call data attestation
|
||||
let mut input_len = None;
|
||||
@@ -2127,7 +2126,6 @@ pub(crate) fn mock_aggregate(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
|
||||
@@ -9,8 +9,6 @@ pub type IntegerRep = i128;
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else if x == IntegerRep::MIN {
|
||||
-F::from_u128(x.saturating_neg() as u128) - F::ONE
|
||||
} else {
|
||||
-F::from_u128(x.saturating_neg() as u128)
|
||||
}
|
||||
@@ -34,9 +32,6 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN;
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -56,7 +51,7 @@ mod test {
|
||||
use halo2curves::pasta::Fp as F;
|
||||
|
||||
#[test]
|
||||
fn integerreptofelt() {
|
||||
fn test_conv() {
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
@@ -78,20 +73,4 @@ mod test {
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmin() {
|
||||
let x = IntegerRep::MIN;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmax() {
|
||||
let x = IntegerRep::MAX;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,12 +11,6 @@ pub enum GraphError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Non scalar power
|
||||
#[error("we only support scalar powers")]
|
||||
NonScalarPower,
|
||||
/// Non scalar base for exponentiation
|
||||
#[error("we only support scalar bases for exponentiation")]
|
||||
NonScalarBase,
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
@@ -119,13 +113,13 @@ pub enum GraphError {
|
||||
/// Missing input for a node
|
||||
#[error("missing input for node {0}")]
|
||||
MissingInput(usize),
|
||||
/// Ranges can only be constant
|
||||
///
|
||||
#[error("range only supports constant inputs in a zk circuit")]
|
||||
NonConstantRange,
|
||||
/// Trilu diagonal must be constant
|
||||
///
|
||||
#[error("trilu only supports constant diagonals in a zk circuit")]
|
||||
NonConstantTrilu,
|
||||
/// The witness was too short
|
||||
///
|
||||
#[error("insufficient witness values to generate a fixed output")]
|
||||
InsufficientWitnessValues,
|
||||
/// Missing scale
|
||||
@@ -149,13 +143,4 @@ pub enum GraphError {
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
/// Only nearest neighbor interpolation is supported
|
||||
#[error("only nearest neighbor interpolation is supported")]
|
||||
InvalidInterpolation,
|
||||
/// Node has a missing output
|
||||
#[error("node {0} has a missing output")]
|
||||
MissingOutput(usize),
|
||||
/// Inssuficient advice columns
|
||||
#[error("insuficcient advice columns (need {0} at least)")]
|
||||
InsufficientAdviceColumns(usize),
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
|
||||
|
||||
@@ -32,95 +31,30 @@ type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
|
||||
/// Represents different types of values that can be stored in a file source
|
||||
/// Used for handling various input types in zero-knowledge proofs
|
||||
///
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum FileSourceInner {
|
||||
/// Floating point value (64-bit)
|
||||
/// Inner elements of float inputs coming from a file
|
||||
Float(f64),
|
||||
/// Boolean value
|
||||
/// Inner elements of bool inputs coming from a file
|
||||
Bool(bool),
|
||||
/// Field element value for direct use in circuits
|
||||
/// Inner elements of inputs coming from a witness
|
||||
Field(Fp),
|
||||
}
|
||||
|
||||
impl FileSourceInner {
|
||||
/// Returns true if the value is a floating point number
|
||||
///
|
||||
pub fn is_float(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Float(_))
|
||||
}
|
||||
|
||||
/// Returns true if the value is a boolean
|
||||
///
|
||||
pub fn is_bool(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Bool(_))
|
||||
}
|
||||
|
||||
/// Returns true if the value is a field element
|
||||
///
|
||||
pub fn is_field(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Field(_))
|
||||
}
|
||||
|
||||
/// Creates a new floating point value
|
||||
pub fn new_float(f: f64) -> Self {
|
||||
FileSourceInner::Float(f)
|
||||
}
|
||||
|
||||
/// Creates a new field element value
|
||||
pub fn new_field(f: Fp) -> Self {
|
||||
FileSourceInner::Field(f)
|
||||
}
|
||||
|
||||
/// Creates a new boolean value
|
||||
pub fn new_bool(f: bool) -> Self {
|
||||
FileSourceInner::Bool(f)
|
||||
}
|
||||
|
||||
/// Adjusts the value according to the specified input type
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_type` - Type specification to convert the value to
|
||||
pub fn as_type(&mut self, input_type: &InputType) {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
||||
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
||||
FileSourceInner::Field(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the value to a field element using appropriate scaling
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `scale` - Scaling factor for floating point conversion
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
} else {
|
||||
Fp::zero()
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => *f,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the value to a floating point number
|
||||
pub fn to_float(&self) -> f64 {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => *f,
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for FileSourceInner {
|
||||
@@ -136,8 +70,8 @@ impl Serialize for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
// Deserialization implementation for FileSourceInner
|
||||
// Uses JSON deserialization to handle the different variants
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -164,16 +98,70 @@ impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
/// A collection of input values from a file source
|
||||
/// Organized as a vector of vectors where each inner vector represents a row/entry
|
||||
/// Elements of inputs coming from a file
|
||||
pub type FileSource = Vec<Vec<FileSourceInner>>;
|
||||
|
||||
/// Represents different types of calls for fetching on-chain data
|
||||
impl FileSourceInner {
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_float(f: f64) -> Self {
|
||||
FileSourceInner::Float(f)
|
||||
}
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_field(f: Fp) -> Self {
|
||||
FileSourceInner::Field(f)
|
||||
}
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_bool(f: bool) -> Self {
|
||||
FileSourceInner::Bool(f)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn as_type(&mut self, input_type: &InputType) {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
||||
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
||||
FileSourceInner::Field(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
} else {
|
||||
Fp::zero()
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => *f,
|
||||
}
|
||||
}
|
||||
/// Convert to a float
|
||||
pub fn to_float(&self) -> f64 {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => *f,
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Call type for attested inputs on-chain
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum Calls {
|
||||
/// Multiple calls to different accounts, each returning individual values
|
||||
/// Vector of calls to accounts, each returning an attested data point
|
||||
Multiple(Vec<CallsToAccount>),
|
||||
/// Single call returning an array of values
|
||||
/// Single call to account, returning an array of attested data points
|
||||
Single(CallToAccount),
|
||||
}
|
||||
|
||||
@@ -182,6 +170,32 @@ impl Default for Calls {
|
||||
Calls::Multiple(Vec::new())
|
||||
}
|
||||
}
|
||||
/// Inner elements of inputs/outputs coming from on-chain
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Calls to accounts
|
||||
pub calls: Calls,
|
||||
/// RPC url
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Create a new OnChainSource with multiple calls
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new OnChainSource with a single call
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Calls {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
@@ -203,6 +217,7 @@ impl<'de> Deserialize<'de> for Calls {
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = multiple_try {
|
||||
return Ok(Calls::Multiple(t));
|
||||
@@ -212,52 +227,111 @@ impl<'de> Deserialize<'de> for Calls {
|
||||
return Ok(Calls::Single(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize Calls"))
|
||||
Err(serde::de::Error::custom(
|
||||
"failed to deserialize FileSourceInner",
|
||||
))
|
||||
}
|
||||
}
|
||||
/// Configuration for accessing on-chain data sources
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Inner elements of inputs/outputs coming from postgres DB
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Call specifications for fetching data
|
||||
pub calls: Calls,
|
||||
/// RPC endpoint URL for accessing the chain
|
||||
pub rpc: RPCUrl,
|
||||
pub struct PostgresSource {
|
||||
/// postgres host
|
||||
pub host: RPCUrl,
|
||||
/// user to connect to postgres
|
||||
pub user: String,
|
||||
/// password to connect to postgres
|
||||
pub password: String,
|
||||
/// query to execute
|
||||
pub query: String,
|
||||
/// dbname
|
||||
pub dbname: String,
|
||||
/// port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Create a new PostgresSource
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch data from postgres
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// clone to move into thread
|
||||
let user = self.user.clone();
|
||||
let host = self.host.clone();
|
||||
let query = self.query.clone();
|
||||
let dbname = self.dbname.clone();
|
||||
let port = self.port.clone();
|
||||
let password = self.password.clone();
|
||||
|
||||
let config = if password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
host, user, dbname, port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
host, user, dbname, port, password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[]).await? {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetch data from postgres and format it as a FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Creates a new OnChainSource with multiple calls
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `calls` - Vector of call specifications
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new OnChainSource with a single call
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `call` - Call specification
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Creates test data for the OnChain data source
|
||||
/// Used for testing and development purposes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Sample file data to use
|
||||
/// * `scales` - Scaling factors for each input
|
||||
/// * `shapes` - Shapes of the input tensors
|
||||
/// * `rpc` - Optional RPC endpoint override
|
||||
/// Create dummy local on-chain data to test the OnChain data source
|
||||
pub async fn test_from_file_data(
|
||||
data: &FileSource,
|
||||
scales: Vec<crate::Scale>,
|
||||
@@ -324,40 +398,48 @@ impl OnChainSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// Specification for view-only calls to fetch on-chain data
|
||||
/// Used for data attestation in smart contract verification
|
||||
/// Defines the view only calls to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallsToAccount {
|
||||
/// Vector of (call data, decimals) pairs
|
||||
/// call_data: ABI-encoded function call
|
||||
/// decimals: Number of decimal places for float conversion
|
||||
/// A vector of tuples, where index 0 of tuples
|
||||
/// are the byte strings representing the ABI encoded function calls to
|
||||
/// read the data from the address. This call must return a single
|
||||
/// elementary type (<https://docs.soliditylang.org/en/v0.8.20/abi-spec.html#types>).
|
||||
/// The second index of the tuple is the number of decimals for f32 conversion.
|
||||
/// We don't support dynamic types currently.
|
||||
pub call_data: Vec<(Call, Decimals)>,
|
||||
/// Contract address to call
|
||||
/// Address of the contract to read the data from.
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Specification for a single view-only call returning an array
|
||||
/// Defines a view only call to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallToAccount {
|
||||
/// ABI-encoded function call data
|
||||
/// The call_data is a byte strings representing the ABI encoded function call to
|
||||
/// read the data from the address. This call must return a single array of integers that can be
|
||||
/// be safely cast to the int128 type in solidity.
|
||||
pub call_data: Call,
|
||||
/// Number of decimal places for float conversion
|
||||
/// The number of decimals for f32 conversion of all of the elements returned from the
|
||||
/// call.
|
||||
pub decimals: Decimals,
|
||||
/// Contract address to call
|
||||
/// Address of the contract to read the data from.
|
||||
pub address: String,
|
||||
/// Expected length of returned array
|
||||
/// The number of elements returned from the call.
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
/// Represents different sources of input/output data for the EZKL model
|
||||
/// Enum that defines source of the inputs/outputs to the EZKL model
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum DataSource {
|
||||
/// Data from a JSON file containing arrays of values
|
||||
/// .json File data source.
|
||||
File(FileSource),
|
||||
/// Data fetched from blockchain contracts
|
||||
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
|
||||
OnChain(OnChainSource),
|
||||
/// Data from a PostgreSQL database
|
||||
/// Postgres DB
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DB(PostgresSource),
|
||||
}
|
||||
@@ -400,7 +482,8 @@ impl From<OnChainSource> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Always use JSON serialization for untagged enums
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
impl<'de> Deserialize<'de> for DataSource {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -408,19 +491,15 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
// Try deserializing as FileSource first
|
||||
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
|
||||
|
||||
if let Ok(t) = first_try {
|
||||
return Ok(DataSource::File(t));
|
||||
}
|
||||
|
||||
// Try deserializing as OnChainSource
|
||||
let second_try: Result<OnChainSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = second_try {
|
||||
return Ok(DataSource::OnChain(t));
|
||||
}
|
||||
|
||||
// Try deserializing as PostgresSource if feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
|
||||
@@ -433,29 +512,22 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for input and output data for graph computations
|
||||
///
|
||||
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
|
||||
/// Input to graph as a datasource
|
||||
/// Always use JSON serialization for GraphData. Seriously.
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
|
||||
pub struct GraphData {
|
||||
/// Input data for the model/graph
|
||||
/// Can be empty if inputs come from on-chain sources
|
||||
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
||||
pub input_data: DataSource,
|
||||
|
||||
/// Optional output data for the model/graph
|
||||
/// Can be empty if outputs come from on-chain sources
|
||||
/// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain).
|
||||
pub output_data: Option<DataSource>,
|
||||
}
|
||||
|
||||
impl UnwindSafe for GraphData {}
|
||||
|
||||
impl GraphData {
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Converts the input data to tract's tensor format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `shapes` - Expected shapes for each input tensor
|
||||
/// * `datum_types` - Expected data types for each input
|
||||
/// Convert the input data to tract data
|
||||
pub fn to_tract_data(
|
||||
&self,
|
||||
shapes: &[Vec<usize>],
|
||||
@@ -484,14 +556,9 @@ impl GraphData {
|
||||
Ok(inputs)
|
||||
}
|
||||
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Converts tract tensor data into GraphData format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensors` - Array of tract tensors to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the converted tensor data
|
||||
/// Convert the tract data to tract data
|
||||
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
|
||||
use tract_onnx::prelude::DatumType;
|
||||
|
||||
@@ -517,10 +584,7 @@ impl GraphData {
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a new GraphData instance with given input data
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_data` - The input data source
|
||||
pub fn new(input_data: DataSource) -> Self {
|
||||
GraphData {
|
||||
input_data,
|
||||
@@ -528,13 +592,7 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads graph input data from a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the input file
|
||||
///
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the loaded data
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let reader = std::fs::File::open(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
@@ -548,35 +606,23 @@ impl GraphData {
|
||||
Ok(graph_input)
|
||||
}
|
||||
|
||||
/// Saves the graph data to a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path where to save the data
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let file = std::fs::File::create(path.clone()).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Splits the input data into multiple batches based on input shapes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_shapes` - Vector of shapes for each input tensor
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of GraphData instances, one for each batch
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if:
|
||||
/// - Data is from on-chain source
|
||||
/// - Input size is not evenly divisible by batch size
|
||||
pub async fn split_into_batches(
|
||||
&self,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
// split input data into batches
|
||||
let mut batched_inputs = vec![];
|
||||
|
||||
let iterable = match self {
|
||||
@@ -600,12 +646,10 @@ impl GraphData {
|
||||
} => data.fetch_and_format_as_file().await?,
|
||||
};
|
||||
|
||||
// Process each input tensor according to its shape
|
||||
for (i, shape) in input_shapes.iter().enumerate() {
|
||||
// ensure the input is evenly divisible by batch_size
|
||||
let input_size = shape.clone().iter().product::<usize>();
|
||||
let input = &iterable[i];
|
||||
|
||||
// Validate input size is divisible by batch size
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
@@ -613,8 +657,6 @@ impl GraphData {
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Split input into batches
|
||||
let mut batches = vec![];
|
||||
for batch in input.chunks(input_size) {
|
||||
batches.push(batch.to_vec());
|
||||
@@ -622,18 +664,18 @@ impl GraphData {
|
||||
batched_inputs.push(batches);
|
||||
}
|
||||
|
||||
// Merge batches across inputs
|
||||
// now merge all the batches for each input into a vector of batches
|
||||
// first assert each input has the same number of batches
|
||||
let num_batches = if batched_inputs.is_empty() {
|
||||
0
|
||||
} else {
|
||||
let num_batches = batched_inputs[0].len();
|
||||
// Verify all inputs have same number of batches
|
||||
for input in batched_inputs.iter() {
|
||||
assert_eq!(input.len(), num_batches);
|
||||
}
|
||||
num_batches
|
||||
};
|
||||
|
||||
// now merge the batches
|
||||
let mut input_batches = vec![];
|
||||
for i in 0..num_batches {
|
||||
let mut batch = vec![];
|
||||
@@ -643,12 +685,11 @@ impl GraphData {
|
||||
input_batches.push(DataSource::File(batch));
|
||||
}
|
||||
|
||||
// Ensure at least one batch exists
|
||||
if input_batches.is_empty() {
|
||||
input_batches.push(DataSource::File(vec![vec![]]));
|
||||
}
|
||||
|
||||
// Create GraphData instance for each batch
|
||||
// create a new GraphWitness for each batch
|
||||
let batches = input_batches
|
||||
.into_iter()
|
||||
.map(GraphData::new)
|
||||
@@ -660,7 +701,6 @@ impl GraphData {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallsToAccount {
|
||||
/// Converts CallsToAccount to Python object
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("account", &self.address).unwrap();
|
||||
@@ -669,165 +709,6 @@ impl ToPyObject for CallsToAccount {
|
||||
}
|
||||
}
|
||||
|
||||
// Additional Python bindings for various types...
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_postgres_source_new() {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let source = PostgresSource::new(
|
||||
"localhost".to_string(),
|
||||
"5432".to_string(),
|
||||
"user".to_string(),
|
||||
"SELECT * FROM table".to_string(),
|
||||
"database".to_string(),
|
||||
"password".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(source.host, "localhost");
|
||||
assert_eq!(source.port, "5432");
|
||||
assert_eq!(source.user, "user");
|
||||
assert_eq!(source.query, "SELECT * FROM table");
|
||||
assert_eq!(source.dbname, "database");
|
||||
assert_eq!(source.password, "password");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
// Test backwards compatibility with old format
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
// Test serialization/deserialization of graph input
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_python_compat() {
|
||||
// Test compatibility with mclbn256 library serialization
|
||||
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
||||
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
||||
assert_eq!(format!("{:?}", source), original_addr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Source data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
/// Database host address
|
||||
pub host: RPCUrl,
|
||||
/// Database user name
|
||||
pub user: String,
|
||||
/// Database password
|
||||
pub password: String,
|
||||
/// SQL query to execute
|
||||
pub query: String,
|
||||
/// Database name
|
||||
pub dbname: String,
|
||||
/// Database port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Creates a new PostgreSQL data source
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches data from the PostgreSQL database
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// Configuration string
|
||||
let config = if self.password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
self.host, self.user, self.dbname, self.port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
self.host, self.user, self.dbname, self.port, self.password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
|
||||
// Extract rows from query
|
||||
for row in client.query(&self.query, &[]).await? {
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetches and formats data as FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallToAccount {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -862,7 +743,6 @@ impl ToPyObject for DataSource {
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DataSource::DB(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("host", &source.host).unwrap();
|
||||
@@ -887,3 +767,57 @@ impl ToPyObject for FileSourceInner {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
|
||||
// test for the compatibility with the serialized elements from the mclbn256 library
|
||||
#[test]
|
||||
fn test_python_compat() {
|
||||
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
||||
|
||||
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
||||
|
||||
assert_eq!(format!("{:?}", source), original_addr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -619,6 +619,11 @@ impl GraphSettings {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn uses_modules(&self) -> bool {
|
||||
!self.module_sizes.max_constraints() > 0
|
||||
}
|
||||
|
||||
/// if any visibility is encrypted or hashed
|
||||
pub fn module_requires_fixed(&self) -> bool {
|
||||
self.run_args.input_visibility.is_hashed()
|
||||
@@ -761,7 +766,7 @@ pub struct TestOnChainData {
|
||||
pub data: std::path::PathBuf,
|
||||
/// rpc endpoint
|
||||
pub rpc: Option<String>,
|
||||
/// data sources for the on chain data
|
||||
///
|
||||
pub data_sources: TestSources,
|
||||
}
|
||||
|
||||
@@ -949,7 +954,7 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
_ => Err(GraphError::OnChainDataSource),
|
||||
_ => unreachable!("cannot load from on-chain data"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -384,7 +384,8 @@ pub struct ParsedNodes {
|
||||
impl ParsedNodes {
|
||||
/// Returns the number of the computational graph's inputs
|
||||
pub fn num_inputs(&self) -> usize {
|
||||
self.inputs.len()
|
||||
let input_nodes = self.inputs.iter();
|
||||
input_nodes.len()
|
||||
}
|
||||
|
||||
/// Input types
|
||||
@@ -424,7 +425,8 @@ impl ParsedNodes {
|
||||
|
||||
/// Returns the number of the computational graph's outputs
|
||||
pub fn num_outputs(&self) -> usize {
|
||||
self.outputs.len()
|
||||
let output_nodes = self.outputs.iter();
|
||||
output_nodes.len()
|
||||
}
|
||||
|
||||
/// Returns shapes of the computational graph's outputs
|
||||
@@ -632,10 +634,6 @@ impl Model {
|
||||
|
||||
for (i, id) in model.clone().inputs.iter().enumerate() {
|
||||
let input = model.node_mut(id.node);
|
||||
|
||||
if input.outputs.len() == 0 {
|
||||
return Err(GraphError::MissingOutput(id.node));
|
||||
}
|
||||
let mut fact: InferenceFact = input.outputs[0].fact.clone();
|
||||
|
||||
for (i, x) in fact.clone().shape.dims().enumerate() {
|
||||
@@ -908,7 +906,6 @@ impl Model {
|
||||
n.opkind = SupportedOp::Input(Input {
|
||||
scale,
|
||||
datum_type: inp.datum_type,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
});
|
||||
input_idx += 1;
|
||||
n.out_scale = scale;
|
||||
@@ -1019,10 +1016,6 @@ impl Model {
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
if vars.advices.len() < 3 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(3));
|
||||
}
|
||||
|
||||
let mut base_gate = PolyConfig::configure(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
@@ -1042,10 +1035,6 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_dynamic_lookup() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
|
||||
base_gate.configure_dynamic_lookup(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
@@ -1054,9 +1043,6 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_shuffle() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
base_gate.configure_shuffles(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
@@ -1075,7 +1061,6 @@ impl Model {
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `witnessed_outputs` - The values to compare against.
|
||||
/// * `constants` - The constants for the circuit.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn layout(
|
||||
&self,
|
||||
mut config: ModelConfig,
|
||||
@@ -1146,8 +1131,8 @@ impl Model {
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let comparators = if run_args.output_visibility == Visibility::Public {
|
||||
let res = vars
|
||||
@@ -1170,10 +1155,7 @@ impl Model {
|
||||
.layout(
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
.map_err(|e| e.into())
|
||||
})
|
||||
@@ -1450,16 +1432,13 @@ impl Model {
|
||||
.into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let mut tol = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
@@ -1481,7 +1460,7 @@ impl Model {
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_felt_evals()
|
||||
.unwrap_or_else(|_| Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1551,7 +1530,6 @@ impl Model {
|
||||
let mut op = crate::circuit::Constant::new(
|
||||
c.quantized_values.clone(),
|
||||
c.raw_values.clone(),
|
||||
c.decomp,
|
||||
);
|
||||
op.pre_assign(consts[const_idx].clone());
|
||||
n.opkind = SupportedOp::Constant(op);
|
||||
|
||||
@@ -284,6 +284,7 @@ impl GraphModules {
|
||||
log::error!("Poseidon config not initialized");
|
||||
return Err(Error::Synthesis);
|
||||
}
|
||||
// If the module is encrypted, then we need to encrypt the inputs
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
// Import dependencies for scaling operations
|
||||
use super::scale_to_multiplier;
|
||||
|
||||
// Import ONNX-specific utilities when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::utilities::node_output_shapes;
|
||||
|
||||
// Import scale management types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
|
||||
// Import visibility settings for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::Visibility;
|
||||
|
||||
// Import operation types for different circuit components
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
@@ -22,49 +13,28 @@ use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
|
||||
// Import graph error types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::errors::GraphError;
|
||||
|
||||
// Import ONNX operation conversion utilities
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
|
||||
// Import tensor error handling
|
||||
use crate::tensor::TensorError;
|
||||
|
||||
// Import curve-specific field type
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
|
||||
// Import logging for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use log::trace;
|
||||
|
||||
// Import serialization traits
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
// Import data structures for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
// Import formatting traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::fmt;
|
||||
|
||||
// Import table display formatting for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tabled::Tabled;
|
||||
|
||||
// Import ONNX-specific types and traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::{
|
||||
self,
|
||||
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
|
||||
};
|
||||
|
||||
/// Helper function to format vectors for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
if !v.is_empty() {
|
||||
@@ -74,35 +44,29 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to format operation kinds for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_opkind(v: &SupportedOp) -> String {
|
||||
v.as_string()
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled to handle different precision requirements.
|
||||
/// This enables operations to work with inputs that have been scaled to different fixed-point representations.
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Rescaled {
|
||||
/// The underlying operation that needs to be rescaled
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// Vector of (index, scale) pairs defining how each input should be scaled
|
||||
/// The scale of the operation's inputs.
|
||||
pub scale: Vec<(usize, u128)>,
|
||||
}
|
||||
|
||||
/// Implementation of the Op trait for Rescaled operations
|
||||
impl Op<Fp> for Rescaled {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
@@ -113,7 +77,6 @@ impl Op<Fp> for Rescaled {
|
||||
Op::<Fp>::out_scale(&*self.inner, in_scales)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -130,40 +93,28 @@ impl Op<Fp> for Rescaled {
|
||||
self.inner.layout(config, region, res)
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone())
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for operations that require scale rebasing
|
||||
/// This handles cases where operation scales need to be adjusted to a target scale
|
||||
/// while preserving the numerical relationships
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RebaseScale {
|
||||
/// The operation that needs to be rescaled
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// Operation used for rebasing, typically division
|
||||
/// rebase op
|
||||
pub rebase_op: HybridOp,
|
||||
/// Scale that we're rebasing to
|
||||
/// scale being rebased to
|
||||
pub target_scale: i32,
|
||||
/// Original scale of operation's inputs before rebasing
|
||||
/// The original scale of the operation's inputs.
|
||||
pub original_scale: i32,
|
||||
/// Scaling multiplier used in rebasing
|
||||
/// multiplier
|
||||
pub multiplier: f64,
|
||||
}
|
||||
|
||||
impl RebaseScale {
|
||||
/// Creates a rebased version of an operation if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `global_scale` - Base scale for the system
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
/// * `scale_rebase_multiplier` - Factor determining when rebasing should occur
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation depending on scale relationships
|
||||
pub fn rebase(
|
||||
inner: SupportedOp,
|
||||
global_scale: crate::Scale,
|
||||
@@ -204,15 +155,7 @@ impl RebaseScale {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a rebased operation with increased scale
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `target_scale` - Scale to rebase to
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation with increased scale
|
||||
pub fn rebase_up(
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
@@ -249,12 +192,10 @@ impl RebaseScale {
|
||||
}
|
||||
|
||||
impl Op<Fp> for RebaseScale {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
"REBASED (div={:?}, rebasing_op={}) ({})",
|
||||
@@ -264,12 +205,10 @@ impl Op<Fp> for RebaseScale {
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -283,40 +222,34 @@ impl Op<Fp> for RebaseScale {
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone())
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents all supported operation types in the circuit
|
||||
/// Each variant encapsulates a different type of operation with specific behavior
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum SupportedOp {
|
||||
/// Linear operations (polynomial-based)
|
||||
/// A linear operation.
|
||||
Linear(PolyOp),
|
||||
/// Nonlinear operations requiring lookup tables
|
||||
/// A nonlinear operation.
|
||||
Nonlinear(LookupOp),
|
||||
/// Mixed operations combining different approaches
|
||||
/// A hybrid operation.
|
||||
Hybrid(HybridOp),
|
||||
/// Input values to the circuit
|
||||
///
|
||||
Input(Input),
|
||||
/// Constant values in the circuit
|
||||
///
|
||||
Constant(Constant<Fp>),
|
||||
/// Placeholder for unsupported operations
|
||||
///
|
||||
Unknown(Unknown),
|
||||
/// Operations requiring rescaling of inputs
|
||||
///
|
||||
Rescaled(Rescaled),
|
||||
/// Operations requiring scale rebasing
|
||||
///
|
||||
RebaseScale(RebaseScale),
|
||||
}
|
||||
|
||||
impl SupportedOp {
|
||||
/// Checks if the operation is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if operation requires lookup table
|
||||
/// * `false` otherwise
|
||||
pub fn is_lookup(&self) -> bool {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(_) => true,
|
||||
@@ -324,12 +257,7 @@ impl SupportedOp {
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns input operation if this is an input
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(Input)` if this is an input operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_input(&self) -> Option<Input> {
|
||||
match self {
|
||||
SupportedOp::Input(op) => Some(op.clone()),
|
||||
@@ -337,11 +265,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to rebased operation if this is a rebased operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&RebaseScale)` if this is a rebased operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_rebased(&self) -> Option<&RebaseScale> {
|
||||
match self {
|
||||
SupportedOp::RebaseScale(op) => Some(op),
|
||||
@@ -349,11 +273,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to lookup operation if this is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&LookupOp)` if this is a lookup operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_lookup(&self) -> Option<&LookupOp> {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(op) => Some(op),
|
||||
@@ -361,11 +281,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_constant(&self) -> Option<&Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -373,11 +289,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns mutable reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&mut Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_mutable_constant(&mut self) -> Option<&mut Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -385,19 +297,18 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a homogeneously rescaled version of this operation if needed
|
||||
/// Only available with EZKL feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn homogenous_rescale(
|
||||
&self,
|
||||
in_scales: Vec<crate::Scale>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
let op = self.clone_dyn();
|
||||
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
|
||||
}
|
||||
|
||||
/// Returns reference to underlying Op implementation
|
||||
/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
|
||||
fn as_op(&self) -> &dyn Op<Fp> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => op,
|
||||
@@ -411,10 +322,9 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if this is an identity operation
|
||||
///
|
||||
/// check if is the identity operation
|
||||
/// # Returns
|
||||
/// * `true` if this operation passes input through unchanged
|
||||
/// * `true` if the operation is the identity operation
|
||||
/// * `false` otherwise
|
||||
pub fn is_identity(&self) -> bool {
|
||||
match self {
|
||||
@@ -451,11 +361,9 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
if let Some(op) = value.as_any().downcast_ref::<Unknown>() {
|
||||
return SupportedOp::Unknown(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<Rescaled>() {
|
||||
return SupportedOp::Rescaled(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<RebaseScale>() {
|
||||
return SupportedOp::RebaseScale(op.clone());
|
||||
};
|
||||
@@ -467,7 +375,6 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
/// Layout this operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -477,61 +384,54 @@ impl Op<Fp> for SupportedOp {
|
||||
self.as_op().layout(config, region, values)
|
||||
}
|
||||
|
||||
/// Check if this is an input operation
|
||||
fn is_input(&self) -> bool {
|
||||
self.as_op().is_input()
|
||||
}
|
||||
|
||||
/// Check if this is a constant operation
|
||||
fn is_constant(&self) -> bool {
|
||||
self.as_op().is_constant()
|
||||
}
|
||||
|
||||
/// Get which inputs require homogeneous scales
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
self.as_op().requires_homogenous_input_scales()
|
||||
}
|
||||
|
||||
/// Create a clone of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
self.as_op().clone_dyn()
|
||||
}
|
||||
|
||||
/// Get string representation
|
||||
fn as_string(&self) -> String {
|
||||
self.as_op().as_string()
|
||||
}
|
||||
|
||||
/// Convert to Any type
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Calculate output scale from input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a connection to another node's output
|
||||
/// First element is node index, second is output slot index
|
||||
/// A node's input is a tensor from another node's output.
|
||||
pub type Outlet = (usize, usize);
|
||||
|
||||
/// Represents a single computational node in the circuit graph
|
||||
/// Contains all information needed to execute and connect operations
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Node {
|
||||
/// The operation this node performs
|
||||
/// [Op] i.e what operation this node represents.
|
||||
pub opkind: SupportedOp,
|
||||
/// Fixed point scale factor for this node's output
|
||||
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
|
||||
pub out_scale: i32,
|
||||
/// Connections to other nodes' outputs that serve as inputs
|
||||
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
|
||||
// but in_dim is [in], out_dim is [out]
|
||||
/// The indices of the node's inputs.
|
||||
pub inputs: Vec<Outlet>,
|
||||
/// Shape of this node's output tensor
|
||||
/// Dimensions of output.
|
||||
pub out_dims: Vec<usize>,
|
||||
/// Unique identifier for this node
|
||||
/// The node's unique identifier.
|
||||
pub idx: usize,
|
||||
/// Number of times this node's output is used
|
||||
/// The node's num of uses
|
||||
pub num_uses: usize,
|
||||
}
|
||||
|
||||
@@ -569,19 +469,12 @@ impl PartialEq for Node {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Creates a new Node from an ONNX node
|
||||
/// Only available when EZKL feature is enabled
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node` - Source ONNX node
|
||||
/// * `other_nodes` - Map of existing nodes in the graph
|
||||
/// * `scales` - Scale factors for variables
|
||||
/// * `idx` - Unique identifier for this node
|
||||
/// * `symbol_values` - ONNX symbol values
|
||||
/// * `run_args` - Runtime configuration arguments
|
||||
///
|
||||
/// # Returns
|
||||
/// New Node instance or error if creation fails
|
||||
/// Converts a tract [OnnxNode] into an ezkl [Node].
|
||||
/// # Arguments:
|
||||
/// * `node` - [OnnxNode]
|
||||
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
|
||||
/// * `public_params` - flag if parameters of model are public
|
||||
/// * `idx` - The node's unique identifier.
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
@@ -719,14 +612,16 @@ impl Node {
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if this node performs softmax operation
|
||||
/// check if it is a softmax node
|
||||
pub fn is_softmax(&self) -> bool {
|
||||
matches!(self.opkind, SupportedOp::Hybrid(HybridOp::Softmax { .. }))
|
||||
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to rescale constants that are only used once
|
||||
/// Only available when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn rescale_const_with_single_use(
|
||||
constant: &mut Constant<Fp>,
|
||||
|
||||
@@ -44,11 +44,11 @@ use tract_onnx::tract_hir::{
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
|
||||
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `elem` - the element to quantize.
|
||||
/// * `vec` - the vector to quantize.
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(
|
||||
@@ -59,7 +59,7 @@ pub fn quantize_float(
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value || *elem < -max_value {
|
||||
if *elem > max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
|
||||
f64::powf(2., scale as f64)
|
||||
}
|
||||
|
||||
/// Converts a fixed point multiplier to a scale (log base 2).
|
||||
/// Converts a scale (log base 2) to a fixed point multiplier.
|
||||
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
mult.log2().round() as crate::Scale
|
||||
}
|
||||
@@ -228,7 +228,10 @@ pub fn extract_tensor_value(
|
||||
.iter()
|
||||
.map(|x| match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
Err(_) => match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -309,9 +312,6 @@ pub fn new_op_from_onnx(
|
||||
let mut deleted_indices = vec![];
|
||||
let node = match node.op().name().as_ref() {
|
||||
"ShiftLeft" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -324,13 +324,10 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
|
||||
}
|
||||
}
|
||||
"ShiftRight" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -343,7 +340,7 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
@@ -366,10 +363,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
if input_ops.len() != 3 {
|
||||
return Err(GraphError::InvalidDims(idx, "range".to_string()));
|
||||
}
|
||||
|
||||
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
|
||||
let input_ops = input_ops
|
||||
.iter()
|
||||
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
|
||||
@@ -384,11 +378,7 @@ pub fn new_op_from_onnx(
|
||||
// Quantize the raw value (integers)
|
||||
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
|
||||
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
raw_value,
|
||||
!run_args.ignore_range_check_inputs_outputs,
|
||||
);
|
||||
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
|
||||
// Create a constant op
|
||||
SupportedOp::Constant(c)
|
||||
}
|
||||
@@ -429,10 +419,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
|
||||
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| {
|
||||
@@ -450,7 +436,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: false,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -462,17 +447,8 @@ pub fn new_op_from_onnx(
|
||||
"Topk" => {
|
||||
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
};
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
}
|
||||
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
c.raw_values.map(|x| x as usize)[0]
|
||||
@@ -512,10 +488,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -527,7 +499,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -551,9 +522,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -564,7 +532,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -588,9 +555,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -602,7 +566,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -626,9 +589,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -640,7 +600,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -715,11 +674,7 @@ pub fn new_op_from_onnx(
|
||||
constant_scale,
|
||||
&run_args.param_visibility,
|
||||
)?;
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
raw_value,
|
||||
run_args.ignore_range_check_inputs_outputs,
|
||||
);
|
||||
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
|
||||
// Create a constant op
|
||||
SupportedOp::Constant(c)
|
||||
}
|
||||
@@ -729,9 +684,7 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
|
||||
}
|
||||
assert_eq!(axes.len(), 1, "only support argmax over one axis");
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
|
||||
}
|
||||
@@ -741,9 +694,7 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
|
||||
}
|
||||
assert_eq!(axes.len(), 1, "only support argmin over one axis");
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
|
||||
}
|
||||
@@ -852,9 +803,6 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
@@ -898,9 +846,6 @@ pub fn new_op_from_onnx(
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
@@ -982,19 +927,13 @@ pub fn new_op_from_onnx(
|
||||
DatumType::F64 => (scales.input, InputType::F64),
|
||||
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
|
||||
};
|
||||
SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale,
|
||||
datum_type,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
})
|
||||
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
|
||||
}
|
||||
"Cast" => {
|
||||
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
|
||||
let dt = op.to;
|
||||
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
|
||||
};
|
||||
assert_eq!(input_scales.len(), 1);
|
||||
|
||||
match dt {
|
||||
DatumType::Bool
|
||||
@@ -1044,11 +983,6 @@ pub fn new_op_from_onnx(
|
||||
|
||||
if const_idx.len() == 1 {
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if inputs.len() <= const_idx {
|
||||
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
|
||||
// if not divisible by 2 then we need to add a range check
|
||||
@@ -1123,9 +1057,6 @@ pub fn new_op_from_onnx(
|
||||
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
|
||||
}
|
||||
};
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
|
||||
}
|
||||
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
@@ -1165,42 +1096,22 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Floor" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Round" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "round".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"RoundHalfToEven" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
// Extract the slope layer hyperparams from a const
|
||||
@@ -1210,9 +1121,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.len() > 1 {
|
||||
return Err(GraphError::NonScalarPower);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
unimplemented!("only support scalar pow")
|
||||
}
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
@@ -1229,9 +1138,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[0].decrement_use();
|
||||
deleted_indices.push(0);
|
||||
if c.raw_values.len() > 1 {
|
||||
return Err(GraphError::NonScalarBase);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
unimplemented!("only support scalar base")
|
||||
}
|
||||
|
||||
let base = c.raw_values[0];
|
||||
@@ -1241,14 +1148,10 @@ pub fn new_op_from_onnx(
|
||||
base: base.into(),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
unimplemented!("only support constant base or pow for now")
|
||||
}
|
||||
}
|
||||
"Div" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -1256,15 +1159,14 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 || const_idx.is_empty() {
|
||||
if const_idx.len() > 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if const_idx != 1 {
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
unimplemented!("only support div with constant as second input")
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
@@ -1278,14 +1180,10 @@ pub fn new_op_from_onnx(
|
||||
denom: denom.into(),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support non zero divisors of size 1".to_string(),
|
||||
));
|
||||
unimplemented!("only support non zero divisors of size 1")
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
unimplemented!("only support div with constant as second input")
|
||||
}
|
||||
}
|
||||
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
|
||||
@@ -1425,7 +1323,7 @@ pub fn new_op_from_onnx(
|
||||
if !resize_node.contains("interpolator: Nearest")
|
||||
&& !resize_node.contains("nearest: Floor")
|
||||
{
|
||||
return Err(GraphError::InvalidInterpolation);
|
||||
unimplemented!("Only nearest neighbor interpolation is supported")
|
||||
}
|
||||
// check if optional scale factor is present
|
||||
if inputs.len() != 2 && inputs.len() != 3 {
|
||||
@@ -1529,10 +1427,6 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Linear(PolyOp::Reshape(output_shape))
|
||||
}
|
||||
"Flatten" => {
|
||||
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
|
||||
};
|
||||
|
||||
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
|
||||
SupportedOp::Linear(PolyOp::Flatten(new_dims))
|
||||
}
|
||||
@@ -1606,10 +1500,12 @@ pub fn homogenize_input_scales(
|
||||
input_scales: Vec<crate::Scale>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let relevant_input_scales = inputs_to_scale
|
||||
.iter()
|
||||
.filter(|idx| input_scales.len() > **idx)
|
||||
.map(|&idx| input_scales[idx])
|
||||
let relevant_input_scales = input_scales
|
||||
.clone()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| inputs_to_scale.contains(idx))
|
||||
.map(|(_, scale)| scale)
|
||||
.collect_vec();
|
||||
|
||||
if inputs_to_scale.is_empty() {
|
||||
@@ -1650,30 +1546,10 @@ pub fn homogenize_input_scales(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// tests for the utility module
|
||||
pub mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
// quantization tests
|
||||
#[test]
|
||||
fn test_quantize_tensor() {
|
||||
let tensor: Tensor<f32> = (0..10).map(|x| x as f32).into();
|
||||
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
let scale = 0;
|
||||
let visibility = &Visibility::Public;
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
|
||||
assert_eq!(quantized.len(), 10);
|
||||
assert_eq!(quantized, reference);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_edge_cases() {
|
||||
assert_eq!(quantize_float(&f64::NAN, 0.0, 0).unwrap(), 0);
|
||||
assert!(quantize_float(&f64::INFINITY, 0.0, 0).is_err());
|
||||
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flatten_valtensors() {
|
||||
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
|
||||
@@ -11,34 +11,35 @@ use log::debug;
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Defines the visibility level of values within the zero-knowledge circuit
|
||||
/// Controls how values are handled during proof generation and verification
|
||||
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
|
||||
pub enum Visibility {
|
||||
/// Value is private to the prover and not included in proof
|
||||
/// Mark an item as private to the prover (not in the proof submitted for verification)
|
||||
#[default]
|
||||
Private,
|
||||
/// Value is public and included in proof for verification
|
||||
/// Mark an item as public (sent in the proof submitted for verification)
|
||||
Public,
|
||||
/// Value is hashed and the hash is included in proof
|
||||
/// Mark an item as publicly committed to (hash sent in the proof submitted for verification)
|
||||
Hashed {
|
||||
/// Controls how the hash is handled in proof
|
||||
/// true - hash is included directly in proof (public)
|
||||
/// false - hash is used as advice and passed to computational graph
|
||||
/// Whether the hash is used as an instance (sent in the proof submitted for verification)
|
||||
/// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph
|
||||
/// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph
|
||||
hash_is_public: bool,
|
||||
/// Specifies which outputs this hash affects
|
||||
///
|
||||
outlets: Vec<usize>,
|
||||
},
|
||||
/// Value is committed using KZG commitment scheme
|
||||
/// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification)
|
||||
KZGCommit,
|
||||
/// Value is assigned as a constant in the circuit
|
||||
/// assigned as a constant in the circuit
|
||||
Fixed,
|
||||
}
|
||||
|
||||
@@ -65,17 +66,15 @@ impl Display for Visibility {
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Visibility {
|
||||
/// Converts visibility to command line flags
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
/// Converts string representation to Visibility
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
// Split on last occurrence of '/'
|
||||
// split on last occurrence of '/'
|
||||
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -107,8 +106,8 @@ impl<'a> From<&'a str> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Visibility {
|
||||
/// Converts Visibility to Python object
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
Visibility::Private => "private".to_object(py),
|
||||
@@ -135,13 +134,14 @@ impl IntoPy<PyObject> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Visibility {
|
||||
/// Extracts Visibility from Python object
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
let strval = strval.as_str();
|
||||
|
||||
if strval.contains("hashed/private") {
|
||||
// split on last occurence of '/'
|
||||
let (_, outlets) = strval.split_at(strval.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -174,32 +174,29 @@ impl<'source> FromPyObject<'source> for Visibility {
|
||||
}
|
||||
|
||||
impl Visibility {
|
||||
/// Returns true if visibility is Fixed
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_fixed(&self) -> bool {
|
||||
matches!(&self, Visibility::Fixed)
|
||||
}
|
||||
|
||||
/// Returns true if visibility is Private or hashed private
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_private(&self) -> bool {
|
||||
matches!(&self, Visibility::Private) || self.is_hashed_private()
|
||||
}
|
||||
|
||||
/// Returns true if visibility is Public
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_public(&self) -> bool {
|
||||
matches!(&self, Visibility::Public)
|
||||
}
|
||||
|
||||
/// Returns true if visibility involves hashing
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. })
|
||||
}
|
||||
|
||||
/// Returns true if visibility uses KZG commitment
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_polycommit(&self) -> bool {
|
||||
matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
/// Returns true if visibility is hashed with public hash
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed_public(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: true,
|
||||
@@ -210,8 +207,7 @@ impl Visibility {
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns true if visibility is hashed with private hash
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed_private(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: false,
|
||||
@@ -223,12 +219,11 @@ impl Visibility {
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns true if visibility requires additional processing
|
||||
#[allow(missing_docs)]
|
||||
pub fn requires_processing(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. }) | matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
/// Returns vector of output indices that this visibility setting affects
|
||||
#[allow(missing_docs)]
|
||||
pub fn overwrites_inputs(&self) -> Vec<usize> {
|
||||
if let Visibility::Hashed { outlets, .. } = self {
|
||||
return outlets.clone();
|
||||
@@ -237,14 +232,14 @@ impl Visibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages scaling factors for different parts of the model
|
||||
/// Represents the scale of the model input, model parameters.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarScales {
|
||||
/// Scale factor for input values
|
||||
///
|
||||
pub input: crate::Scale,
|
||||
/// Scale factor for parameter values
|
||||
///
|
||||
pub params: crate::Scale,
|
||||
/// Multiplier for scale rebasing
|
||||
///
|
||||
pub rebase_multiplier: u32,
|
||||
}
|
||||
|
||||
@@ -255,17 +250,17 @@ impl std::fmt::Display for VarScales {
|
||||
}
|
||||
|
||||
impl VarScales {
|
||||
/// Returns maximum scale value
|
||||
///
|
||||
pub fn get_max(&self) -> crate::Scale {
|
||||
std::cmp::max(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Returns minimum scale value
|
||||
///
|
||||
pub fn get_min(&self) -> crate::Scale {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Creates VarScales from runtime arguments
|
||||
/// Place in [VarScales] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
@@ -275,17 +270,16 @@ impl VarScales {
|
||||
}
|
||||
}
|
||||
|
||||
/// Controls visibility settings for different parts of the model
|
||||
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarVisibility {
|
||||
/// Visibility of model inputs
|
||||
/// Input to the model or computational graph
|
||||
pub input: Visibility,
|
||||
/// Visibility of model parameters (weights, biases)
|
||||
/// Parameters, such as weights and biases, in the model
|
||||
pub params: Visibility,
|
||||
/// Visibility of model outputs
|
||||
/// Output of the model or computational graph
|
||||
pub output: Visibility,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VarVisibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
@@ -307,7 +301,8 @@ impl Default for VarVisibility {
|
||||
}
|
||||
|
||||
impl VarVisibility {
|
||||
/// Creates visibility settings from runtime arguments
|
||||
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Place in [VarVisibility] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let input_vis = &args.input_visibility;
|
||||
let params_vis = &args.param_visibility;
|
||||
@@ -318,17 +313,17 @@ impl VarVisibility {
|
||||
}
|
||||
|
||||
if !output_vis.is_public()
|
||||
&& !params_vis.is_public()
|
||||
&& !input_vis.is_public()
|
||||
&& !output_vis.is_fixed()
|
||||
&& !params_vis.is_fixed()
|
||||
&& !input_vis.is_fixed()
|
||||
&& !output_vis.is_hashed()
|
||||
&& !params_vis.is_hashed()
|
||||
&& !input_vis.is_hashed()
|
||||
&& !output_vis.is_polycommit()
|
||||
&& !params_vis.is_polycommit()
|
||||
&& !input_vis.is_polycommit()
|
||||
& !params_vis.is_public()
|
||||
& !input_vis.is_public()
|
||||
& !output_vis.is_fixed()
|
||||
& !params_vis.is_fixed()
|
||||
& !input_vis.is_fixed()
|
||||
& !output_vis.is_hashed()
|
||||
& !params_vis.is_hashed()
|
||||
& !input_vis.is_hashed()
|
||||
& !output_vis.is_polycommit()
|
||||
& !params_vis.is_polycommit()
|
||||
& !input_vis.is_polycommit()
|
||||
{
|
||||
return Err(GraphError::Visibility);
|
||||
}
|
||||
@@ -340,17 +335,17 @@ impl VarVisibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for circuit columns used by a model
|
||||
/// A wrapper for holding all columns that will be assigned to by a model.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Advice columns for circuit assignments
|
||||
#[allow(missing_docs)]
|
||||
pub advices: Vec<VarTensor>,
|
||||
/// Optional instance column for public inputs
|
||||
#[allow(missing_docs)]
|
||||
pub instance: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
/// Gets reference to instance column if it exists
|
||||
/// Get instance col
|
||||
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
|
||||
if let Some(instance) = &self.instance {
|
||||
match instance {
|
||||
@@ -362,14 +357,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets initial offset for instance values
|
||||
/// Set the initial instance offset
|
||||
pub fn set_initial_instance_offset(&mut self, offset: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_initial_instance_offset(offset);
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets total length of instance data
|
||||
/// Get the total instance len
|
||||
pub fn get_instance_len(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_total_instance_len()
|
||||
@@ -378,21 +373,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments instance index
|
||||
/// Increment the instance offset
|
||||
pub fn increment_instance_idx(&mut self) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.increment_idx();
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets instance index to specific value
|
||||
/// Reset the instance offset
|
||||
pub fn set_instance_idx(&mut self, val: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_idx(val);
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets current instance index
|
||||
/// Get the instance offset
|
||||
pub fn get_instance_idx(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_idx()
|
||||
@@ -401,7 +396,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes instance column with specified dimensions and scale
|
||||
///
|
||||
pub fn instantiate_instance(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -422,7 +417,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates new ModelVars with allocated columns based on settings
|
||||
/// Allocate all columns that will be assigned to by a model.
|
||||
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
|
||||
310
src/lib.rs
310
src/lib.rs
@@ -28,9 +28,6 @@
|
||||
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
use log::warn;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc as _;
|
||||
|
||||
/// Error type
|
||||
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
|
||||
@@ -102,7 +99,7 @@ use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use graph::Visibility;
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
};
|
||||
@@ -168,6 +165,7 @@ pub mod srs_sha;
|
||||
pub mod tensor;
|
||||
#[cfg(feature = "ios-bindings")]
|
||||
uniffi::setup_scaffolding!();
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -182,9 +180,11 @@ lazy_static! {
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
|
||||
}
|
||||
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
@@ -266,102 +266,76 @@ impl From<String> for Commitments {
|
||||
}
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
///
|
||||
/// RunArgs contains all configuration parameters needed to control the proving process,
|
||||
/// including scaling factors, visibility settings, and circuit parameters.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(Args, ToFlags)
|
||||
)]
|
||||
pub struct RunArgs {
|
||||
/// Error tolerance for model outputs
|
||||
/// Only applicable when outputs are public
|
||||
/// The tolerance for error on model outputs
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
|
||||
pub tolerance: Tolerance,
|
||||
/// Fixed point scaling factor for quantizing inputs
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub input_scale: Scale,
|
||||
/// Fixed point scaling factor for quantizing parameters
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
/// The denominator in the fixed point representation used when quantizing parameters
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub param_scale: Scale,
|
||||
/// Scale rebase threshold multiplier
|
||||
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
|
||||
/// Advanced parameter that should be used with caution
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// Range for lookup table input column values
|
||||
/// Specified as (min, max) pair
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
|
||||
pub lookup_range: Range,
|
||||
/// Log2 of the number of rows in the circuit
|
||||
/// Controls circuit size and proving time
|
||||
/// The log_2 number of rows
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
|
||||
pub logrows: u32,
|
||||
/// Number of inner columns per block
|
||||
/// Affects circuit layout and efficiency
|
||||
/// The log_2 number of rows
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
pub num_inner_cols: usize,
|
||||
/// Graph variables for parameterizing the computation
|
||||
/// Format: "name->value", e.g. "batch_size->1"
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Visibility setting for input values
|
||||
/// Controls whether inputs are public or private in the circuit
|
||||
/// Flags whether inputs are public, private, fixed, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub input_visibility: Visibility,
|
||||
/// Visibility setting for output values
|
||||
/// Controls whether outputs are public or private in the circuit
|
||||
/// Flags whether outputs are public, private, fixed, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
|
||||
pub output_visibility: Visibility,
|
||||
/// Visibility setting for parameters
|
||||
/// Controls how parameters are handled in the circuit
|
||||
/// Flags whether params are fixed, private, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub param_visibility: Visibility,
|
||||
/// Whether to rebase constants with zero fractional part to scale 0
|
||||
/// Can improve efficiency for integer constants
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
/// Circuit checking mode
|
||||
/// Controls level of constraint verification
|
||||
/// check mode (safe, unsafe, etc)
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
|
||||
pub check_mode: CheckMode,
|
||||
/// Commitment scheme for circuit proving
|
||||
/// Affects proof size and verification time
|
||||
/// commitment scheme
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
|
||||
pub commitment: Option<Commitments>,
|
||||
/// Base for number decomposition
|
||||
/// Must be a power of 2
|
||||
/// the base used for decompositions
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
|
||||
pub decomp_base: usize,
|
||||
/// Number of decomposition legs
|
||||
/// Controls decomposition granularity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
/// Whether to use bounded lookup for logarithm computation
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
/// Range check inputs and outputs (turn off if the inputs are felts)
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
/// Creates a new RunArgs instance with default values
|
||||
///
|
||||
/// Default configuration is optimized for common use cases
|
||||
/// while maintaining reasonable proving time and circuit size
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
@@ -381,138 +355,54 @@ impl Default for RunArgs {
|
||||
commitment: None,
|
||||
decomp_base: 16384,
|
||||
decomp_legs: 2,
|
||||
ignore_range_check_inputs_outputs: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
/// Validates the RunArgs configuration
|
||||
///
|
||||
/// Performs comprehensive validation of all parameters to ensure they are within
|
||||
/// acceptable ranges and follow required constraints. Returns accumulated errors
|
||||
/// if any validations fail.
|
||||
///
|
||||
/// # Returns
|
||||
/// - Ok(()) if all validations pass
|
||||
/// - Err(String) with detailed error message if any validation fails
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Visibility validations
|
||||
if self.param_visibility == Visibility::Public {
|
||||
errors.push(
|
||||
"Parameters cannot be public instances. Use 'fixed' or 'kzgcommit' instead"
|
||||
.to_string(),
|
||||
return Err(
|
||||
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
|
||||
}
|
||||
|
||||
// Scale validations
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
errors.push("scale_rebase_multiplier must be >= 1".to_string());
|
||||
return Err("scale_rebase_multiplier must be >= 1".into());
|
||||
}
|
||||
|
||||
// if any of the scales are too small
|
||||
if self.input_scale < 8 || self.param_scale < 8 {
|
||||
warn!("low scale values (<8) may impact precision");
|
||||
}
|
||||
|
||||
// Lookup range validations
|
||||
if self.lookup_range.0 > self.lookup_range.1 {
|
||||
errors.push(format!(
|
||||
"Invalid lookup range: min ({}) is greater than max ({})",
|
||||
self.lookup_range.0, self.lookup_range.1
|
||||
));
|
||||
return Err("lookup_range min is greater than max".into());
|
||||
}
|
||||
|
||||
// Size validations
|
||||
if self.logrows < 1 {
|
||||
errors.push("logrows must be >= 1".to_string());
|
||||
return Err("logrows must be >= 1".into());
|
||||
}
|
||||
|
||||
if self.num_inner_cols < 1 {
|
||||
errors.push("num_inner_cols must be >= 1".to_string());
|
||||
return Err("num_inner_cols must be >= 1".into());
|
||||
}
|
||||
|
||||
let batch_size = self.variables.iter().find(|(name, _)| name == "batch_size");
|
||||
if let Some(batch_size) = batch_size {
|
||||
if batch_size.1 == 0 {
|
||||
errors.push("'batch_size' cannot be 0".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Decomposition validations
|
||||
if self.decomp_base == 0 {
|
||||
errors.push("decomp_base cannot be 0".to_string());
|
||||
}
|
||||
|
||||
if self.decomp_legs == 0 {
|
||||
errors.push("decomp_legs cannot be 0".to_string());
|
||||
}
|
||||
|
||||
// Performance validations
|
||||
if self.logrows > MAX_PUBLIC_SRS {
|
||||
warn!("logrows exceeds maximum public SRS size");
|
||||
}
|
||||
|
||||
// Validate tolerance is non-negative
|
||||
if self.tolerance.val < 0.0 {
|
||||
errors.push("tolerance cannot be negative".to_string());
|
||||
}
|
||||
|
||||
// Performance warnings
|
||||
if self.input_scale > 20 || self.param_scale > 20 {
|
||||
warn!("High scale values (>20) may impact performance");
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors.join("\n"))
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
return Err("tolerance > 0.0 requires output_visibility to be public".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Exports the configuration as JSON
|
||||
///
|
||||
/// Serializes the RunArgs instance to a JSON string
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(String)` containing JSON representation
|
||||
/// * `Err` if serialization fails
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let res = serde_json::to_string(&self)?;
|
||||
Ok(res)
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
/// Parses configuration from JSON
|
||||
///
|
||||
/// Deserializes a RunArgs instance from a JSON string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `arg_json` - JSON string containing configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(RunArgs)` if parsing succeeds
|
||||
/// * `Err` if parsing fails
|
||||
/// Parse an ezkl configuration from a json
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(arg_json)
|
||||
}
|
||||
}
|
||||
|
||||
// Additional helper functions for the module
|
||||
|
||||
/// Parses a key-value pair from a string in the format "key->value"
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `s` - Input string in the format "key->value"
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok((T, U))` - Parsed key and value
|
||||
/// * `Err` - If parsing fails
|
||||
/// Parse a single key-value pair
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
@@ -525,15 +415,14 @@ where
|
||||
{
|
||||
let pos = s
|
||||
.find("->")
|
||||
.ok_or_else(|| format!("invalid KEY->VALUE: no `->` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 2..].parse()?))
|
||||
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
|
||||
let a = s[..pos].parse()?;
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
}
|
||||
|
||||
/// Verifies that a version string matches the expected artifact version
|
||||
/// Logs warnings for version mismatches or unversioned artifacts
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `artifact_version` - Version string from the artifact
|
||||
/// Check if the version string matches the artifact version
|
||||
/// If the version string does not match the artifact version, log a warning
|
||||
pub fn check_version_string_matches(artifact_version: &str) {
|
||||
if artifact_version == "0.0.0"
|
||||
|| artifact_version == "source - no compatibility guaranteed"
|
||||
@@ -558,98 +447,3 @@ pub fn check_version_string_matches(artifact_version: &str) {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::field_reassign_with_default)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_valid_default_args() {
|
||||
let args = RunArgs::default();
|
||||
assert!(args.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_param_visibility() {
|
||||
let mut args = RunArgs::default();
|
||||
args.param_visibility = Visibility::Public;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Parameters cannot be public instances"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_scale_rebase() {
|
||||
let mut args = RunArgs::default();
|
||||
args.scale_rebase_multiplier = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("scale_rebase_multiplier must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_lookup_range() {
|
||||
let mut args = RunArgs::default();
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Invalid lookup range"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_logrows() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("logrows must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_inner_cols() {
|
||||
let mut args = RunArgs::default();
|
||||
args.num_inner_cols = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("num_inner_cols must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = 1.0;
|
||||
args.output_visibility = Visibility::Private;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negative_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = -1.0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("tolerance cannot be negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_batch_size() {
|
||||
let mut args = RunArgs::default();
|
||||
args.variables = vec![("batch_size".to_string(), 0)];
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("'batch_size' cannot be 0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_serialization() {
|
||||
let args = RunArgs::default();
|
||||
let json = args.as_json().unwrap();
|
||||
let deserialized = RunArgs::from_json(&json).unwrap();
|
||||
assert_eq!(args, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_validation_errors() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
// Should contain multiple error messages
|
||||
assert!(err.matches("\n").count() >= 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +133,7 @@ pub fn aggregate<'a>(
|
||||
.collect_vec()
|
||||
}));
|
||||
|
||||
// loader.ctx().constrain_equal(cell_0, cell_1)
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
|
||||
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
@@ -308,11 +309,11 @@ impl AggregationCircuit {
|
||||
})
|
||||
}
|
||||
|
||||
/// Number of limbs used for decomposition
|
||||
///
|
||||
pub fn num_limbs() -> usize {
|
||||
LIMBS
|
||||
}
|
||||
/// Number of bits used for decomposition
|
||||
///
|
||||
pub fn num_bits() -> usize {
|
||||
BITS
|
||||
}
|
||||
|
||||
@@ -353,7 +353,6 @@ where
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
{
|
||||
/// Create a new application snark from proof and instance variables ready for aggregation
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
protocol: Option<PlonkProtocol<C>>,
|
||||
instances: Vec<Vec<F>>,
|
||||
@@ -529,6 +528,7 @@ pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
@@ -794,6 +794,7 @@ pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
@@ -816,6 +817,7 @@ pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
|
||||
@@ -38,10 +38,4 @@ pub enum TensorError {
|
||||
/// Decomposition error
|
||||
#[error("decomposition error: {0}")]
|
||||
DecompositionError(#[from] DecompositionError),
|
||||
/// Invalid argument
|
||||
#[error("invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
/// Index out of bounds
|
||||
#[error("index {0} out of bounds for dimension {1}")]
|
||||
IndexOutOfBounds(usize, usize),
|
||||
}
|
||||
|
||||
@@ -803,12 +803,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot duplicate every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
|
||||
let mut offset = initial_offset;
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
@@ -838,17 +832,11 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot remove every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Pre-calculate capacity to avoid reallocations
|
||||
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
|
||||
let mut inner = Vec::with_capacity(estimated_size);
|
||||
|
||||
// Use iterator directly instead of creating intermediate collectionsif
|
||||
// Use iterator directly instead of creating intermediate collections
|
||||
let mut i = 0;
|
||||
while i < self.inner.len() {
|
||||
// Add the current element
|
||||
@@ -867,6 +855,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
|
||||
/// Remove indices
|
||||
/// WARN: assumes indices are in ascending order for speed
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
@@ -893,11 +882,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
// remove indices
|
||||
for elem in indices.iter().rev() {
|
||||
if *elem < self.len() {
|
||||
inner.remove(*elem);
|
||||
} else {
|
||||
return Err(TensorError::IndexOutOfBounds(*elem, self.len()));
|
||||
}
|
||||
inner.remove(*elem);
|
||||
}
|
||||
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
@@ -1658,9 +1643,7 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
}
|
||||
|
||||
// implement remainder
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + PartialEq> Rem
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
|
||||
/// Elementwise remainder of a tensor with another tensor.
|
||||
@@ -1689,25 +1672,9 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| {
|
||||
if let Some(zero) = T::zero() {
|
||||
if r != zero {
|
||||
*o = o.clone() % r;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Cannot divide by zero in remainder".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() % r;
|
||||
});
|
||||
|
||||
Ok(lhs)
|
||||
}
|
||||
@@ -1742,6 +1709,7 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
/// assert_eq!(c, vec![2, 3]);
|
||||
///
|
||||
/// ```
|
||||
|
||||
pub fn get_broadcasted_shape(
|
||||
shape_a: &[usize],
|
||||
shape_b: &[usize],
|
||||
@@ -1749,21 +1717,20 @@ pub fn get_broadcasted_shape(
|
||||
let num_dims_a = shape_a.len();
|
||||
let num_dims_b = shape_b.len();
|
||||
|
||||
if num_dims_a == num_dims_b {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
match (num_dims_a, num_dims_b) {
|
||||
(a, b) if a == b => {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
} else if num_dims_a < num_dims_b {
|
||||
Ok(shape_b.to_vec())
|
||||
} else if num_dims_a > num_dims_b {
|
||||
Ok(shape_a.to_vec())
|
||||
} else {
|
||||
Err(TensorError::DimError(
|
||||
(a, b) if a < b => Ok(shape_b.to_vec()),
|
||||
(a, b) if a > b => Ok(shape_a.to_vec()),
|
||||
_ => Err(TensorError::DimError(
|
||||
"Unknown condition for broadcasting".to_string(),
|
||||
))
|
||||
)),
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
|
||||
@@ -385,12 +385,6 @@ pub fn resize<T: TensorType + Send + Sync>(
|
||||
pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("add".to_string()));
|
||||
}
|
||||
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -439,11 +433,6 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("sub".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -490,11 +479,6 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("mult".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,35 +5,33 @@ use log::{debug, error, warn};
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
|
||||
use super::*;
|
||||
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
|
||||
/// VarTensors are used to store and manage circuit columns, typically for assigning ValTensor
|
||||
/// values during circuit layout. The tensor organizes storage into blocks of columns, where each
|
||||
/// block contains multiple columns and each column contains multiple rows.
|
||||
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
|
||||
/// Typically assign [ValTensor]s to [VarTensor]s when laying out a circuit.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq)]
|
||||
pub enum VarTensor {
|
||||
/// A VarTensor for holding Advice values, which are assigned at proving time.
|
||||
Advice {
|
||||
/// Vec of Advice columns, we have [[xx][xx][xx]...] where each inner vec is xx columns
|
||||
inner: Vec<Vec<Column<Advice>>>,
|
||||
/// The number of columns in each inner block
|
||||
///
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// A placeholder tensor used for testing or temporary storage
|
||||
/// Dummy var
|
||||
Dummy {
|
||||
/// The number of columns in each inner block
|
||||
///
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// An empty tensor with no storage
|
||||
/// Empty var
|
||||
#[default]
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// Returns the name of the tensor variant as a static string
|
||||
/// name of the tensor
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => "Advice",
|
||||
@@ -42,35 +40,22 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the tensor is an Advice variant
|
||||
///
|
||||
pub fn is_advice(&self) -> bool {
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
}
|
||||
|
||||
/// Calculates the maximum number of usable rows in the constraint system
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - Log base 2 of the total number of rows (including system and blinding rows)
|
||||
///
|
||||
/// # Returns
|
||||
/// The maximum number of usable rows after accounting for blinding factors
|
||||
pub fn max_rows<F: PrimeField>(cs: &ConstraintSystem<F>, logrows: usize) -> usize {
|
||||
let base = 2u32;
|
||||
base.pow(logrows as u32) as usize - cs.blinding_factors() - 1
|
||||
}
|
||||
|
||||
/// Creates a new VarTensor::Advice with unblinded columns. Unblinded columns are used when
|
||||
/// the values do not need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with unblinded columns enabled for equality constraints
|
||||
/// Create a new VarTensor::Advice that is unblinded
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn new_unblinded_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -108,17 +93,11 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
|
||||
/// the values need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with blinded columns enabled for equality constraints
|
||||
/// Create a new VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn new_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -154,17 +133,11 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
|
||||
/// Fixed columns are used for constant values that are known at circuit creation time.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_constants` - Number of constant values needed
|
||||
/// * `module_requires_fixed` - Whether the module requires at least one fixed column
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of fixed columns created
|
||||
/// Initializes fixed columns to support the VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn constant_cols<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -196,14 +169,7 @@ impl VarTensor {
|
||||
modulo
|
||||
}
|
||||
|
||||
/// Creates a new dummy VarTensor for testing or temporary storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Dummy with the specified dimensions
|
||||
/// Create a new VarTensor::Dummy
|
||||
pub fn dummy(logrows: usize, num_inner_cols: usize) -> Self {
|
||||
let base = 2u32;
|
||||
let max_rows = base.pow(logrows as u32) as usize - 6;
|
||||
@@ -213,7 +179,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of blocks in the tensor
|
||||
/// Gets the dims of the object the VarTensor represents
|
||||
pub fn num_blocks(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner.len(),
|
||||
@@ -221,7 +187,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of columns in each inner block
|
||||
/// Num inner cols
|
||||
pub fn num_inner_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { num_inner_cols, .. } | VarTensor::Dummy { num_inner_cols, .. } => {
|
||||
@@ -231,7 +197,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the total number of columns across all blocks
|
||||
/// Total number of columns
|
||||
pub fn num_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner[0].len() * inner.len(),
|
||||
@@ -239,7 +205,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the maximum number of rows in each column
|
||||
/// Gets the size of each column
|
||||
pub fn col_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { col_size, .. } | VarTensor::Dummy { col_size, .. } => *col_size,
|
||||
@@ -247,7 +213,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the total size of each block (num_inner_cols * col_size)
|
||||
/// Gets the size of each column
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice {
|
||||
@@ -264,13 +230,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a linear coordinate to (block, column, row) coordinates in the storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `linear_coord` - The linear index to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (block_index, column_index, row_index)
|
||||
/// Take a linear coordinate and output the (column, row) position in the storage block.
|
||||
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
|
||||
// x indexes over blocks of size num_inner_cols
|
||||
let x = linear_coord / self.block_size();
|
||||
@@ -283,17 +243,7 @@ impl VarTensor {
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// Queries a range of cells in the tensor during circuit synthesis
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `y` - Column index within block
|
||||
/// * `z` - Starting row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried cells
|
||||
/// Retrieve the value of a specific cell in the tensor.
|
||||
pub fn query_rng<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -318,16 +268,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Queries an entire block of cells at a given offset
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `z` - Row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried block
|
||||
/// Retrieve the value of a specific block at an offset in the tensor.
|
||||
pub fn query_whole_block<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -352,16 +293,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a constant value to a specific cell in the tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `coord` - Coordinate within the tensor
|
||||
/// * `constant` - The constant value to assign
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned cell or an error if assignment fails
|
||||
/// Assigns a constant value to a specific cell in the tensor.
|
||||
pub fn assign_constant<F: PrimeField + TensorType + PartialOrd>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -381,17 +313,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values from a ValTensor to this tensor, excluding specified positions
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `omissions` - Set of positions to skip during assignment
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -422,16 +344,7 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Assigns values from a ValTensor to this tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -483,23 +396,14 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Returns the remaining available space in a column for assignments
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `offset` - Current offset in the column
|
||||
/// * `values` - The ValTensor to check space for
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of rows that need to be flushed or an error if space is insufficient
|
||||
/// Helper function to get the remaining size of the column
|
||||
pub fn get_column_flush<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<usize, halo2_proofs::plonk::Error> {
|
||||
if values.len() > self.col_size() {
|
||||
error!(
|
||||
"There are too many values to flush for this column size, try setting the logrows to a higher value (eg. --logrows 22 on the cli)"
|
||||
);
|
||||
error!("Values are too large for the column");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
|
||||
@@ -523,16 +427,8 @@ impl VarTensor {
|
||||
Ok(flush_len)
|
||||
}
|
||||
|
||||
/// Assigns values to a single column, avoiding column overflow by flushing to the next column if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, number of rows flushed) or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor. Whereby the values are assigned to a single column, without overflowing.
|
||||
/// So for instance if we are assigning 10 values and we are at index 18 of the column, and the columns are of length 20, we skip the last 2 values of current column and start from the beginning of the next column.
|
||||
pub fn assign_exact_column<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -547,17 +443,8 @@ impl VarTensor {
|
||||
Ok((assigned_vals, flush_len))
|
||||
}
|
||||
|
||||
/// Assigns values with duplication in dummy mode, used for testing and simulation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `single_inner_col` - Whether to treat as a single column
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn dummy_assign_with_duplication<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -607,16 +494,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values with duplication but without enforcing constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
pub fn assign_with_duplication_unconstrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -655,18 +533,8 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values with duplication and enforces equality constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `check_mode` - Mode for checking equality constraints
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn assign_with_duplication_constrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -745,17 +613,6 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a single value to the tensor. This is a helper function used by other assignment methods.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `k` - The value to assign
|
||||
/// * `coord` - The coordinate where to assign the value
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned value or an error if assignment fails
|
||||
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -766,28 +623,24 @@ impl VarTensor {
|
||||
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
let res = match k {
|
||||
// Handle direct value assignment
|
||||
ValType::Value(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned value
|
||||
ValType::PrevAssigned(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned constant
|
||||
ValType::AssignedConstant(v, val) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle assigning evaluated value
|
||||
ValType::AssignedValue(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
|
||||
region
|
||||
@@ -796,7 +649,6 @@ impl VarTensor {
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle constant value assignment with caching
|
||||
ValType::Constant(v) => {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
|
||||
let value = ValType::AssignedConstant(
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -28,12 +28,11 @@
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false,
|
||||
"ignore_range_check_inputs_outputs": false
|
||||
"bounded_log_lookup": false
|
||||
},
|
||||
"num_rows": 236,
|
||||
"total_assignments": 472,
|
||||
"total_const_size": 4,
|
||||
"num_rows": 46,
|
||||
"total_assignments": 92,
|
||||
"total_const_size": 3,
|
||||
"total_dynamic_col_size": 0,
|
||||
"max_dynamic_input_len": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
|
||||
Binary file not shown.
@@ -1,6 +1,7 @@
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(test)]
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
@@ -22,8 +23,6 @@ mod native_tests {
|
||||
static COMPILE_WASM: Once = Once::new();
|
||||
static ENV_SETUP: Once = Once::new();
|
||||
|
||||
const TEST_BINARY: &str = "test-runs/ezkl";
|
||||
|
||||
//Sure to run this once
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
@@ -76,8 +75,9 @@ mod native_tests {
|
||||
});
|
||||
}
|
||||
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
fn init_wasm() {
|
||||
pub fn init_wasm() {
|
||||
COMPILE_WASM.call_once(|| {
|
||||
build_wasm_ezkl();
|
||||
});
|
||||
@@ -104,7 +104,7 @@ mod native_tests {
|
||||
|
||||
fn download_srs(logrows: u32, commitment: Commitments) {
|
||||
// if does not exist, download it
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"get-srs",
|
||||
"--logrows",
|
||||
@@ -207,61 +207,61 @@ mod native_tests {
|
||||
];
|
||||
|
||||
const TESTS: [&str; 98] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice", //1
|
||||
"1l_concat", //2
|
||||
"1l_flatten", //3
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
"1l_flatten",
|
||||
// "1l_average",
|
||||
"1l_div", //4
|
||||
"1l_pad", // 5
|
||||
"1l_reshape", //6
|
||||
"1l_eltwise_div", //7
|
||||
"1l_sigmoid", //8
|
||||
"1l_sqrt", //9
|
||||
"1l_softmax", //10
|
||||
"1l_div",
|
||||
"1l_pad", // 5
|
||||
"1l_reshape",
|
||||
"1l_eltwise_div",
|
||||
"1l_sigmoid",
|
||||
"1l_sqrt",
|
||||
"1l_softmax", //10
|
||||
// "1l_instance_norm",
|
||||
"1l_batch_norm", //11
|
||||
"1l_prelu", //12
|
||||
"1l_leakyrelu", //13
|
||||
"1l_gelu_noappx", //14
|
||||
"1l_batch_norm",
|
||||
"1l_prelu",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
// "1l_gelu_tanh_appx",
|
||||
"1l_relu", //15
|
||||
"1l_downsample", //16
|
||||
"1l_tanh", //17
|
||||
"2l_relu_sigmoid_small", //18
|
||||
"2l_relu_fc", //19
|
||||
"2l_relu_small", //20
|
||||
"2l_relu_sigmoid", //21
|
||||
"1l_conv", //22
|
||||
"2l_sigmoid_small", //23
|
||||
"2l_relu_sigmoid_conv", //24
|
||||
"3l_relu_conv_fc", //25
|
||||
"4l_relu_conv_fc", //26
|
||||
"1l_erf", //27
|
||||
"1l_var", //28
|
||||
"1l_elu", //29
|
||||
"min", //30
|
||||
"max", //31
|
||||
"1l_max_pool", //32
|
||||
"1l_conv_transpose", //33
|
||||
"1l_upsample", //34
|
||||
"1l_identity", //35
|
||||
"idolmodel", // too big evm
|
||||
"trig", // too big evm
|
||||
"prelu_gmm", //38
|
||||
"lstm", //39
|
||||
"rnn", //40
|
||||
"quantize_dequantize", //41
|
||||
"1l_where", //42
|
||||
"boolean", //43
|
||||
"boolean_identity", //44
|
||||
"decision_tree", // 45
|
||||
"random_forest", //46
|
||||
"gradient_boosted_trees", //47
|
||||
"1l_topk", //48
|
||||
"xgboost", //49
|
||||
"lightgbm", //50
|
||||
"hummingbird_decision_tree", //51
|
||||
"1l_relu", //15
|
||||
"1l_downsample",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small", //20
|
||||
"2l_relu_sigmoid",
|
||||
"1l_conv",
|
||||
"2l_sigmoid_small",
|
||||
"2l_relu_sigmoid_conv",
|
||||
"3l_relu_conv_fc", //25
|
||||
"4l_relu_conv_fc",
|
||||
"1l_erf",
|
||||
"1l_var",
|
||||
"1l_elu",
|
||||
"min", //30
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"1l_conv_transpose",
|
||||
"1l_upsample",
|
||||
"1l_identity", //35
|
||||
"idolmodel", // too big evm
|
||||
"trig", // too big evm
|
||||
"prelu_gmm",
|
||||
"lstm",
|
||||
"rnn", //40
|
||||
"quantize_dequantize",
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"decision_tree", // 45
|
||||
"random_forest",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
"xgboost",
|
||||
"lightgbm", //50
|
||||
"hummingbird_decision_tree",
|
||||
"oh_decision_tree",
|
||||
"linear_svc",
|
||||
"gather_elements",
|
||||
@@ -628,7 +628,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(32776), Some(5));
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(8194), Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -982,7 +982,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, Some(5));
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1556,7 +1556,7 @@ mod native_tests {
|
||||
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
@@ -1568,7 +1568,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
@@ -1580,7 +1580,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
@@ -1592,7 +1592,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
} else {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
@@ -1641,11 +1641,6 @@ mod native_tests {
|
||||
format!("--commitment={}", commitment),
|
||||
];
|
||||
|
||||
// if output-visibility is fixed set --range-check-inputs-outputs to False
|
||||
if output_visibility == "fixed" {
|
||||
args.push("--ignore-range-check-inputs-outputs".to_string());
|
||||
}
|
||||
|
||||
if let Some(decomp_base) = decomp_base {
|
||||
args.push(format!("--decomp-base={}", decomp_base));
|
||||
}
|
||||
@@ -1658,7 +1653,7 @@ mod native_tests {
|
||||
args.push("--bounded-log-lookup".to_string());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1688,7 +1683,7 @@ mod native_tests {
|
||||
calibrate_args.push(scales);
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(calibrate_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1712,7 +1707,7 @@ mod native_tests {
|
||||
*tolerance = 0.0;
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"compile-circuit",
|
||||
"-M",
|
||||
@@ -1729,7 +1724,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"gen-witness",
|
||||
"-D",
|
||||
@@ -1797,7 +1792,7 @@ mod native_tests {
|
||||
|
||||
// Mock prove (fast, but does not cover some potential issues)
|
||||
fn render_circuit(test_dir: &str, example_name: String) {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"render-circuit",
|
||||
"-M",
|
||||
@@ -1828,7 +1823,7 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
);
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"mock-aggregate",
|
||||
"--logrows=23",
|
||||
@@ -1866,7 +1861,7 @@ mod native_tests {
|
||||
|
||||
download_srs(23, commitment);
|
||||
// now setup-aggregate
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup-aggregate",
|
||||
"--sample-snarks",
|
||||
@@ -1882,7 +1877,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"aggregate",
|
||||
"--logrows=23",
|
||||
@@ -1897,7 +1892,7 @@ mod native_tests {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"verify-aggr",
|
||||
"--logrows=23",
|
||||
@@ -1947,7 +1942,7 @@ mod native_tests {
|
||||
let private_key = format!("--private-key={}", *ANVIL_DEFAULT_PRIVATE_KEY);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
@@ -1969,7 +1964,7 @@ mod native_tests {
|
||||
|
||||
let args = build_args(base_args, &sol_arg);
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1985,7 +1980,7 @@ mod native_tests {
|
||||
private_key.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2007,14 +2002,14 @@ mod native_tests {
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&base_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// As sanity check, add example that should fail.
|
||||
base_args[2] = PF_FAILURE_AGGR;
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(base_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2065,7 +2060,7 @@ mod native_tests {
|
||||
|
||||
init_params(settings_path.clone().into());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup",
|
||||
"-M",
|
||||
@@ -2080,7 +2075,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"prove",
|
||||
"-W",
|
||||
@@ -2098,7 +2093,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"swap-proof-commitments",
|
||||
"--proof-path",
|
||||
@@ -2110,7 +2105,7 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"verify",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
@@ -2133,7 +2128,7 @@ mod native_tests {
|
||||
// get_srs for the graph_settings_num_instances
|
||||
download_srs(1, graph_settings.run_args.commitment.into());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"verify",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
@@ -2183,7 +2178,7 @@ mod native_tests {
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
@@ -2203,7 +2198,7 @@ mod native_tests {
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2215,7 +2210,7 @@ mod native_tests {
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2237,14 +2232,14 @@ mod native_tests {
|
||||
deployed_addr_arg.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// As sanity check, add example that should fail.
|
||||
args[2] = PF_FAILURE;
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2252,7 +2247,6 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier(
|
||||
num_inner_columns: usize,
|
||||
test_dir: &str,
|
||||
@@ -2303,7 +2297,7 @@ mod native_tests {
|
||||
"--reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2318,7 +2312,7 @@ mod native_tests {
|
||||
"-C=verifier/reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2347,7 +2341,7 @@ mod native_tests {
|
||||
&sol_arg_vk,
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2362,7 +2356,7 @@ mod native_tests {
|
||||
"-C=vka",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2375,7 +2369,7 @@ mod native_tests {
|
||||
let deployed_addr_arg_vk = format!("--addr-vk={}", addr_vk);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
@@ -2398,7 +2392,7 @@ mod native_tests {
|
||||
deployed_addr_arg_vk.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2431,7 +2425,7 @@ mod native_tests {
|
||||
// Verify the modified proof (should fail)
|
||||
let mut args_mod = args.clone();
|
||||
args_mod[2] = &modified_pf_arg;
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args_mod)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2509,7 +2503,7 @@ mod native_tests {
|
||||
let test_input_source = format!("--input-source={}", input_source);
|
||||
let test_output_source = format!("--output-source={}", output_source);
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup",
|
||||
"-M",
|
||||
@@ -2524,7 +2518,7 @@ mod native_tests {
|
||||
assert!(status.success());
|
||||
|
||||
// generate the witness, passing the vk path to generate the necessary kzg commits
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"gen-witness",
|
||||
"-D",
|
||||
@@ -2581,7 +2575,7 @@ mod native_tests {
|
||||
}
|
||||
input.save(data_path.clone().into()).unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
@@ -2599,7 +2593,7 @@ mod native_tests {
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"prove",
|
||||
"-W",
|
||||
@@ -2620,7 +2614,7 @@ mod native_tests {
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
@@ -2639,7 +2633,7 @@ mod native_tests {
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2660,7 +2654,7 @@ mod native_tests {
|
||||
args.push("--sol-code-path");
|
||||
args.push(sol_arg.as_str());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2683,7 +2677,7 @@ mod native_tests {
|
||||
create_da_args.push(test_on_chain_data_path.as_str());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&create_da_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2696,7 +2690,7 @@ mod native_tests {
|
||||
};
|
||||
|
||||
let addr_path_da_arg = format!("--addr-path={}/{}/addr_da.txt", test_dir, example_name);
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"deploy-evm-da",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
@@ -2734,14 +2728,14 @@ mod native_tests {
|
||||
deployed_addr_da_arg.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// Create a new set of test on chain data only for the on-chain input source
|
||||
if input_source != "file" || output_source != "file" {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
@@ -2768,7 +2762,7 @@ mod native_tests {
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2784,7 +2778,7 @@ mod native_tests {
|
||||
deployed_addr_da_arg.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -2795,28 +2789,21 @@ mod native_tests {
|
||||
#[cfg(feature = "icicle")]
|
||||
let args = [
|
||||
"build",
|
||||
"--profile=test-runs",
|
||||
"--release",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--features",
|
||||
"icicle",
|
||||
];
|
||||
#[cfg(feature = "macos-metal")]
|
||||
let args = [
|
||||
"build",
|
||||
"--profile=test-runs",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--features",
|
||||
"macos-metal",
|
||||
];
|
||||
let args = ["build", "--release", "--bin", "ezkl", "--features", "macos-metal"];
|
||||
// not macos-metal and not icicle
|
||||
#[cfg(all(not(feature = "icicle"), not(feature = "macos-metal")))]
|
||||
let args = ["build", "--profile=test-runs", "--bin", "ezkl"];
|
||||
let args = ["build", "--release", "--bin", "ezkl"];
|
||||
#[cfg(not(feature = "mv-lookup"))]
|
||||
let args = [
|
||||
"build",
|
||||
"--profile=test-runs",
|
||||
"--release",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--no-default-features",
|
||||
@@ -2837,7 +2824,7 @@ mod native_tests {
|
||||
let status = Command::new("wasm-pack")
|
||||
.args([
|
||||
"build",
|
||||
"--profile=test-runs",
|
||||
"--release",
|
||||
"--target",
|
||||
"nodejs",
|
||||
"--out-dir",
|
||||
|
||||
@@ -72,10 +72,11 @@ mod py_tests {
|
||||
"torchtext==0.17.2",
|
||||
"torchvision==0.17.2",
|
||||
"pandas==2.2.1",
|
||||
"numpy==1.26.4",
|
||||
"seaborn==0.13.2",
|
||||
"notebook==7.1.2",
|
||||
"nbconvert==7.16.3",
|
||||
"onnx==1.17.0",
|
||||
"onnx==1.16.0",
|
||||
"kaggle==1.6.8",
|
||||
"py-solc-x==2.0.3",
|
||||
"web3==7.5.0",
|
||||
@@ -89,13 +90,12 @@ mod py_tests {
|
||||
"xgboost==2.0.3",
|
||||
"hummingbird-ml==0.4.11",
|
||||
"lightgbm==4.3.0",
|
||||
"numpy==1.26.4",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new("pip")
|
||||
.args(["install", "numpy==1.26.4"])
|
||||
.args(["install", "numpy==1.23"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
@@ -126,10 +126,10 @@ mod py_tests {
|
||||
}
|
||||
|
||||
const TESTS: [&str; 35] = [
|
||||
"mnist_gan.ipynb", // 0
|
||||
"ezkl_demo_batch.ipynb", // 1
|
||||
"proof_splitting.ipynb", // 2
|
||||
"variance.ipynb", // 3
|
||||
"ezkl_demo_batch.ipynb", // 0
|
||||
"proof_splitting.ipynb", // 1
|
||||
"variance.ipynb", // 2
|
||||
"mnist_gan.ipynb", // 3
|
||||
"keras_simple_demo.ipynb", // 4
|
||||
"mnist_gan_proof_splitting.ipynb", // 5
|
||||
"hashed_vis.ipynb", // 6
|
||||
|
||||
Reference in New Issue
Block a user