mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
6 Commits
v17.1.5
...
ac/patch-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5290045f06 | ||
|
|
a0078bef6a | ||
|
|
d0ba505baa | ||
|
|
f35688917d | ||
|
|
7ae541ed35 | ||
|
|
675628cd08 |
22
.github/workflows/benchmarks.yml
vendored
22
.github/workflows/benchmarks.yml
vendored
@@ -8,6 +8,8 @@ on:
|
||||
jobs:
|
||||
|
||||
bench_poseidon:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -22,6 +24,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench poseidon
|
||||
|
||||
bench_einsum_accum_matmul:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -37,6 +41,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_einsum_matmul
|
||||
|
||||
bench_accum_matmul_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -52,6 +58,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu
|
||||
|
||||
bench_accum_matmul_relu_overflow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -67,6 +75,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu_overflow
|
||||
|
||||
bench_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -82,6 +92,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench relu
|
||||
|
||||
bench_accum_dot:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -97,6 +109,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_dot
|
||||
|
||||
bench_accum_conv:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -112,6 +126,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_conv
|
||||
|
||||
bench_accum_sumpool:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -127,6 +143,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sumpool
|
||||
|
||||
bench_pairwise_add:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -142,6 +160,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench pairwise_add
|
||||
|
||||
bench_accum_sum:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
@@ -157,6 +177,8 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sum
|
||||
|
||||
bench_pairwise_pow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
|
||||
6
.github/workflows/engine.yml
vendored
6
.github/workflows/engine.yml
vendored
@@ -15,6 +15,9 @@ defaults:
|
||||
working-directory: .
|
||||
jobs:
|
||||
publish-wasm-bindings:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -186,6 +189,9 @@ jobs:
|
||||
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -6,6 +6,8 @@ on:
|
||||
description: "Test scenario tags"
|
||||
jobs:
|
||||
large-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: kaiju
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
3
.github/workflows/pypi-gpu.yml
vendored
3
.github/workflows/pypi-gpu.yml
vendored
@@ -18,6 +18,9 @@ defaults:
|
||||
jobs:
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: GPU
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
19
.github/workflows/pypi.yml
vendored
19
.github/workflows/pypi.yml
vendored
@@ -16,6 +16,8 @@ defaults:
|
||||
|
||||
jobs:
|
||||
macos:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -47,6 +49,13 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'universal2-apple-darwin'
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'x86_64'
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
@@ -64,6 +73,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
windows:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: windows-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -111,6 +122,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -226,6 +239,8 @@ jobs:
|
||||
# path: dist
|
||||
|
||||
musllinux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -291,6 +306,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
musllinux-cross:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -378,6 +395,8 @@ jobs:
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
|
||||
18
.github/workflows/release.yml
vendored
18
.github/workflows/release.yml
vendored
@@ -10,6 +10,9 @@ on:
|
||||
- "*"
|
||||
jobs:
|
||||
create-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: create-release
|
||||
runs-on: ubuntu-22.04
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -33,6 +36,9 @@ jobs:
|
||||
tag_name: ${{ env.EZKL_VERSION }}
|
||||
|
||||
build-release-gpu:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: build-release-gpu
|
||||
needs: ["create-release"]
|
||||
runs-on: GPU
|
||||
@@ -94,6 +100,10 @@ jobs:
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
build-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
issues: write
|
||||
name: build-release
|
||||
needs: ["create-release"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -186,14 +196,18 @@ jobs:
|
||||
echo "target flag is: ${{ env.TARGET_FLAGS }}"
|
||||
echo "target dir is: ${{ env.TARGET_DIR }}"
|
||||
|
||||
- name: Build release binary (no asm)
|
||||
if: matrix.build != 'linux-gnu'
|
||||
- name: Build release binary (no asm or metal)
|
||||
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
|
||||
|
||||
- name: Build release binary (metal)
|
||||
if: matrix.build == 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
|
||||
|
||||
- name: Strip release binary
|
||||
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
|
||||
run: strip "target/${{ matrix.target }}/release/ezkl"
|
||||
|
||||
99
.github/workflows/rust.yml
vendored
99
.github/workflows/rust.yml
vendored
@@ -19,9 +19,27 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
fr-age-test:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: fr age Mock
|
||||
run: cargo test --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
|
||||
|
||||
build:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -35,6 +53,8 @@ jobs:
|
||||
run: cargo build --verbose
|
||||
|
||||
docs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -49,6 +69,8 @@ jobs:
|
||||
run: cargo doc --verbose
|
||||
|
||||
library-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -109,6 +131,8 @@ jobs:
|
||||
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
|
||||
ultra-overflow-tests_og-lookup:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -144,6 +168,8 @@ jobs:
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
|
||||
ultra-overflow-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -179,6 +205,8 @@ jobs:
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
|
||||
|
||||
model-serialization:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -197,6 +225,8 @@ jobs:
|
||||
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
|
||||
|
||||
wasm32-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -224,6 +254,8 @@ jobs:
|
||||
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
|
||||
|
||||
mock-proving-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -289,6 +321,8 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
|
||||
prove-and-verify-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -371,7 +405,44 @@ jobs:
|
||||
- name: KZG prove and verify tests (EVM + hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
|
||||
|
||||
# prove-and-verify-tests-metal:
|
||||
# permissions:
|
||||
# contents: read
|
||||
# runs-on: macos-13
|
||||
# # needs: [build, library-tests, docs]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: jetli/wasm-pack-action@v0.4.0
|
||||
# with:
|
||||
# # Pin to version 0.12.1
|
||||
# version: 'v0.12.1'
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2024-07-18
|
||||
# - uses: actions/checkout@v3
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - name: Use pnpm 8
|
||||
# uses: pnpm/action-setup@v2
|
||||
# with:
|
||||
# version: 8
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --release --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
|
||||
|
||||
|
||||
prove-and-verify-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
@@ -489,6 +560,8 @@ jobs:
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -528,6 +601,8 @@ jobs:
|
||||
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -547,6 +622,8 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
@@ -570,6 +647,8 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
examples:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
@@ -589,6 +668,8 @@ jobs:
|
||||
run: cargo nextest run --release tests_examples
|
||||
|
||||
python-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
@@ -617,7 +698,9 @@ jobs:
|
||||
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
|
||||
|
||||
accuracy-measurement-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -649,6 +732,8 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
|
||||
|
||||
python-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
services:
|
||||
# Label used to access the service container
|
||||
@@ -694,6 +779,8 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: Neural bow
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
@@ -713,14 +800,14 @@ jobs:
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Reusable verifier tutorial
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
|
||||
ios-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -739,6 +826,8 @@ jobs:
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
|
||||
swift-package-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
|
||||
|
||||
3
.github/workflows/static-analysis.yml
vendored
3
.github/workflows/static-analysis.yml
vendored
@@ -8,8 +8,9 @@ on:
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
3
.github/workflows/swift-pm.yml
vendored
3
.github/workflows/swift-pm.yml
vendored
@@ -9,6 +9,9 @@ on:
|
||||
|
||||
jobs:
|
||||
build-and-update:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
108
Cargo.lock
generated
108
Cargo.lock
generated
@@ -1835,6 +1835,16 @@ dependencies = [
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.2"
|
||||
@@ -1848,6 +1858,19 @@ dependencies = [
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"env_filter",
|
||||
"humantime",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
@@ -1923,7 +1946,7 @@ dependencies = [
|
||||
"console_error_panic_hook",
|
||||
"criterion 0.5.1",
|
||||
"ecc",
|
||||
"env_logger",
|
||||
"env_logger 0.10.2",
|
||||
"ethabi",
|
||||
"foundry-compilers",
|
||||
"gag",
|
||||
@@ -1931,7 +1954,7 @@ dependencies = [
|
||||
"halo2_gadgets",
|
||||
"halo2_proofs",
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.7.0",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"hex",
|
||||
"indicatif",
|
||||
"instant",
|
||||
@@ -1939,7 +1962,6 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"metal",
|
||||
"mimalloc",
|
||||
"mnist",
|
||||
"num",
|
||||
@@ -2394,14 +2416,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6#ee4e1a09ebdb1f79f797685b78951c6034c430a6"
|
||||
source = "git+https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996#bf9d0057a82443be48c4779bbe14961c18fb5996"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
"env_logger 0.10.2",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2curves 0.7.0",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"icicle-bn254",
|
||||
"icicle-core",
|
||||
"icicle-cuda-runtime",
|
||||
@@ -2409,6 +2431,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"mopro-msm",
|
||||
"rand_chacha",
|
||||
"rand_core 0.6.4",
|
||||
"rustc-hash 2.0.0",
|
||||
@@ -2497,13 +2520,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.7.0"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851#b753a832e92d5c86c5c997327a9cf9de86a18851"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d380afeef3f1d4d3245b76895172018cfb087d9976a7cabcd5597775b2933e07"
|
||||
dependencies = [
|
||||
"blake2",
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive",
|
||||
"halo2derive 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
@@ -2523,6 +2547,49 @@ dependencies = [
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.7.0"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851#b753a832e92d5c86c5c997327a9cf9de86a18851"
|
||||
dependencies = [
|
||||
"blake2",
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive 0.1.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"pairing",
|
||||
"pasta_curves",
|
||||
"paste",
|
||||
"rand 0.8.5",
|
||||
"rand_core 0.6.4",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_arrays",
|
||||
"sha2",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bdb99e7492b4f5ff469d238db464131b86c2eaac814a78715acba369f64d2c76"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
@@ -3283,7 +3350,8 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "metal"
|
||||
version = "0.29.0"
|
||||
source = "git+https://github.com/gfx-rs/metal-rs#0e1918b34689c4b8cd13a43372f9898680547ee9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"block",
|
||||
@@ -3354,6 +3422,28 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mopro-msm"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/metal-msm-gpu-acceleration.git#be5f647b1a6c1a6ea9024390744a2b4d87f5d002"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"env_logger 0.11.6",
|
||||
"halo2curves 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"instant",
|
||||
"itertools 0.13.0",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"metal",
|
||||
"objc",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
|
||||
@@ -91,7 +91,6 @@ pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", ver
|
||||
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
pyo3-stub-gen = { version = "0.6.0", optional = true }
|
||||
@@ -277,13 +276,14 @@ icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
no-update = []
|
||||
|
||||
macos-metal = ["halo2_proofs/macos"]
|
||||
ios-metal = ["halo2_proofs/ios"]
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
@@ -453,18 +453,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now mock aggregate the proofs\n",
|
||||
"proofs = []\n",
|
||||
"for i in range(3):\n",
|
||||
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
" proofs.append(proof_path)\n",
|
||||
"# proofs = []\n",
|
||||
"# for i in range(3):\n",
|
||||
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
"# proofs.append(proof_path)\n",
|
||||
"\n",
|
||||
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
|
||||
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ezkl",
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -478,7 +478,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
"version": "3.12.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
1
examples/onnx/fr_age/input.json
Normal file
1
examples/onnx/fr_age/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/fr_age/network.onnx
Normal file
BIN
examples/onnx/fr_age/network.onnx
Normal file
Binary file not shown.
@@ -100,9 +100,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
Self::configure_with_cols(
|
||||
@@ -152,9 +149,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
let instance = meta.instance_column();
|
||||
|
||||
@@ -215,15 +215,16 @@ impl DynamicLookups {
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
|
||||
pub struct Shuffles {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub reference_selectors: Vec<Selector>,
|
||||
pub output_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// tables
|
||||
pub references: Vec<VarTensor>,
|
||||
pub outputs: Vec<VarTensor>,
|
||||
}
|
||||
|
||||
impl Shuffles {
|
||||
@@ -234,9 +235,13 @@ impl Shuffles {
|
||||
|
||||
Self {
|
||||
input_selectors: BTreeMap::new(),
|
||||
reference_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone()],
|
||||
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
|
||||
output_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
|
||||
outputs: vec![
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -374,6 +379,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
if inputs[0].num_cols() != output.num_cols() {
|
||||
log::warn!("input and output shapes do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
|
||||
log::warn!("input number of inner columns do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != output.num_inner_cols() {
|
||||
log::warn!("input and output number of inner columns do not match");
|
||||
}
|
||||
|
||||
for i in 0..output.num_blocks() {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
@@ -581,9 +592,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// this is 0 if the index is the same as the column index (starting from 1)
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* table
|
||||
* (table
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier =
|
||||
table.selector_constructor.get_selector_val_at_idx(col_idx);
|
||||
@@ -615,6 +626,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.static_lookups
|
||||
.selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
@@ -740,8 +785,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
pub fn configure_shuffles(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
inputs: &[VarTensor; 2],
|
||||
references: &[VarTensor; 2],
|
||||
inputs: &[VarTensor; 3],
|
||||
outputs: &[VarTensor; 3],
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
@@ -752,14 +797,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
}
|
||||
|
||||
for t in references.iter() {
|
||||
for t in outputs.iter() {
|
||||
if !t.is_advice() || t.num_inner_cols() > 1 {
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// assert all tables have the same number of blocks
|
||||
if references
|
||||
if outputs
|
||||
.iter()
|
||||
.map(|t| t.num_blocks())
|
||||
.collect::<Vec<_>>()
|
||||
@@ -767,23 +812,23 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
.any(|w| w[0] != w[1])
|
||||
{
|
||||
return Err(CircuitError::WrongDynamicColumnType(
|
||||
"references inner cols".to_string(),
|
||||
"outputs inner cols".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
for q in 0..references[0].num_blocks() {
|
||||
let s_reference = cs.complex_selector();
|
||||
for q in 0..outputs[0].num_blocks() {
|
||||
let s_output = cs.complex_selector();
|
||||
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
cs.lookup_any("shuffle", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_referenceq = cs.query_selector(s_reference);
|
||||
let s_outputq = cs.query_selector(s_output);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
|
||||
for input in inputs {
|
||||
@@ -795,9 +840,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
});
|
||||
}
|
||||
|
||||
let mut ref_queries = vec![one.clone()];
|
||||
for reference in references {
|
||||
ref_queries.push(match reference {
|
||||
let mut output_queries = vec![one.clone()];
|
||||
for output in outputs {
|
||||
output_queries.push(match output {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[q][0], Rotation(0))
|
||||
}
|
||||
@@ -806,7 +851,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
|
||||
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
@@ -817,13 +862,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
.or_insert(s_input);
|
||||
}
|
||||
}
|
||||
self.shuffles.reference_selectors.push(s_reference);
|
||||
self.shuffles.output_selectors.push(s_output);
|
||||
}
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.shuffles.references.is_empty() {
|
||||
debug!("assigning shuffles reference");
|
||||
self.shuffles.references = references.to_vec();
|
||||
if self.shuffles.outputs.is_empty() {
|
||||
debug!("assigning shuffles output");
|
||||
self.shuffles.outputs = outputs.to_vec();
|
||||
}
|
||||
if self.shuffles.inputs.is_empty() {
|
||||
debug!("assigning shuffles input");
|
||||
@@ -893,9 +938,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let default_x = range_check.get_first_element(col_idx);
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* range_check
|
||||
* (range_check
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier = range_check
|
||||
.selector_constructor
|
||||
@@ -918,6 +963,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.range_checks
|
||||
.selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
|
||||
@@ -100,4 +100,7 @@ pub enum CircuitError {
|
||||
#[error("invalid input type {0}")]
|
||||
/// Invalid input type
|
||||
InvalidInputType(String),
|
||||
#[error("an element is missing from the shuffled version of the tensor")]
|
||||
/// An element is missing from the shuffled version of the tensor
|
||||
MissingShuffleElement,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::integer_rep_to_felt,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -250,8 +250,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
integer_rep_to_felt(input_scale.0 as IntegerRep),
|
||||
integer_rep_to_felt(output_scale.0 as IntegerRep),
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
@@ -259,7 +259,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(denom.0 as i128),
|
||||
integer_rep_to_felt(denom.0 as IntegerRep),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
|
||||
@@ -43,7 +43,7 @@ const ASCII_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz";
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
@@ -160,9 +160,26 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -235,6 +252,30 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[equal_zero_mask.clone(), equal_inverse_mask],
|
||||
)?;
|
||||
|
||||
let masked_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), not_equal_zero_mask.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let err_func = |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>|
|
||||
@@ -266,7 +307,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 9]),
|
||||
/// &[3, 3],
|
||||
@@ -338,7 +379,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
@@ -377,7 +418,7 @@ pub fn rsqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::layouts::dot;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
@@ -404,22 +445,6 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let mut values = values.clone();
|
||||
|
||||
// this section has been optimized to death, don't mess with it
|
||||
let mut removal_indices = values[0].get_const_zero_indices();
|
||||
let second_zero_indices = values[1].get_const_zero_indices();
|
||||
removal_indices.extend(second_zero_indices);
|
||||
removal_indices.par_sort_unstable();
|
||||
removal_indices.dedup();
|
||||
|
||||
// if empty return a const
|
||||
if removal_indices.len() == values[0].len() {
|
||||
return Ok(create_zero_tensor(1));
|
||||
}
|
||||
|
||||
// is already sorted
|
||||
values[0].remove_indices(&mut removal_indices, true)?;
|
||||
values[1].remove_indices(&mut removal_indices, true)?;
|
||||
|
||||
let mut inputs = vec![];
|
||||
let block_width = config.custom_gates.output.num_inner_cols();
|
||||
|
||||
@@ -492,7 +517,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// // matmul case
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -906,7 +931,6 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let window_b = assigned_sort.get_slice(&[1..assigned_sort.len()])?;
|
||||
|
||||
let is_greater = greater_equal(config, region, &[window_b.clone(), window_a.clone()])?;
|
||||
|
||||
let unit = create_unit_tensor(is_greater.len());
|
||||
|
||||
enforce_equality(config, region, &[unit, is_greater])?;
|
||||
@@ -945,7 +969,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -1168,30 +1192,38 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
}
|
||||
|
||||
/// Shuffle arg
|
||||
/// 1. the input is a set of pairs (index_input, value_input) -- looked up against a (dynamic) set of values (index_output, value_output).
|
||||
/// 2. index_input is copy constrained to values in a fixed column and is thus a fixed set of incrementing values over the input
|
||||
/// 3. value_input is just the input that we are ascertaining is shuffled
|
||||
/// 4. index_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 5. value_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 6. Given the above, and given the fixed index_input , we go through every (index_input, value_input) pair and ascertain that it is contained in the input.
|
||||
/// Given the fixed incrementing index index_input, we avoid multiplicity in the output by leveraging this surrogate index: if index_output isn't matched to the exact value where for `index_input=index_output` -> `value_input=value_output`, then the lookup fails
|
||||
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
output: &[ValTensor<F>; 1],
|
||||
input: &[ValTensor<F>; 1],
|
||||
reference: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let shuffle_index = region.shuffle_index();
|
||||
let (input, reference) = (input[0].clone(), reference[0].clone());
|
||||
let (output, input) = (output[0].clone(), input[0].clone());
|
||||
|
||||
// assert input and reference are same length
|
||||
if input.len() != reference.len() {
|
||||
if output.len() != input.len() {
|
||||
return Err(CircuitError::MismatchedShuffleLength(
|
||||
output.len(),
|
||||
input.len(),
|
||||
reference.len(),
|
||||
));
|
||||
}
|
||||
|
||||
let (reference, flush_len_ref) =
|
||||
region.assign_shuffle(&config.shuffles.references[0], &reference)?;
|
||||
let reference_len = reference.len();
|
||||
let (output, flush_len_ref) = region.assign_shuffle(&config.shuffles.outputs[0], &output)?;
|
||||
let output_len = output.len();
|
||||
let input = region.assign(&config.shuffles.inputs[0], &input)?;
|
||||
|
||||
// now create a vartensor of constants for the shuffle index
|
||||
let index = create_constant_tensor(F::from(shuffle_index as u64), reference_len);
|
||||
let (index, flush_len_index) = region.assign_shuffle(&config.shuffles.references[1], &index)?;
|
||||
let index = create_constant_tensor(F::from(shuffle_index as u64), output_len);
|
||||
let (index, flush_len_index) = region.assign_shuffle(&config.shuffles.outputs[1], &index)?;
|
||||
region.assign(&config.shuffles.inputs[1], &index)?;
|
||||
|
||||
if flush_len_index != flush_len_ref {
|
||||
return Err(CircuitError::MismatchedShuffleLength(
|
||||
@@ -1200,18 +1232,71 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
));
|
||||
}
|
||||
|
||||
let input = region.assign(&config.shuffles.inputs[0], &input)?;
|
||||
region.assign(&config.shuffles.inputs[1], &index)?;
|
||||
// now found the position of each element of the reference to the input
|
||||
|
||||
let is_known = !output.any_unknowns()? && !input.any_unknowns()?;
|
||||
|
||||
let claimed_index_output = if is_known {
|
||||
let input = input.int_evals()?;
|
||||
let output = output.int_evals()?;
|
||||
|
||||
// Keep track of which positions we've used for each value
|
||||
let mut used_positions: HashMap<usize, bool> = HashMap::new();
|
||||
|
||||
let index_output = output
|
||||
.iter()
|
||||
.map(|x| {
|
||||
// Find all positions of the current element
|
||||
let positions: Vec<usize> = input
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, y)| *y == x)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
// Find the first unused position for this element
|
||||
let pos = positions
|
||||
.iter()
|
||||
.find(|&&i| !used_positions.get(&i).unwrap_or(&false))
|
||||
.unwrap_or(&0);
|
||||
|
||||
// Mark this position as used
|
||||
used_positions.insert(*pos, true);
|
||||
|
||||
Ok::<_, CircuitError>(Value::known(F::from(*pos as u64)))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// Verify all indices were used exactly once
|
||||
if !used_positions.values().all(|&x| x) {
|
||||
return Err(CircuitError::MissingShuffleElement);
|
||||
}
|
||||
|
||||
Tensor::from(index_output.into_iter()).into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); output_len]),
|
||||
&[output_len],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
|
||||
region.assign_shuffle(&config.shuffles.outputs[2], &claimed_index_output)?;
|
||||
|
||||
// the incrementing index is the set of numbered values for the input tensor 0...n, and is FIXED
|
||||
let incrementing_index: ValTensor<F> =
|
||||
Tensor::from((0..output.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
||||
region.assign(&config.shuffles.inputs[2], &incrementing_index)?;
|
||||
|
||||
let mut shuffle_block = 0;
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..reference_len)
|
||||
(0..output_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) = config.shuffles.references[0]
|
||||
let (x, _, z) = config.shuffles.outputs[0]
|
||||
.cartesian_coord(region.combined_dynamic_shuffle_coord() + i + flush_len_ref);
|
||||
shuffle_block = x;
|
||||
let ref_selector = config.shuffles.reference_selectors[shuffle_block];
|
||||
let ref_selector = config.shuffles.output_selectors[shuffle_block];
|
||||
region.enable(Some(&ref_selector), z)?;
|
||||
Ok(())
|
||||
})
|
||||
@@ -1220,7 +1305,7 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
|
||||
if !region.is_dummy() {
|
||||
// Enable the selectors
|
||||
(0..reference_len)
|
||||
(0..output_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) =
|
||||
config.custom_gates.inputs[0].cartesian_coord(region.linear_coord() + i);
|
||||
@@ -1237,11 +1322,11 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
region.increment_shuffle_col_coord(reference_len + flush_len_ref);
|
||||
region.increment_shuffle_col_coord(output_len + flush_len_ref);
|
||||
region.increment_shuffle_index(1);
|
||||
region.increment(reference_len);
|
||||
region.increment(output_len);
|
||||
|
||||
Ok(input)
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// One hot accumulated layout
|
||||
@@ -1698,6 +1783,8 @@ pub(crate) fn get_missing_set_elements<
|
||||
// assign the claimed output
|
||||
claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// input and claimed output should be the shuffles of fullset
|
||||
// concatentate input and claimed output
|
||||
let input_and_claimed_output = input.concat(claimed_output.clone())?;
|
||||
@@ -1891,7 +1978,7 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -1989,7 +2076,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2144,7 +2231,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2180,7 +2267,7 @@ pub fn prod_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2216,7 +2303,7 @@ pub fn sum_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2258,7 +2345,7 @@ pub fn argmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
@@ -2294,7 +2381,7 @@ pub fn max_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2340,7 +2427,7 @@ pub fn argmin_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2381,25 +2468,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
lhs.expand(&broadcasted_shape)?;
|
||||
rhs.expand(&broadcasted_shape)?;
|
||||
|
||||
// original values
|
||||
let orig_lhs = lhs.clone();
|
||||
let orig_rhs = rhs.clone();
|
||||
|
||||
let first_zero_indices = HashSet::from_iter(lhs.get_const_zero_indices());
|
||||
let second_zero_indices = HashSet::from_iter(rhs.get_const_zero_indices());
|
||||
|
||||
let removal_indices = match op {
|
||||
BaseOp::Add | BaseOp::Mult => {
|
||||
// join the zero indices
|
||||
first_zero_indices
|
||||
.union(&second_zero_indices)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
BaseOp::Sub => second_zero_indices.clone(),
|
||||
_ => return Err(CircuitError::UnsupportedOp),
|
||||
};
|
||||
|
||||
if lhs.len() != rhs.len() {
|
||||
return Err(CircuitError::DimMismatch(format!(
|
||||
"pairwise {} layout",
|
||||
@@ -2411,11 +2479,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, input)| {
|
||||
let res = region.assign_with_omissions(
|
||||
&config.custom_gates.inputs[i],
|
||||
input,
|
||||
&removal_indices,
|
||||
)?;
|
||||
let res = region.assign(&config.custom_gates.inputs[i], input)?;
|
||||
|
||||
Ok(res.get_inner()?)
|
||||
})
|
||||
@@ -2434,12 +2498,8 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
|
||||
let assigned_len = op_result.len() - removal_indices.len();
|
||||
let mut output = region.assign_with_omissions(
|
||||
&config.custom_gates.output,
|
||||
&op_result.into(),
|
||||
&removal_indices,
|
||||
)?;
|
||||
let assigned_len = op_result.len();
|
||||
let mut output = region.assign(&config.custom_gates.output, &op_result.into())?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
@@ -2457,44 +2517,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
}
|
||||
region.increment(assigned_len);
|
||||
|
||||
let a_tensor = orig_lhs.get_inner_tensor()?;
|
||||
let b_tensor = orig_rhs.get_inner_tensor()?;
|
||||
|
||||
// infill the zero indices with the correct values from values[0] or values[1]
|
||||
if !removal_indices.is_empty() {
|
||||
output
|
||||
.get_inner_tensor_mut()?
|
||||
.par_enum_map_mut_filtered(&removal_indices, |i| {
|
||||
let val = match op {
|
||||
BaseOp::Add => {
|
||||
let a_is_null = first_zero_indices.contains(&i);
|
||||
let b_is_null = second_zero_indices.contains(&i);
|
||||
|
||||
if a_is_null && b_is_null {
|
||||
ValType::Constant(F::ZERO)
|
||||
} else if a_is_null {
|
||||
b_tensor[i].clone()
|
||||
} else {
|
||||
a_tensor[i].clone()
|
||||
}
|
||||
}
|
||||
BaseOp::Sub => {
|
||||
let a_is_null = first_zero_indices.contains(&i);
|
||||
// by default b is null in this case for sub
|
||||
if a_is_null {
|
||||
ValType::Constant(F::ZERO)
|
||||
} else {
|
||||
a_tensor[i].clone()
|
||||
}
|
||||
}
|
||||
BaseOp::Mult => ValType::Constant(F::ZERO),
|
||||
// can safely panic as the prior check ensures this is not called
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok::<_, TensorError>(val)
|
||||
})?;
|
||||
}
|
||||
|
||||
output.reshape(&broadcasted_shape)?;
|
||||
|
||||
let end = global_start.elapsed();
|
||||
@@ -2522,7 +2544,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2579,7 +2601,7 @@ pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 12, 6, 4, 5, 6]),
|
||||
@@ -2607,7 +2629,8 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
|
||||
let sign = sign(config, region, &[diff])?;
|
||||
equals(config, region, &[sign, create_unit_tensor(1)])
|
||||
let eq = equals(config, region, &[sign, create_unit_tensor(1)])?;
|
||||
Ok(eq)
|
||||
}
|
||||
|
||||
/// Greater equals than operation.
|
||||
@@ -2626,7 +2649,7 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -2676,7 +2699,7 @@ pub fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 5, 4, 5, 1]),
|
||||
@@ -2717,7 +2740,7 @@ pub fn less<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 5, 4, 5, 1]),
|
||||
@@ -2758,7 +2781,7 @@ pub fn less_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2802,7 +2825,7 @@ pub fn and<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2850,7 +2873,7 @@ pub fn or<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2881,6 +2904,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::H
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let values = values[0].clone();
|
||||
let values_inverse = values.inverse()?;
|
||||
|
||||
let product_values_and_invert = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -2924,7 +2948,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::H
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2980,7 +3004,7 @@ pub fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -3023,7 +3047,7 @@ pub fn not<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let mask = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 1, 0, 1, 0]),
|
||||
@@ -3081,7 +3105,7 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
@@ -3113,7 +3137,7 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3195,7 +3219,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3249,32 +3273,30 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
output
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(flat_index, o)| {
|
||||
let coord = &cartesian_coord[flat_index];
|
||||
let (b, i) = (coord[0], coord[1]);
|
||||
let inner_loop_function = |idx: usize, region: &mut RegionCtx<F>| {
|
||||
let coord = &cartesian_coord[idx];
|
||||
let (b, i) = (coord[0], coord[1]);
|
||||
|
||||
let mut slice = vec![b..b + 1, i..i + 1];
|
||||
slice.extend(
|
||||
coord[2..]
|
||||
.iter()
|
||||
.zip(stride.iter())
|
||||
.zip(pool_dims.iter())
|
||||
.map(|((c, s), k)| {
|
||||
let start = c * s;
|
||||
let end = start + k;
|
||||
start..end
|
||||
}),
|
||||
);
|
||||
let mut slice = vec![b..b + 1, i..i + 1];
|
||||
slice.extend(
|
||||
coord[2..]
|
||||
.iter()
|
||||
.zip(stride.iter())
|
||||
.zip(pool_dims.iter())
|
||||
.map(|((c, s), k)| {
|
||||
let start = c * s;
|
||||
let end = start + k;
|
||||
start..end
|
||||
}),
|
||||
);
|
||||
|
||||
let slice = padded_image.get_slice(&slice)?;
|
||||
let max_w = max(config, region, &[slice])?;
|
||||
*o = max_w.get_inner_tensor()?[0].clone();
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
let slice = padded_image.get_slice(&slice)?;
|
||||
let max_w = max(config, region, &[slice])?;
|
||||
|
||||
Ok::<_, CircuitError>(max_w.get_inner_tensor()?[0].clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let res: ValTensor<F> = output.into();
|
||||
|
||||
@@ -3296,7 +3318,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3524,7 +3546,7 @@ pub fn deconv<
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
@@ -4248,7 +4270,7 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4301,7 +4323,7 @@ pub fn max_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4340,10 +4362,7 @@ pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input_len = values[0].len();
|
||||
_sort_ascending(config, region, values)?
|
||||
.get_slice(&[input_len - 1..input_len])
|
||||
.map_err(|e| e.into())
|
||||
Ok(_sort_ascending(config, region, values)?.last()?)
|
||||
}
|
||||
|
||||
/// min layout
|
||||
@@ -4352,9 +4371,7 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
_sort_ascending(config, region, values)?
|
||||
.get_slice(&[0..1])
|
||||
.map_err(|e| e.into())
|
||||
Ok(_sort_ascending(config, region, values)?.first()?)
|
||||
}
|
||||
|
||||
/// floor layout
|
||||
@@ -4377,7 +4394,7 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, -3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4489,7 +4506,7 @@ pub fn floor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4602,7 +4619,7 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4889,7 +4906,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -5032,7 +5049,7 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -5267,7 +5284,24 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
let rest = claimed_output.get_slice(&rest_slice)?;
|
||||
|
||||
let sign = range_check(config, region, &[sign], &(-1, 1))?;
|
||||
let rest = range_check(config, region, &[rest], &(0, (*base - 1) as i128))?;
|
||||
|
||||
// isZero(input) * sign == 0.
|
||||
{
|
||||
let is_zero = equals_zero(config, region, &[input.clone()])?;
|
||||
// take the product of the sign and is_zero
|
||||
let sign_is_zero = pairwise(config, region, &[sign.clone(), is_zero], BaseOp::Mult)?;
|
||||
// constrain the sign_is_zero to be 0
|
||||
enforce_equality(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sign_is_zero.clone(),
|
||||
create_constant_tensor(F::ZERO, sign_is_zero.len()),
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
||||
let rest = range_check(config, region, &[rest], &(0, (*base - 1) as IntegerRep))?;
|
||||
|
||||
// equation needs to be constructed as ij,ij->i but for arbitrary n dims we need to construct this dynamically
|
||||
// indices should map in order of the alphabet
|
||||
@@ -5531,7 +5565,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 2, 3, 2, 2, 0]),
|
||||
@@ -5582,7 +5616,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[100, 200, 300, 400, 500, 600]),
|
||||
|
||||
@@ -132,21 +132,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size given the number of rows and reserved blinding rows
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
(range_len / col_size as IntegerRep) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
@@ -355,16 +350,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
|
||||
@@ -1040,6 +1040,10 @@ mod conv {
|
||||
let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
|
||||
// column for constants
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1171,7 +1175,7 @@ mod conv_col_ultra_overflow {
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const K: usize = 6;
|
||||
const LEN: usize = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -1191,9 +1195,10 @@ mod conv_col_ultra_overflow {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * LEN * LEN * LEN, false);
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1776,13 +1781,18 @@ mod shuffle {
|
||||
|
||||
let d = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let e = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
|
||||
|
||||
let mut config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
|
||||
config
|
||||
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
|
||||
.configure_shuffles(
|
||||
cs,
|
||||
&[a.clone(), b.clone(), c.clone()],
|
||||
&[d.clone(), e.clone(), f.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,10";
|
||||
/// Default use reduced srs for verification
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
|
||||
@@ -5,7 +5,7 @@ use halo2curves::ff::PrimeField;
|
||||
/// Integer representation of a PrimeField element.
|
||||
pub type IntegerRep = i128;
|
||||
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
/// Converts an integer rep to a PrimeField element.
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
@@ -69,7 +69,7 @@ mod test {
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,12 @@ pub enum GraphError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Non scalar power
|
||||
#[error("we only support scalar powers")]
|
||||
NonScalarPower,
|
||||
/// Non scalar base for exponentiation
|
||||
#[error("we only support scalar bases for exponentiation")]
|
||||
NonScalarBase,
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
@@ -143,4 +149,7 @@ pub enum GraphError {
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
/// Only nearest neighbor interpolation is supported
|
||||
#[error("only nearest neighbor interpolation is supported")]
|
||||
InvalidInterpolation,
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
@@ -515,7 +514,7 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
|
||||
/// Input to graph as a datasource
|
||||
/// Always use JSON serialization for GraphData. Seriously.
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq)]
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
|
||||
pub struct GraphData {
|
||||
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
||||
pub input_data: DataSource,
|
||||
@@ -769,18 +768,6 @@ impl ToPyObject for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for GraphData {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut state = serializer.serialize_struct("GraphData", 4)?;
|
||||
state.serialize_field("input_data", &self.input_data)?;
|
||||
state.serialize_field("output_data", &self.output_data)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1045,8 +1045,8 @@ impl Model {
|
||||
if settings.requires_shuffle() {
|
||||
base_gate.configure_shuffles(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
vars.advices[3..5].try_into()?,
|
||||
vars.advices[0..3].try_into()?,
|
||||
vars.advices[3..6].try_into()?,
|
||||
)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -44,11 +44,10 @@ use tract_onnx::tract_hir::{
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `vec` - the vector to quantize.
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `elem` - the element to quantize.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(
|
||||
@@ -85,7 +84,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
|
||||
f64::powf(2., scale as f64)
|
||||
}
|
||||
|
||||
/// Converts a scale (log base 2) to a fixed point multiplier.
|
||||
/// Converts a fixed point multiplier to a scale (log base 2).
|
||||
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
mult.log2().round() as crate::Scale
|
||||
}
|
||||
@@ -312,6 +311,9 @@ pub fn new_op_from_onnx(
|
||||
let mut deleted_indices = vec![];
|
||||
let node = match node.op().name().as_ref() {
|
||||
"ShiftLeft" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -324,10 +326,13 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
|
||||
}
|
||||
}
|
||||
"ShiftRight" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -340,7 +345,7 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
@@ -363,7 +368,10 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
|
||||
if input_ops.len() != 3 {
|
||||
return Err(GraphError::InvalidDims(idx, "range".to_string()));
|
||||
}
|
||||
|
||||
let input_ops = input_ops
|
||||
.iter()
|
||||
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
|
||||
@@ -419,6 +427,10 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
|
||||
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| {
|
||||
@@ -447,8 +459,17 @@ pub fn new_op_from_onnx(
|
||||
"Topk" => {
|
||||
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
};
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
}
|
||||
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
c.raw_values.map(|x| x as usize)[0]
|
||||
@@ -488,6 +509,10 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -522,6 +547,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -555,6 +583,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -589,6 +620,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -684,7 +718,9 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
assert_eq!(axes.len(), 1, "only support argmax over one axis");
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
|
||||
}
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
|
||||
}
|
||||
@@ -694,7 +730,9 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
assert_eq!(axes.len(), 1, "only support argmin over one axis");
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
|
||||
}
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
|
||||
}
|
||||
@@ -803,6 +841,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
@@ -846,6 +887,9 @@ pub fn new_op_from_onnx(
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
@@ -933,7 +977,9 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
|
||||
let dt = op.to;
|
||||
|
||||
assert_eq!(input_scales.len(), 1);
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
|
||||
};
|
||||
|
||||
match dt {
|
||||
DatumType::Bool
|
||||
@@ -983,6 +1029,11 @@ pub fn new_op_from_onnx(
|
||||
|
||||
if const_idx.len() == 1 {
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if inputs.len() <= const_idx {
|
||||
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
|
||||
// if not divisible by 2 then we need to add a range check
|
||||
@@ -1057,6 +1108,9 @@ pub fn new_op_from_onnx(
|
||||
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
|
||||
}
|
||||
};
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
|
||||
}
|
||||
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
@@ -1096,22 +1150,42 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Ceil" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Floor" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Round" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "round".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"RoundHalfToEven" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
// Extract the slope layer hyperparams from a const
|
||||
@@ -1121,7 +1195,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
return Err(GraphError::NonScalarPower);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
@@ -1138,7 +1214,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[0].decrement_use();
|
||||
deleted_indices.push(0);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar base")
|
||||
return Err(GraphError::NonScalarBase);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
|
||||
let base = c.raw_values[0];
|
||||
@@ -1148,10 +1226,14 @@ pub fn new_op_from_onnx(
|
||||
base: base.into(),
|
||||
})
|
||||
} else {
|
||||
unimplemented!("only support constant base or pow for now")
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
}
|
||||
"Div" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -1159,14 +1241,15 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 {
|
||||
if const_idx.len() > 1 || const_idx.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if const_idx != 1 {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
@@ -1180,10 +1263,14 @@ pub fn new_op_from_onnx(
|
||||
denom: denom.into(),
|
||||
})
|
||||
} else {
|
||||
unimplemented!("only support non zero divisors of size 1")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support non zero divisors of size 1".to_string(),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
|
||||
@@ -1323,7 +1410,7 @@ pub fn new_op_from_onnx(
|
||||
if !resize_node.contains("interpolator: Nearest")
|
||||
&& !resize_node.contains("nearest: Floor")
|
||||
{
|
||||
unimplemented!("Only nearest neighbor interpolation is supported")
|
||||
return Err(GraphError::InvalidInterpolation);
|
||||
}
|
||||
// check if optional scale factor is present
|
||||
if inputs.len() != 2 && inputs.len() != 3 {
|
||||
@@ -1427,6 +1514,10 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Linear(PolyOp::Reshape(output_shape))
|
||||
}
|
||||
"Flatten" => {
|
||||
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
|
||||
};
|
||||
|
||||
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
|
||||
SupportedOp::Linear(PolyOp::Flatten(new_dims))
|
||||
}
|
||||
@@ -1546,6 +1637,7 @@ pub fn homogenize_input_scales(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// tests for the utility module
|
||||
pub mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -435,7 +435,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
.collect_vec();
|
||||
|
||||
if requires_dynamic_lookup || requires_shuffle {
|
||||
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
|
||||
let num_cols = 3;
|
||||
for _ in 0..num_cols {
|
||||
let dynamic_lookup =
|
||||
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
|
||||
|
||||
@@ -24,9 +24,6 @@ use std::path::PathBuf;
|
||||
pub use val::*;
|
||||
pub use var::*;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use instant::Instant;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
@@ -40,8 +37,6 @@ use halo2_proofs::{
|
||||
poly::Rotation,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
#[cfg(feature = "metal")]
|
||||
use metal::{Device, MTLResourceOptions, MTLSize};
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::io::Read;
|
||||
@@ -49,31 +44,6 @@ use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
lazy_static::lazy_static! {
|
||||
static ref DEVICE: Device = Device::system_default().expect("no device found");
|
||||
|
||||
static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap();
|
||||
|
||||
static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue();
|
||||
|
||||
static ref PIPELINES: HashMap<String, metal::ComputePipelineState> = {
|
||||
let mut map = HashMap::new();
|
||||
for name in ["add", "sub", "mul"] {
|
||||
let function = LIB.get_function(name, None).unwrap();
|
||||
let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap();
|
||||
map.insert(name.to_string(), pipeline);
|
||||
}
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
pub trait TensorType: Clone + Debug + 'static {
|
||||
/// Returns the zero value.
|
||||
@@ -1404,10 +1374,6 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "add");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1505,10 +1471,6 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "sub");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1576,10 +1538,6 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "mul");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1811,66 +1769,4 @@ mod tests {
|
||||
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_int() {
|
||||
let a = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let b = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_felt() {
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
let a = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
let b = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ pub fn get_rep(
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
|
||||
if (*x).abs() > ((base as i128).pow(n as u32)) - 1 {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
@@ -43,8 +43,8 @@ pub fn get_rep(
|
||||
let mut x = x.abs();
|
||||
//
|
||||
for i in (1..rep.len()).rev() {
|
||||
rep[i] = x % base as i128;
|
||||
x /= base as i128;
|
||||
rep[i] = x % base as IntegerRep;
|
||||
x /= base as IntegerRep;
|
||||
}
|
||||
|
||||
Ok(rep)
|
||||
@@ -127,7 +127,7 @@ pub fn decompose(
|
||||
.flatten()
|
||||
.collect::<Vec<IntegerRep>>();
|
||||
|
||||
let output = Tensor::<i128>::new(Some(&resp), &dims)?;
|
||||
let output = Tensor::<IntegerRep>::new(Some(&resp), &dims)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -187,13 +187,14 @@ mod native_tests {
|
||||
|
||||
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
|
||||
|
||||
const LARGE_TESTS: [&str; 6] = [
|
||||
const LARGE_TESTS: [&str; 7] = [
|
||||
"self_attention",
|
||||
"nanoGPT",
|
||||
"multihead_attention",
|
||||
"mobilenet",
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age",
|
||||
];
|
||||
|
||||
const ACCURACY_CAL_TESTS: [&str; 6] = [
|
||||
@@ -395,29 +396,29 @@ mod native_tests {
|
||||
const TESTS_AGGR: [&str; 3] = ["1l_mlp", "1l_flatten", "1l_average"];
|
||||
|
||||
const TESTS_EVM: [&str; 23] = [
|
||||
"1l_mlp",
|
||||
"1l_flatten",
|
||||
"1l_average",
|
||||
"1l_reshape",
|
||||
"1l_sigmoid",
|
||||
"1l_div",
|
||||
"1l_sqrt",
|
||||
"1l_prelu",
|
||||
"1l_var",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
"1l_relu",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_small",
|
||||
"min",
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"idolmodel",
|
||||
"1l_identity",
|
||||
"lstm",
|
||||
"rnn",
|
||||
"quantize_dequantize",
|
||||
"1l_mlp", // 0
|
||||
"1l_flatten", // 1
|
||||
"1l_average", // 2
|
||||
"1l_reshape", // 3
|
||||
"1l_sigmoid", // 4
|
||||
"1l_div", // 5
|
||||
"1l_sqrt", // 6
|
||||
"1l_prelu", // 7
|
||||
"1l_var", // 8
|
||||
"1l_leakyrelu", // 9
|
||||
"1l_gelu_noappx", // 10
|
||||
"1l_relu", // 11
|
||||
"1l_tanh", // 12
|
||||
"2l_relu_sigmoid_small", // 13
|
||||
"2l_relu_small", // 14
|
||||
"min", // 15
|
||||
"max", // 16
|
||||
"1l_max_pool", // 17
|
||||
"idolmodel", // 18
|
||||
"1l_identity", // 19
|
||||
"lstm", // 20
|
||||
"rnn", // 21
|
||||
"quantize_dequantize", // 22
|
||||
];
|
||||
|
||||
const TESTS_EVM_AGGR: [&str; 18] = [
|
||||
@@ -541,7 +542,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -606,7 +607,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -616,7 +617,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -627,7 +628,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(8194), Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -642,7 +643,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -652,7 +653,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -661,7 +662,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -670,7 +671,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -679,7 +680,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -688,7 +689,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -697,7 +698,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -706,7 +707,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -716,7 +717,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -726,7 +727,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -735,7 +736,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -745,7 +746,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -754,7 +755,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -764,7 +765,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -774,7 +775,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -784,7 +785,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -794,7 +795,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -804,7 +805,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -963,7 +964,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=5 {
|
||||
seq!(N in 0..=6 {
|
||||
|
||||
#(#[test_case(LARGE_TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -981,7 +982,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1459,6 +1460,8 @@ mod native_tests {
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
@@ -1475,6 +1478,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
decomp_base,
|
||||
decomp_legs,
|
||||
);
|
||||
|
||||
if tolerance > 0.0 {
|
||||
@@ -1616,6 +1621,8 @@ mod native_tests {
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1634,6 +1641,14 @@ mod native_tests {
|
||||
format!("--commitment={}", commitment),
|
||||
];
|
||||
|
||||
if let Some(decomp_base) = decomp_base {
|
||||
args.push(format!("--decomp-base={}", decomp_base));
|
||||
}
|
||||
|
||||
if let Some(decomp_legs) = decomp_legs {
|
||||
args.push(format!("--decomp-legs={}", decomp_legs));
|
||||
}
|
||||
|
||||
if bounded_lookup_log {
|
||||
args.push("--bounded-log-lookup".to_string());
|
||||
}
|
||||
@@ -1751,6 +1766,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -2035,6 +2052,8 @@ mod native_tests {
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -2467,6 +2486,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
@@ -2774,7 +2795,10 @@ mod native_tests {
|
||||
"--features",
|
||||
"icicle",
|
||||
];
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(feature = "macos-metal")]
|
||||
let args = ["build", "--release", "--bin", "ezkl", "--features", "macos-metal"];
|
||||
// not macos-metal and not icicle
|
||||
#[cfg(all(not(feature = "icicle"), not(feature = "macos-metal")))]
|
||||
let args = ["build", "--release", "--bin", "ezkl"];
|
||||
#[cfg(not(feature = "mv-lookup"))]
|
||||
let args = [
|
||||
|
||||
@@ -72,11 +72,10 @@ mod py_tests {
|
||||
"torchtext==0.17.2",
|
||||
"torchvision==0.17.2",
|
||||
"pandas==2.2.1",
|
||||
"numpy==1.26.4",
|
||||
"seaborn==0.13.2",
|
||||
"notebook==7.1.2",
|
||||
"nbconvert==7.16.3",
|
||||
"onnx==1.16.0",
|
||||
"onnx==1.17.0",
|
||||
"kaggle==1.6.8",
|
||||
"py-solc-x==2.0.3",
|
||||
"web3==7.5.0",
|
||||
@@ -90,12 +89,13 @@ mod py_tests {
|
||||
"xgboost==2.0.3",
|
||||
"hummingbird-ml==0.4.11",
|
||||
"lightgbm==4.3.0",
|
||||
"numpy==1.26.4",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new("pip")
|
||||
.args(["install", "numpy==1.23"])
|
||||
.args(["install", "numpy==1.26.4"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
|
||||
@@ -873,6 +873,7 @@ def get_examples():
|
||||
'linear_regression',
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age"
|
||||
]
|
||||
examples = []
|
||||
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
|
||||
|
||||
Reference in New Issue
Block a user