mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
2 Commits
v18.1.8
...
ac/patch-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5290045f06 | ||
|
|
a0078bef6a |
72
.github/workflows/engine.yml
vendored
72
.github/workflows/engine.yml
vendored
@@ -19,8 +19,6 @@ jobs:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
@@ -47,39 +45,43 @@ jobs:
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
echo '{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "${RELEASE_TAG}",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}' > pkg/package.json
|
||||
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
@@ -193,8 +195,6 @@ jobs:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -202,8 +202,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${RELEASE_TAG}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
|
||||
|
||||
4
.github/workflows/pypi-gpu.yml
vendored
4
.github/workflows/pypi-gpu.yml
vendored
@@ -25,8 +25,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -51,6 +49,8 @@ jobs:
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
|
||||
109
.github/workflows/pypi.yml
vendored
109
.github/workflows/pypi.yml
vendored
@@ -23,8 +23,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64, universal2-apple-darwin]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -34,14 +32,10 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
@@ -95,14 +89,6 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: ${{ matrix.target }}
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -152,14 +138,6 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -170,6 +148,7 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -208,6 +187,57 @@ jobs:
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
|
||||
# linux-cross:
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
# matrix:
|
||||
# target: [aarch64, armv7]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: 3.12
|
||||
|
||||
# - name: Install cross-compilation tools for aarch64
|
||||
# if: matrix.target == 'aarch64'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
|
||||
|
||||
# - name: Install cross-compilation tools for armv7
|
||||
# if: matrix.target == 'armv7'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
|
||||
|
||||
# - name: Build wheels
|
||||
# uses: PyO3/maturin-action@v1
|
||||
# with:
|
||||
# target: ${{ matrix.target }}
|
||||
# manylinux: auto
|
||||
# args: --release --out dist --features python-bindings
|
||||
|
||||
# - uses: uraimo/run-on-arch-action@v2.5.0
|
||||
# name: Install built wheel
|
||||
# with:
|
||||
# arch: ${{ matrix.target }}
|
||||
# distro: ubuntu20.04
|
||||
# githubToken: ${{ github.token }}
|
||||
# install: |
|
||||
# apt-get update
|
||||
# apt-get install -y --no-install-recommends python3 python3-pip
|
||||
# pip3 install -U pip
|
||||
# run: |
|
||||
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
|
||||
# python3 -c "import ezkl"
|
||||
|
||||
# - name: Upload wheels
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: wheels
|
||||
# path: dist
|
||||
|
||||
musllinux:
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -243,7 +273,6 @@ jobs:
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -294,14 +323,6 @@ jobs:
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -345,6 +366,8 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
if: "startsWith(github.ref, 'refs/tags/')"
|
||||
# TODO: Uncomment if linux-cross is working
|
||||
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
|
||||
needs: [macos, windows, linux, musllinux, musllinux-cross]
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
@@ -352,20 +375,24 @@ jobs:
|
||||
name: wheels
|
||||
- name: List Files
|
||||
run: ls -R
|
||||
|
||||
# # publishes to TestPyPI
|
||||
# - name: Publish package distribution to TestPyPI
|
||||
# uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# packages-dir: ./
|
||||
|
||||
# Both publish steps will fail if there is no trusted publisher setup
|
||||
# On failure the publish step will then simply continue to the next one
|
||||
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
permissions:
|
||||
@@ -382,4 +409,4 @@ jobs:
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
|
||||
158
.github/workflows/rust.yml
vendored
158
.github/workflows/rust.yml
vendored
@@ -19,6 +19,7 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
|
||||
fr-age-test:
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -98,8 +99,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -239,7 +240,7 @@ jobs:
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
version: 'v0.12.1'
|
||||
- uses: nanasess/setup-chromedriver@v2
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
@@ -438,6 +439,7 @@ jobs:
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --release --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
|
||||
|
||||
|
||||
prove-and-verify-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -455,7 +457,7 @@ jobs:
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
version: 'v0.12.1'
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
@@ -526,8 +528,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -584,8 +586,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -801,84 +803,84 @@ jobs:
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Reusable verifier tutorial
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
|
||||
ios-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
|
||||
swift-package-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
22
Cargo.lock
generated
22
Cargo.lock
generated
@@ -1,6 +1,6 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
@@ -1760,7 +1760,7 @@ checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
|
||||
[[package]]
|
||||
name = "ecc"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"integer",
|
||||
"num-bigint",
|
||||
@@ -1968,11 +1968,13 @@ dependencies = [
|
||||
"objc",
|
||||
"openssl",
|
||||
"pg_bigdecimal",
|
||||
"portable-atomic",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"semver 1.0.22",
|
||||
"seq-macro",
|
||||
@@ -2604,7 +2606,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2_proofs",
|
||||
"num-bigint",
|
||||
@@ -2955,7 +2957,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "integer"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"maingate",
|
||||
"num-bigint",
|
||||
@@ -3139,7 +3141,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.52.6",
|
||||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3266,7 +3268,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "maingate"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2wrong",
|
||||
"num-bigint",
|
||||
@@ -3675,9 +3677,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-src"
|
||||
version = "300.4.1+3.4.0"
|
||||
version = "300.2.3+3.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faa4eac4138c62414b5622d1b31c5c304f34b406b013c079c2bbc652fdd6678c"
|
||||
checksum = "5cff92b6f71555b61bb9315f7c64da3ca43d87531622120fea0195fc761b4843"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
@@ -5230,7 +5232,7 @@ checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
|
||||
[[package]]
|
||||
name = "snark-verifier"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac%2Fchunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
dependencies = [
|
||||
"ecc",
|
||||
"halo2_proofs",
|
||||
@@ -6234,7 +6236,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "uniffi_testing"
|
||||
version = "0.28.0"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat%2Ftesting-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat/testing-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"camino",
|
||||
|
||||
@@ -40,6 +40,7 @@ maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = { version = "1.6.0", optional = true }
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
|
||||
semver = { version = "1.0.22", optional = true }
|
||||
|
||||
@@ -73,6 +74,7 @@ tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
regex = { version = "1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
@@ -242,14 +244,16 @@ ezkl = [
|
||||
"dep:indicatif",
|
||||
"dep:gag",
|
||||
"dep:reqwest",
|
||||
"dep:openssl",
|
||||
"dep:tokio-postgres",
|
||||
"dep:pg_bigdecimal",
|
||||
"dep:lazy_static",
|
||||
"dep:regex",
|
||||
"dep:tokio",
|
||||
"dep:openssl",
|
||||
"dep:mimalloc",
|
||||
"dep:chrono",
|
||||
"dep:sha256",
|
||||
"dep:portable-atomic",
|
||||
"dep:clap_complete",
|
||||
"dep:halo2_solidity_verifier",
|
||||
"dep:semver",
|
||||
|
||||
@@ -10,7 +10,6 @@ use rand::Rng;
|
||||
|
||||
// Assuming these are your types
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
enum ValType {
|
||||
Constant(F),
|
||||
AssignedConstant(usize, F),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '18.1.8'
|
||||
release = '0.0.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
// ignore file if compiling for wasm
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::{CommandFactory, Parser};
|
||||
|
||||
@@ -170,10 +170,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
if message.len() != 1 {
|
||||
return Err(ModuleError::InputWrongLength(message.len()));
|
||||
}
|
||||
|
||||
assert_eq!(message.len(), 1);
|
||||
let message = message[0].clone();
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -228,7 +225,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"AssignedValue".to_string(),
|
||||
"PrevAssigned".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
@@ -293,12 +290,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
|
||||
// empty hash case
|
||||
if input_cells.is_empty() {
|
||||
return Ok(input[0].clone());
|
||||
}
|
||||
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
input_cells.iter().map(|e| ValType::from(e.clone())).into();
|
||||
@@ -520,21 +511,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash_empty() {
|
||||
let message = [];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 2> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash() {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::str::FromStr;
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector, TableColumn},
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
poly::Rotation,
|
||||
};
|
||||
use log::debug;
|
||||
@@ -341,8 +341,6 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
/// shared table inputs
|
||||
pub shared_table_inputs: Vec<TableColumn>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
@@ -355,7 +353,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
shared_table_inputs: vec![],
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -500,7 +497,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
shared_table_inputs: vec![],
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -531,9 +527,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
return Err(CircuitError::WrongColumnType(output.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let table = if !self.static_lookups.tables.contains_key(nl) {
|
||||
let table =
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, &mut self.shared_table_inputs);
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let table = if let Some(table) = self.static_lookups.tables.values().next() {
|
||||
Table::<F>::configure(
|
||||
cs,
|
||||
lookup_range,
|
||||
logrows,
|
||||
nl,
|
||||
Some(table.table_inputs.clone()),
|
||||
)
|
||||
} else {
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
|
||||
};
|
||||
self.static_lookups.tables.insert(nl.clone(), table.clone());
|
||||
table
|
||||
} else {
|
||||
@@ -892,6 +900,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
|
||||
self.range_checks.ranges.entry(range)
|
||||
{
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
|
||||
@@ -25,7 +25,7 @@ pub enum CircuitError {
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
/// Invalid einsum expression
|
||||
///
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
/// Flush error
|
||||
@@ -103,7 +103,4 @@ pub enum CircuitError {
|
||||
#[error("an element is missing from the shuffled version of the tensor")]
|
||||
/// An element is missing from the shuffled version of the tensor
|
||||
MissingShuffleElement,
|
||||
/// Visibility has not been set
|
||||
#[error("visibility has not been set")]
|
||||
UnsetVisibility,
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>,
|
||||
f: impl Fn(&BaseConfig<F>, &mut RegionCtx<F>, &ValTensor<F>) -> Result<ValTensor<F>, CircuitError>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
) -> Result<(), CircuitError> {
|
||||
let one = create_constant_tensor(F::from(1), 1);
|
||||
|
||||
let f_x = f(config, region, x)?;
|
||||
@@ -87,17 +87,22 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
let f_x_minus_1 = f(config, region, &x_minus_1)?;
|
||||
|
||||
// because the function is convex, the result should be the minimum of the three
|
||||
// note that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) < f(x-1)
|
||||
// not that we offset the x by 1 to get the next value
|
||||
// f(x) <= f(x+1) and f(x) <= f(x-1)
|
||||
// the result is 1 if the function is optimal solely because of the convexity of the function
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal, but if (f(x) = f(x + 1))
|
||||
// f(x+1) is not smaller than f(x + 1 - 1) = f(x) and thus f(x) is unique
|
||||
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal (or f(x) and f(x-1)).
|
||||
let f_x_is_opt_rhs = less_equal(config, region, &[f_x.clone(), f_x_plus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
let f_x_is_opt_lhs = less_equal(config, region, &[f_x.clone(), f_x_minus_1.clone()])?;
|
||||
|
||||
let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;
|
||||
|
||||
Ok(is_opt)
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Err is less than some constant
|
||||
@@ -285,14 +290,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
// we need to add 1 to the points where it is zero to ignore the cvx opt conditions at those points
|
||||
let mut is_opt = optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
is_opt = pairwise(config, region, &[is_opt, equal_zero_mask], BaseOp::Add)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
@@ -364,13 +362,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(distance)
|
||||
};
|
||||
|
||||
let is_opt = optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
let mut comparison_unit = create_constant_tensor(F::ONE, is_opt.len());
|
||||
comparison_unit.reshape(is_opt.dims())?;
|
||||
|
||||
// assert that the result is 1
|
||||
enforce_equality(config, region, &[is_opt, comparison_unit])?;
|
||||
optimum_convex_function(config, region, &claimed_output, err_func)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
}
|
||||
@@ -1185,6 +1177,8 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
|
||||
region.enable(Some(lookup_selector), z)?;
|
||||
|
||||
// region.enable(Some(lookup_selector), z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
@@ -1204,7 +1198,7 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
/// 4. index_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 5. value_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 6. Given the above, and given the fixed index_input , we go through every (index_input, value_input) pair and ascertain that it is contained in the input.
|
||||
/// 7. Given the fixed incrementing index index_input, we avoid multiplicity in the output by leveraging this surrogate index: if index_output isn't matched to the exact value where for `index_input=index_output` -> `value_input=value_output`, then the lookup fails
|
||||
/// Given the fixed incrementing index index_input, we avoid multiplicity in the output by leveraging this surrogate index: if index_output isn't matched to the exact value where for `index_input=index_output` -> `value_input=value_output`, then the lookup fails
|
||||
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
@@ -2983,7 +2977,7 @@ pub fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let lhs_and_rhs_not = and(config, region, &[lhs, rhs_not.clone()])?;
|
||||
let lhs_not_and_rhs = and(config, region, &[rhs, lhs_not])?;
|
||||
|
||||
// we can safely use add and not OR here because we know that lhs_and_rhs_not and lhs_not_and_rhs are =1 at different indices
|
||||
// we can safely use add and not OR here because we know that lhs_and_rhs_not and lhs_not_and_rhs are =1 at different incices
|
||||
let res: ValTensor<F> = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -3260,11 +3254,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.map(|(i, d)| {
|
||||
let d = padding[i].0 + d + padding[i].1;
|
||||
d.checked_sub(pool_dims[i])
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))?
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))?
|
||||
.checked_div(stride[i])
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))?
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))?
|
||||
.checked_add(1)
|
||||
.ok_or_else(|| TensorError::Overflow("max_pool".to_string()))
|
||||
.ok_or_else(|| TensorError::Overflow("conv".to_string()))
|
||||
})
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
|
||||
@@ -4948,7 +4942,6 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
1,
|
||||
);
|
||||
let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?;
|
||||
region.increment(assigned_midway_point.len());
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
@@ -5214,7 +5207,6 @@ pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
|
||||
if !is_assigned {
|
||||
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
|
||||
region.increment(sliced_input.len());
|
||||
}
|
||||
|
||||
// get the sign bit and make sure it is valid
|
||||
|
||||
@@ -264,10 +264,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
}
|
||||
/// Rebase the scale of the constant
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
|
||||
let visibility = match self.quantized_values.visibility() {
|
||||
Some(v) => v,
|
||||
None => return Err(CircuitError::UnsetVisibility),
|
||||
};
|
||||
let visibility = self.quantized_values.visibility().unwrap();
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -252,12 +252,6 @@ impl<
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 1 {
|
||||
return Err(TensorError::DimError(
|
||||
"GatherElements only accepts single inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
@@ -275,12 +269,6 @@ impl<
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 2 {
|
||||
return Err(TensorError::DimError(
|
||||
"ScatterElements requires two inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
|
||||
@@ -163,7 +163,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
nonlinearity: &LookupOp,
|
||||
preexisting_inputs: &mut Vec<TableColumn>,
|
||||
preexisting_inputs: Option<Vec<TableColumn>>,
|
||||
) -> Table<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
@@ -172,28 +172,28 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
debug!("table range: {:?}", range);
|
||||
|
||||
// validate enough columns are provided to store the range
|
||||
if preexisting_inputs.len() < num_cols {
|
||||
// add columns to match the required number of columns
|
||||
let diff = num_cols - preexisting_inputs.len();
|
||||
for _ in 0..diff {
|
||||
preexisting_inputs.push(cs.lookup_table_column());
|
||||
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
|
||||
let mut cols = vec![];
|
||||
for _ in 0..num_cols {
|
||||
cols.push(cs.lookup_table_column());
|
||||
}
|
||||
}
|
||||
cols
|
||||
});
|
||||
|
||||
let num_cols = table_inputs.len();
|
||||
|
||||
let num_cols = preexisting_inputs.len();
|
||||
if num_cols > 1 {
|
||||
warn!("Using {} columns for non-linearity table.", num_cols);
|
||||
}
|
||||
|
||||
let table_outputs = preexisting_inputs
|
||||
let table_outputs = table_inputs
|
||||
.iter()
|
||||
.map(|_| cs.lookup_table_column())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Table {
|
||||
nonlinearity: nonlinearity.clone(),
|
||||
table_inputs: preexisting_inputs.clone(),
|
||||
table_inputs,
|
||||
table_outputs,
|
||||
is_assigned: false,
|
||||
selector_constructor: SelectorConstructor::new(num_cols),
|
||||
|
||||
@@ -517,7 +517,6 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn deploy_multi_da_contract(
|
||||
client: EthersClient,
|
||||
contract_instance_offset: usize,
|
||||
|
||||
@@ -118,7 +118,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
} => gen_srs_cmd(
|
||||
srs_path,
|
||||
logrows as u32,
|
||||
commitment.unwrap_or_else(|| Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
),
|
||||
Commands::GetSrs {
|
||||
srs_path,
|
||||
@@ -1535,8 +1535,7 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
trace!("params computed");
|
||||
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data =
|
||||
GraphData::from_path(input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
|
||||
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
// The number of input and output instances we attest to for the single call data attestation
|
||||
let mut input_len = None;
|
||||
@@ -2127,7 +2126,6 @@ pub(crate) fn mock_aggregate(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
|
||||
@@ -9,8 +9,6 @@ pub type IntegerRep = i128;
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else if x == IntegerRep::MIN {
|
||||
-F::from_u128(x.saturating_neg() as u128) - F::ONE
|
||||
} else {
|
||||
-F::from_u128(x.saturating_neg() as u128)
|
||||
}
|
||||
@@ -34,9 +32,6 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN;
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -56,7 +51,7 @@ mod test {
|
||||
use halo2curves::pasta::Fp as F;
|
||||
|
||||
#[test]
|
||||
fn integerreptofelt() {
|
||||
fn test_conv() {
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
@@ -78,20 +73,4 @@ mod test {
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmin() {
|
||||
let x = IntegerRep::MIN;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmax() {
|
||||
let x = IntegerRep::MAX;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,13 +119,13 @@ pub enum GraphError {
|
||||
/// Missing input for a node
|
||||
#[error("missing input for node {0}")]
|
||||
MissingInput(usize),
|
||||
/// Ranges can only be constant
|
||||
///
|
||||
#[error("range only supports constant inputs in a zk circuit")]
|
||||
NonConstantRange,
|
||||
/// Trilu diagonal must be constant
|
||||
///
|
||||
#[error("trilu only supports constant diagonals in a zk circuit")]
|
||||
NonConstantTrilu,
|
||||
/// The witness was too short
|
||||
///
|
||||
#[error("insufficient witness values to generate a fixed output")]
|
||||
InsufficientWitnessValues,
|
||||
/// Missing scale
|
||||
@@ -152,10 +152,4 @@ pub enum GraphError {
|
||||
/// Only nearest neighbor interpolation is supported
|
||||
#[error("only nearest neighbor interpolation is supported")]
|
||||
InvalidInterpolation,
|
||||
/// Node has a missing output
|
||||
#[error("node {0} has a missing output")]
|
||||
MissingOutput(usize),
|
||||
/// Inssuficient advice columns
|
||||
#[error("insuficcient advice columns (need {0} at least)")]
|
||||
InsufficientAdviceColumns(usize),
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
|
||||
|
||||
@@ -32,95 +31,30 @@ type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
|
||||
/// Represents different types of values that can be stored in a file source
|
||||
/// Used for handling various input types in zero-knowledge proofs
|
||||
///
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum FileSourceInner {
|
||||
/// Floating point value (64-bit)
|
||||
/// Inner elements of float inputs coming from a file
|
||||
Float(f64),
|
||||
/// Boolean value
|
||||
/// Inner elements of bool inputs coming from a file
|
||||
Bool(bool),
|
||||
/// Field element value for direct use in circuits
|
||||
/// Inner elements of inputs coming from a witness
|
||||
Field(Fp),
|
||||
}
|
||||
|
||||
impl FileSourceInner {
|
||||
/// Returns true if the value is a floating point number
|
||||
///
|
||||
pub fn is_float(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Float(_))
|
||||
}
|
||||
|
||||
/// Returns true if the value is a boolean
|
||||
///
|
||||
pub fn is_bool(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Bool(_))
|
||||
}
|
||||
|
||||
/// Returns true if the value is a field element
|
||||
///
|
||||
pub fn is_field(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Field(_))
|
||||
}
|
||||
|
||||
/// Creates a new floating point value
|
||||
pub fn new_float(f: f64) -> Self {
|
||||
FileSourceInner::Float(f)
|
||||
}
|
||||
|
||||
/// Creates a new field element value
|
||||
pub fn new_field(f: Fp) -> Self {
|
||||
FileSourceInner::Field(f)
|
||||
}
|
||||
|
||||
/// Creates a new boolean value
|
||||
pub fn new_bool(f: bool) -> Self {
|
||||
FileSourceInner::Bool(f)
|
||||
}
|
||||
|
||||
/// Adjusts the value according to the specified input type
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_type` - Type specification to convert the value to
|
||||
pub fn as_type(&mut self, input_type: &InputType) {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
||||
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
||||
FileSourceInner::Field(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the value to a field element using appropriate scaling
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `scale` - Scaling factor for floating point conversion
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
} else {
|
||||
Fp::zero()
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => *f,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the value to a floating point number
|
||||
pub fn to_float(&self) -> f64 {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => *f,
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for FileSourceInner {
|
||||
@@ -136,8 +70,8 @@ impl Serialize for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
// Deserialization implementation for FileSourceInner
|
||||
// Uses JSON deserialization to handle the different variants
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -164,16 +98,70 @@ impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
/// A collection of input values from a file source
|
||||
/// Organized as a vector of vectors where each inner vector represents a row/entry
|
||||
/// Elements of inputs coming from a file
|
||||
pub type FileSource = Vec<Vec<FileSourceInner>>;
|
||||
|
||||
/// Represents different types of calls for fetching on-chain data
|
||||
impl FileSourceInner {
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_float(f: f64) -> Self {
|
||||
FileSourceInner::Float(f)
|
||||
}
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_field(f: Fp) -> Self {
|
||||
FileSourceInner::Field(f)
|
||||
}
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_bool(f: bool) -> Self {
|
||||
FileSourceInner::Bool(f)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn as_type(&mut self, input_type: &InputType) {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
||||
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
||||
FileSourceInner::Field(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
} else {
|
||||
Fp::zero()
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => *f,
|
||||
}
|
||||
}
|
||||
/// Convert to a float
|
||||
pub fn to_float(&self) -> f64 {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => *f,
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Call type for attested inputs on-chain
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum Calls {
|
||||
/// Multiple calls to different accounts, each returning individual values
|
||||
/// Vector of calls to accounts, each returning an attested data point
|
||||
Multiple(Vec<CallsToAccount>),
|
||||
/// Single call returning an array of values
|
||||
/// Single call to account, returning an array of attested data points
|
||||
Single(CallToAccount),
|
||||
}
|
||||
|
||||
@@ -182,6 +170,32 @@ impl Default for Calls {
|
||||
Calls::Multiple(Vec::new())
|
||||
}
|
||||
}
|
||||
/// Inner elements of inputs/outputs coming from on-chain
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Calls to accounts
|
||||
pub calls: Calls,
|
||||
/// RPC url
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Create a new OnChainSource with multiple calls
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new OnChainSource with a single call
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Calls {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
@@ -203,6 +217,7 @@ impl<'de> Deserialize<'de> for Calls {
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = multiple_try {
|
||||
return Ok(Calls::Multiple(t));
|
||||
@@ -212,52 +227,111 @@ impl<'de> Deserialize<'de> for Calls {
|
||||
return Ok(Calls::Single(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize Calls"))
|
||||
Err(serde::de::Error::custom(
|
||||
"failed to deserialize FileSourceInner",
|
||||
))
|
||||
}
|
||||
}
|
||||
/// Configuration for accessing on-chain data sources
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Inner elements of inputs/outputs coming from postgres DB
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Call specifications for fetching data
|
||||
pub calls: Calls,
|
||||
/// RPC endpoint URL for accessing the chain
|
||||
pub rpc: RPCUrl,
|
||||
pub struct PostgresSource {
|
||||
/// postgres host
|
||||
pub host: RPCUrl,
|
||||
/// user to connect to postgres
|
||||
pub user: String,
|
||||
/// password to connect to postgres
|
||||
pub password: String,
|
||||
/// query to execute
|
||||
pub query: String,
|
||||
/// dbname
|
||||
pub dbname: String,
|
||||
/// port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Create a new PostgresSource
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch data from postgres
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// clone to move into thread
|
||||
let user = self.user.clone();
|
||||
let host = self.host.clone();
|
||||
let query = self.query.clone();
|
||||
let dbname = self.dbname.clone();
|
||||
let port = self.port.clone();
|
||||
let password = self.password.clone();
|
||||
|
||||
let config = if password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
host, user, dbname, port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
host, user, dbname, port, password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[]).await? {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetch data from postgres and format it as a FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Creates a new OnChainSource with multiple calls
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `calls` - Vector of call specifications
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new OnChainSource with a single call
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `call` - Call specification
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Creates test data for the OnChain data source
|
||||
/// Used for testing and development purposes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Sample file data to use
|
||||
/// * `scales` - Scaling factors for each input
|
||||
/// * `shapes` - Shapes of the input tensors
|
||||
/// * `rpc` - Optional RPC endpoint override
|
||||
/// Create dummy local on-chain data to test the OnChain data source
|
||||
pub async fn test_from_file_data(
|
||||
data: &FileSource,
|
||||
scales: Vec<crate::Scale>,
|
||||
@@ -324,40 +398,48 @@ impl OnChainSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// Specification for view-only calls to fetch on-chain data
|
||||
/// Used for data attestation in smart contract verification
|
||||
/// Defines the view only calls to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallsToAccount {
|
||||
/// Vector of (call data, decimals) pairs
|
||||
/// call_data: ABI-encoded function call
|
||||
/// decimals: Number of decimal places for float conversion
|
||||
/// A vector of tuples, where index 0 of tuples
|
||||
/// are the byte strings representing the ABI encoded function calls to
|
||||
/// read the data from the address. This call must return a single
|
||||
/// elementary type (<https://docs.soliditylang.org/en/v0.8.20/abi-spec.html#types>).
|
||||
/// The second index of the tuple is the number of decimals for f32 conversion.
|
||||
/// We don't support dynamic types currently.
|
||||
pub call_data: Vec<(Call, Decimals)>,
|
||||
/// Contract address to call
|
||||
/// Address of the contract to read the data from.
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Specification for a single view-only call returning an array
|
||||
/// Defines a view only call to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallToAccount {
|
||||
/// ABI-encoded function call data
|
||||
/// The call_data is a byte strings representing the ABI encoded function call to
|
||||
/// read the data from the address. This call must return a single array of integers that can be
|
||||
/// be safely cast to the int128 type in solidity.
|
||||
pub call_data: Call,
|
||||
/// Number of decimal places for float conversion
|
||||
/// The number of decimals for f32 conversion of all of the elements returned from the
|
||||
/// call.
|
||||
pub decimals: Decimals,
|
||||
/// Contract address to call
|
||||
/// Address of the contract to read the data from.
|
||||
pub address: String,
|
||||
/// Expected length of returned array
|
||||
/// The number of elements returned from the call.
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
/// Represents different sources of input/output data for the EZKL model
|
||||
/// Enum that defines source of the inputs/outputs to the EZKL model
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum DataSource {
|
||||
/// Data from a JSON file containing arrays of values
|
||||
/// .json File data source.
|
||||
File(FileSource),
|
||||
/// Data fetched from blockchain contracts
|
||||
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
|
||||
OnChain(OnChainSource),
|
||||
/// Data from a PostgreSQL database
|
||||
/// Postgres DB
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DB(PostgresSource),
|
||||
}
|
||||
@@ -400,7 +482,8 @@ impl From<OnChainSource> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Always use JSON serialization for untagged enums
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
impl<'de> Deserialize<'de> for DataSource {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -408,19 +491,15 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
// Try deserializing as FileSource first
|
||||
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
|
||||
|
||||
if let Ok(t) = first_try {
|
||||
return Ok(DataSource::File(t));
|
||||
}
|
||||
|
||||
// Try deserializing as OnChainSource
|
||||
let second_try: Result<OnChainSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = second_try {
|
||||
return Ok(DataSource::OnChain(t));
|
||||
}
|
||||
|
||||
// Try deserializing as PostgresSource if feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
|
||||
@@ -433,29 +512,22 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for input and output data for graph computations
|
||||
///
|
||||
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
|
||||
/// Input to graph as a datasource
|
||||
/// Always use JSON serialization for GraphData. Seriously.
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
|
||||
pub struct GraphData {
|
||||
/// Input data for the model/graph
|
||||
/// Can be empty if inputs come from on-chain sources
|
||||
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
||||
pub input_data: DataSource,
|
||||
|
||||
/// Optional output data for the model/graph
|
||||
/// Can be empty if outputs come from on-chain sources
|
||||
/// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain).
|
||||
pub output_data: Option<DataSource>,
|
||||
}
|
||||
|
||||
impl UnwindSafe for GraphData {}
|
||||
|
||||
impl GraphData {
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Converts the input data to tract's tensor format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `shapes` - Expected shapes for each input tensor
|
||||
/// * `datum_types` - Expected data types for each input
|
||||
/// Convert the input data to tract data
|
||||
pub fn to_tract_data(
|
||||
&self,
|
||||
shapes: &[Vec<usize>],
|
||||
@@ -484,14 +556,9 @@ impl GraphData {
|
||||
Ok(inputs)
|
||||
}
|
||||
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Converts tract tensor data into GraphData format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensors` - Array of tract tensors to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the converted tensor data
|
||||
/// Convert the tract data to tract data
|
||||
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
|
||||
use tract_onnx::prelude::DatumType;
|
||||
|
||||
@@ -517,10 +584,7 @@ impl GraphData {
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a new GraphData instance with given input data
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_data` - The input data source
|
||||
pub fn new(input_data: DataSource) -> Self {
|
||||
GraphData {
|
||||
input_data,
|
||||
@@ -528,13 +592,7 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads graph input data from a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the input file
|
||||
///
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the loaded data
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let reader = std::fs::File::open(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
@@ -548,35 +606,23 @@ impl GraphData {
|
||||
Ok(graph_input)
|
||||
}
|
||||
|
||||
/// Saves the graph data to a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path where to save the data
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let file = std::fs::File::create(path.clone()).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Splits the input data into multiple batches based on input shapes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_shapes` - Vector of shapes for each input tensor
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of GraphData instances, one for each batch
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if:
|
||||
/// - Data is from on-chain source
|
||||
/// - Input size is not evenly divisible by batch size
|
||||
pub async fn split_into_batches(
|
||||
&self,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
// split input data into batches
|
||||
let mut batched_inputs = vec![];
|
||||
|
||||
let iterable = match self {
|
||||
@@ -600,12 +646,10 @@ impl GraphData {
|
||||
} => data.fetch_and_format_as_file().await?,
|
||||
};
|
||||
|
||||
// Process each input tensor according to its shape
|
||||
for (i, shape) in input_shapes.iter().enumerate() {
|
||||
// ensure the input is evenly divisible by batch_size
|
||||
let input_size = shape.clone().iter().product::<usize>();
|
||||
let input = &iterable[i];
|
||||
|
||||
// Validate input size is divisible by batch size
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
@@ -613,8 +657,6 @@ impl GraphData {
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Split input into batches
|
||||
let mut batches = vec![];
|
||||
for batch in input.chunks(input_size) {
|
||||
batches.push(batch.to_vec());
|
||||
@@ -622,18 +664,18 @@ impl GraphData {
|
||||
batched_inputs.push(batches);
|
||||
}
|
||||
|
||||
// Merge batches across inputs
|
||||
// now merge all the batches for each input into a vector of batches
|
||||
// first assert each input has the same number of batches
|
||||
let num_batches = if batched_inputs.is_empty() {
|
||||
0
|
||||
} else {
|
||||
let num_batches = batched_inputs[0].len();
|
||||
// Verify all inputs have same number of batches
|
||||
for input in batched_inputs.iter() {
|
||||
assert_eq!(input.len(), num_batches);
|
||||
}
|
||||
num_batches
|
||||
};
|
||||
|
||||
// now merge the batches
|
||||
let mut input_batches = vec![];
|
||||
for i in 0..num_batches {
|
||||
let mut batch = vec![];
|
||||
@@ -643,12 +685,11 @@ impl GraphData {
|
||||
input_batches.push(DataSource::File(batch));
|
||||
}
|
||||
|
||||
// Ensure at least one batch exists
|
||||
if input_batches.is_empty() {
|
||||
input_batches.push(DataSource::File(vec![vec![]]));
|
||||
}
|
||||
|
||||
// Create GraphData instance for each batch
|
||||
// create a new GraphWitness for each batch
|
||||
let batches = input_batches
|
||||
.into_iter()
|
||||
.map(GraphData::new)
|
||||
@@ -660,7 +701,6 @@ impl GraphData {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallsToAccount {
|
||||
/// Converts CallsToAccount to Python object
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("account", &self.address).unwrap();
|
||||
@@ -669,165 +709,6 @@ impl ToPyObject for CallsToAccount {
|
||||
}
|
||||
}
|
||||
|
||||
// Additional Python bindings for various types...
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_postgres_source_new() {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let source = PostgresSource::new(
|
||||
"localhost".to_string(),
|
||||
"5432".to_string(),
|
||||
"user".to_string(),
|
||||
"SELECT * FROM table".to_string(),
|
||||
"database".to_string(),
|
||||
"password".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(source.host, "localhost");
|
||||
assert_eq!(source.port, "5432");
|
||||
assert_eq!(source.user, "user");
|
||||
assert_eq!(source.query, "SELECT * FROM table");
|
||||
assert_eq!(source.dbname, "database");
|
||||
assert_eq!(source.password, "password");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
// Test backwards compatibility with old format
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
// Test serialization/deserialization of graph input
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_python_compat() {
|
||||
// Test compatibility with mclbn256 library serialization
|
||||
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
||||
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
||||
assert_eq!(format!("{:?}", source), original_addr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Source data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
/// Database host address
|
||||
pub host: RPCUrl,
|
||||
/// Database user name
|
||||
pub user: String,
|
||||
/// Database password
|
||||
pub password: String,
|
||||
/// SQL query to execute
|
||||
pub query: String,
|
||||
/// Database name
|
||||
pub dbname: String,
|
||||
/// Database port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Creates a new PostgreSQL data source
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches data from the PostgreSQL database
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// Configuration string
|
||||
let config = if self.password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
self.host, self.user, self.dbname, self.port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
self.host, self.user, self.dbname, self.port, self.password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
|
||||
// Extract rows from query
|
||||
for row in client.query(&self.query, &[]).await? {
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetches and formats data as FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallToAccount {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -862,7 +743,6 @@ impl ToPyObject for DataSource {
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DataSource::DB(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("host", &source.host).unwrap();
|
||||
@@ -887,3 +767,57 @@ impl ToPyObject for FileSourceInner {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
|
||||
// test for the compatibility with the serialized elements from the mclbn256 library
|
||||
#[test]
|
||||
fn test_python_compat() {
|
||||
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
||||
|
||||
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
||||
|
||||
assert_eq!(format!("{:?}", source), original_addr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -619,6 +619,11 @@ impl GraphSettings {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn uses_modules(&self) -> bool {
|
||||
!self.module_sizes.max_constraints() > 0
|
||||
}
|
||||
|
||||
/// if any visibility is encrypted or hashed
|
||||
pub fn module_requires_fixed(&self) -> bool {
|
||||
self.run_args.input_visibility.is_hashed()
|
||||
@@ -761,7 +766,7 @@ pub struct TestOnChainData {
|
||||
pub data: std::path::PathBuf,
|
||||
/// rpc endpoint
|
||||
pub rpc: Option<String>,
|
||||
/// data sources for the on chain data
|
||||
///
|
||||
pub data_sources: TestSources,
|
||||
}
|
||||
|
||||
@@ -949,7 +954,7 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
_ => Err(GraphError::OnChainDataSource),
|
||||
_ => unreachable!("cannot load from on-chain data"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -384,7 +384,8 @@ pub struct ParsedNodes {
|
||||
impl ParsedNodes {
|
||||
/// Returns the number of the computational graph's inputs
|
||||
pub fn num_inputs(&self) -> usize {
|
||||
self.inputs.len()
|
||||
let input_nodes = self.inputs.iter();
|
||||
input_nodes.len()
|
||||
}
|
||||
|
||||
/// Input types
|
||||
@@ -424,7 +425,8 @@ impl ParsedNodes {
|
||||
|
||||
/// Returns the number of the computational graph's outputs
|
||||
pub fn num_outputs(&self) -> usize {
|
||||
self.outputs.len()
|
||||
let output_nodes = self.outputs.iter();
|
||||
output_nodes.len()
|
||||
}
|
||||
|
||||
/// Returns shapes of the computational graph's outputs
|
||||
@@ -632,10 +634,6 @@ impl Model {
|
||||
|
||||
for (i, id) in model.clone().inputs.iter().enumerate() {
|
||||
let input = model.node_mut(id.node);
|
||||
|
||||
if input.outputs.len() == 0 {
|
||||
return Err(GraphError::MissingOutput(id.node));
|
||||
}
|
||||
let mut fact: InferenceFact = input.outputs[0].fact.clone();
|
||||
|
||||
for (i, x) in fact.clone().shape.dims().enumerate() {
|
||||
@@ -1018,10 +1016,6 @@ impl Model {
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
if vars.advices.len() < 3 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(3));
|
||||
}
|
||||
|
||||
let mut base_gate = PolyConfig::configure(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
@@ -1041,10 +1035,6 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_dynamic_lookup() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
|
||||
base_gate.configure_dynamic_lookup(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
@@ -1053,9 +1043,6 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_shuffle() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
base_gate.configure_shuffles(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
@@ -1074,7 +1061,6 @@ impl Model {
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `witnessed_outputs` - The values to compare against.
|
||||
/// * `constants` - The constants for the circuit.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn layout(
|
||||
&self,
|
||||
mut config: ModelConfig,
|
||||
@@ -1474,7 +1460,7 @@ impl Model {
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_felt_evals()
|
||||
.unwrap_or_else(|_| Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -284,6 +284,7 @@ impl GraphModules {
|
||||
log::error!("Poseidon config not initialized");
|
||||
return Err(Error::Synthesis);
|
||||
}
|
||||
// If the module is encrypted, then we need to encrypt the inputs
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
// Import dependencies for scaling operations
|
||||
use super::scale_to_multiplier;
|
||||
|
||||
// Import ONNX-specific utilities when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::utilities::node_output_shapes;
|
||||
|
||||
// Import scale management types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
|
||||
// Import visibility settings for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::Visibility;
|
||||
|
||||
// Import operation types for different circuit components
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
@@ -22,49 +13,28 @@ use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
|
||||
// Import graph error types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::errors::GraphError;
|
||||
|
||||
// Import ONNX operation conversion utilities
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
|
||||
// Import tensor error handling
|
||||
use crate::tensor::TensorError;
|
||||
|
||||
// Import curve-specific field type
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
|
||||
// Import logging for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use log::trace;
|
||||
|
||||
// Import serialization traits
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
// Import data structures for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
// Import formatting traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::fmt;
|
||||
|
||||
// Import table display formatting for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tabled::Tabled;
|
||||
|
||||
// Import ONNX-specific types and traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::{
|
||||
self,
|
||||
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
|
||||
};
|
||||
|
||||
/// Helper function to format vectors for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
if !v.is_empty() {
|
||||
@@ -74,35 +44,29 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to format operation kinds for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_opkind(v: &SupportedOp) -> String {
|
||||
v.as_string()
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled to handle different precision requirements.
|
||||
/// This enables operations to work with inputs that have been scaled to different fixed-point representations.
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Rescaled {
|
||||
/// The underlying operation that needs to be rescaled
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// Vector of (index, scale) pairs defining how each input should be scaled
|
||||
/// The scale of the operation's inputs.
|
||||
pub scale: Vec<(usize, u128)>,
|
||||
}
|
||||
|
||||
/// Implementation of the Op trait for Rescaled operations
|
||||
impl Op<Fp> for Rescaled {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
@@ -113,7 +77,6 @@ impl Op<Fp> for Rescaled {
|
||||
Op::<Fp>::out_scale(&*self.inner, in_scales)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -130,40 +93,28 @@ impl Op<Fp> for Rescaled {
|
||||
self.inner.layout(config, region, res)
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone())
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for operations that require scale rebasing
|
||||
/// This handles cases where operation scales need to be adjusted to a target scale
|
||||
/// while preserving the numerical relationships
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RebaseScale {
|
||||
/// The operation that needs to be rescaled
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// Operation used for rebasing, typically division
|
||||
/// rebase op
|
||||
pub rebase_op: HybridOp,
|
||||
/// Scale that we're rebasing to
|
||||
/// scale being rebased to
|
||||
pub target_scale: i32,
|
||||
/// Original scale of operation's inputs before rebasing
|
||||
/// The original scale of the operation's inputs.
|
||||
pub original_scale: i32,
|
||||
/// Scaling multiplier used in rebasing
|
||||
/// multiplier
|
||||
pub multiplier: f64,
|
||||
}
|
||||
|
||||
impl RebaseScale {
|
||||
/// Creates a rebased version of an operation if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `global_scale` - Base scale for the system
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
/// * `scale_rebase_multiplier` - Factor determining when rebasing should occur
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation depending on scale relationships
|
||||
pub fn rebase(
|
||||
inner: SupportedOp,
|
||||
global_scale: crate::Scale,
|
||||
@@ -204,15 +155,7 @@ impl RebaseScale {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a rebased operation with increased scale
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `target_scale` - Scale to rebase to
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation with increased scale
|
||||
pub fn rebase_up(
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
@@ -249,12 +192,10 @@ impl RebaseScale {
|
||||
}
|
||||
|
||||
impl Op<Fp> for RebaseScale {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
"REBASED (div={:?}, rebasing_op={}) ({})",
|
||||
@@ -264,12 +205,10 @@ impl Op<Fp> for RebaseScale {
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -283,40 +222,34 @@ impl Op<Fp> for RebaseScale {
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone())
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents all supported operation types in the circuit
|
||||
/// Each variant encapsulates a different type of operation with specific behavior
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum SupportedOp {
|
||||
/// Linear operations (polynomial-based)
|
||||
/// A linear operation.
|
||||
Linear(PolyOp),
|
||||
/// Nonlinear operations requiring lookup tables
|
||||
/// A nonlinear operation.
|
||||
Nonlinear(LookupOp),
|
||||
/// Mixed operations combining different approaches
|
||||
/// A hybrid operation.
|
||||
Hybrid(HybridOp),
|
||||
/// Input values to the circuit
|
||||
///
|
||||
Input(Input),
|
||||
/// Constant values in the circuit
|
||||
///
|
||||
Constant(Constant<Fp>),
|
||||
/// Placeholder for unsupported operations
|
||||
///
|
||||
Unknown(Unknown),
|
||||
/// Operations requiring rescaling of inputs
|
||||
///
|
||||
Rescaled(Rescaled),
|
||||
/// Operations requiring scale rebasing
|
||||
///
|
||||
RebaseScale(RebaseScale),
|
||||
}
|
||||
|
||||
impl SupportedOp {
|
||||
/// Checks if the operation is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if operation requires lookup table
|
||||
/// * `false` otherwise
|
||||
pub fn is_lookup(&self) -> bool {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(_) => true,
|
||||
@@ -324,12 +257,7 @@ impl SupportedOp {
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns input operation if this is an input
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(Input)` if this is an input operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_input(&self) -> Option<Input> {
|
||||
match self {
|
||||
SupportedOp::Input(op) => Some(op.clone()),
|
||||
@@ -337,11 +265,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to rebased operation if this is a rebased operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&RebaseScale)` if this is a rebased operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_rebased(&self) -> Option<&RebaseScale> {
|
||||
match self {
|
||||
SupportedOp::RebaseScale(op) => Some(op),
|
||||
@@ -349,11 +273,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to lookup operation if this is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&LookupOp)` if this is a lookup operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_lookup(&self) -> Option<&LookupOp> {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(op) => Some(op),
|
||||
@@ -361,11 +281,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_constant(&self) -> Option<&Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -373,11 +289,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns mutable reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&mut Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_mutable_constant(&mut self) -> Option<&mut Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -385,19 +297,18 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a homogeneously rescaled version of this operation if needed
|
||||
/// Only available with EZKL feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn homogenous_rescale(
|
||||
&self,
|
||||
in_scales: Vec<crate::Scale>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
let op = self.clone_dyn();
|
||||
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
|
||||
}
|
||||
|
||||
/// Returns reference to underlying Op implementation
|
||||
/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
|
||||
fn as_op(&self) -> &dyn Op<Fp> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => op,
|
||||
@@ -411,10 +322,9 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if this is an identity operation
|
||||
///
|
||||
/// check if is the identity operation
|
||||
/// # Returns
|
||||
/// * `true` if this operation passes input through unchanged
|
||||
/// * `true` if the operation is the identity operation
|
||||
/// * `false` otherwise
|
||||
pub fn is_identity(&self) -> bool {
|
||||
match self {
|
||||
@@ -451,11 +361,9 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
if let Some(op) = value.as_any().downcast_ref::<Unknown>() {
|
||||
return SupportedOp::Unknown(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<Rescaled>() {
|
||||
return SupportedOp::Rescaled(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<RebaseScale>() {
|
||||
return SupportedOp::RebaseScale(op.clone());
|
||||
};
|
||||
@@ -467,7 +375,6 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
/// Layout this operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -477,61 +384,54 @@ impl Op<Fp> for SupportedOp {
|
||||
self.as_op().layout(config, region, values)
|
||||
}
|
||||
|
||||
/// Check if this is an input operation
|
||||
fn is_input(&self) -> bool {
|
||||
self.as_op().is_input()
|
||||
}
|
||||
|
||||
/// Check if this is a constant operation
|
||||
fn is_constant(&self) -> bool {
|
||||
self.as_op().is_constant()
|
||||
}
|
||||
|
||||
/// Get which inputs require homogeneous scales
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
self.as_op().requires_homogenous_input_scales()
|
||||
}
|
||||
|
||||
/// Create a clone of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
self.as_op().clone_dyn()
|
||||
}
|
||||
|
||||
/// Get string representation
|
||||
fn as_string(&self) -> String {
|
||||
self.as_op().as_string()
|
||||
}
|
||||
|
||||
/// Convert to Any type
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Calculate output scale from input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a connection to another node's output
|
||||
/// First element is node index, second is output slot index
|
||||
/// A node's input is a tensor from another node's output.
|
||||
pub type Outlet = (usize, usize);
|
||||
|
||||
/// Represents a single computational node in the circuit graph
|
||||
/// Contains all information needed to execute and connect operations
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Node {
|
||||
/// The operation this node performs
|
||||
/// [Op] i.e what operation this node represents.
|
||||
pub opkind: SupportedOp,
|
||||
/// Fixed point scale factor for this node's output
|
||||
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
|
||||
pub out_scale: i32,
|
||||
/// Connections to other nodes' outputs that serve as inputs
|
||||
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
|
||||
// but in_dim is [in], out_dim is [out]
|
||||
/// The indices of the node's inputs.
|
||||
pub inputs: Vec<Outlet>,
|
||||
/// Shape of this node's output tensor
|
||||
/// Dimensions of output.
|
||||
pub out_dims: Vec<usize>,
|
||||
/// Unique identifier for this node
|
||||
/// The node's unique identifier.
|
||||
pub idx: usize,
|
||||
/// Number of times this node's output is used
|
||||
/// The node's num of uses
|
||||
pub num_uses: usize,
|
||||
}
|
||||
|
||||
@@ -569,19 +469,12 @@ impl PartialEq for Node {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Creates a new Node from an ONNX node
|
||||
/// Only available when EZKL feature is enabled
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node` - Source ONNX node
|
||||
/// * `other_nodes` - Map of existing nodes in the graph
|
||||
/// * `scales` - Scale factors for variables
|
||||
/// * `idx` - Unique identifier for this node
|
||||
/// * `symbol_values` - ONNX symbol values
|
||||
/// * `run_args` - Runtime configuration arguments
|
||||
///
|
||||
/// # Returns
|
||||
/// New Node instance or error if creation fails
|
||||
/// Converts a tract [OnnxNode] into an ezkl [Node].
|
||||
/// # Arguments:
|
||||
/// * `node` - [OnnxNode]
|
||||
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
|
||||
/// * `public_params` - flag if parameters of model are public
|
||||
/// * `idx` - The node's unique identifier.
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
@@ -719,14 +612,16 @@ impl Node {
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if this node performs softmax operation
|
||||
/// check if it is a softmax node
|
||||
pub fn is_softmax(&self) -> bool {
|
||||
matches!(self.opkind, SupportedOp::Hybrid(HybridOp::Softmax { .. }))
|
||||
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to rescale constants that are only used once
|
||||
/// Only available when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn rescale_const_with_single_use(
|
||||
constant: &mut Constant<Fp>,
|
||||
|
||||
@@ -45,7 +45,6 @@ use tract_onnx::tract_hir::{
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `elem` - the element to quantize.
|
||||
@@ -59,7 +58,7 @@ pub fn quantize_float(
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value || *elem < -max_value {
|
||||
if *elem > max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
@@ -228,7 +227,10 @@ pub fn extract_tensor_value(
|
||||
.iter()
|
||||
.map(|x| match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
Err(_) => match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1589,10 +1591,12 @@ pub fn homogenize_input_scales(
|
||||
input_scales: Vec<crate::Scale>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let relevant_input_scales = inputs_to_scale
|
||||
.iter()
|
||||
.filter(|idx| input_scales.len() > **idx)
|
||||
.map(|&idx| input_scales[idx])
|
||||
let relevant_input_scales = input_scales
|
||||
.clone()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| inputs_to_scale.contains(idx))
|
||||
.map(|(_, scale)| scale)
|
||||
.collect_vec();
|
||||
|
||||
if inputs_to_scale.is_empty() {
|
||||
@@ -1638,25 +1642,6 @@ pub mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
// quantization tests
|
||||
#[test]
|
||||
fn test_quantize_tensor() {
|
||||
let tensor: Tensor<f32> = (0..10).map(|x| x as f32).into();
|
||||
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
let scale = 0;
|
||||
let visibility = &Visibility::Public;
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
|
||||
assert_eq!(quantized.len(), 10);
|
||||
assert_eq!(quantized, reference);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_edge_cases() {
|
||||
assert_eq!(quantize_float(&f64::NAN, 0.0, 0).unwrap(), 0);
|
||||
assert!(quantize_float(&f64::INFINITY, 0.0, 0).is_err());
|
||||
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flatten_valtensors() {
|
||||
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
|
||||
@@ -11,34 +11,35 @@ use log::debug;
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Defines the visibility level of values within the zero-knowledge circuit
|
||||
/// Controls how values are handled during proof generation and verification
|
||||
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
|
||||
pub enum Visibility {
|
||||
/// Value is private to the prover and not included in proof
|
||||
/// Mark an item as private to the prover (not in the proof submitted for verification)
|
||||
#[default]
|
||||
Private,
|
||||
/// Value is public and included in proof for verification
|
||||
/// Mark an item as public (sent in the proof submitted for verification)
|
||||
Public,
|
||||
/// Value is hashed and the hash is included in proof
|
||||
/// Mark an item as publicly committed to (hash sent in the proof submitted for verification)
|
||||
Hashed {
|
||||
/// Controls how the hash is handled in proof
|
||||
/// true - hash is included directly in proof (public)
|
||||
/// false - hash is used as advice and passed to computational graph
|
||||
/// Whether the hash is used as an instance (sent in the proof submitted for verification)
|
||||
/// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph
|
||||
/// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph
|
||||
hash_is_public: bool,
|
||||
/// Specifies which outputs this hash affects
|
||||
///
|
||||
outlets: Vec<usize>,
|
||||
},
|
||||
/// Value is committed using KZG commitment scheme
|
||||
/// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification)
|
||||
KZGCommit,
|
||||
/// Value is assigned as a constant in the circuit
|
||||
/// assigned as a constant in the circuit
|
||||
Fixed,
|
||||
}
|
||||
|
||||
@@ -65,17 +66,15 @@ impl Display for Visibility {
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Visibility {
|
||||
/// Converts visibility to command line flags
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
/// Converts string representation to Visibility
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
// Split on last occurrence of '/'
|
||||
// split on last occurrence of '/'
|
||||
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -107,8 +106,8 @@ impl<'a> From<&'a str> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Visibility {
|
||||
/// Converts Visibility to Python object
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
Visibility::Private => "private".to_object(py),
|
||||
@@ -135,13 +134,14 @@ impl IntoPy<PyObject> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Visibility {
|
||||
/// Extracts Visibility from Python object
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
let strval = strval.as_str();
|
||||
|
||||
if strval.contains("hashed/private") {
|
||||
// split on last occurence of '/'
|
||||
let (_, outlets) = strval.split_at(strval.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -174,32 +174,29 @@ impl<'source> FromPyObject<'source> for Visibility {
|
||||
}
|
||||
|
||||
impl Visibility {
|
||||
/// Returns true if visibility is Fixed
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_fixed(&self) -> bool {
|
||||
matches!(&self, Visibility::Fixed)
|
||||
}
|
||||
|
||||
/// Returns true if visibility is Private or hashed private
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_private(&self) -> bool {
|
||||
matches!(&self, Visibility::Private) || self.is_hashed_private()
|
||||
}
|
||||
|
||||
/// Returns true if visibility is Public
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_public(&self) -> bool {
|
||||
matches!(&self, Visibility::Public)
|
||||
}
|
||||
|
||||
/// Returns true if visibility involves hashing
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. })
|
||||
}
|
||||
|
||||
/// Returns true if visibility uses KZG commitment
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_polycommit(&self) -> bool {
|
||||
matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
/// Returns true if visibility is hashed with public hash
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed_public(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: true,
|
||||
@@ -210,8 +207,7 @@ impl Visibility {
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns true if visibility is hashed with private hash
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed_private(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: false,
|
||||
@@ -223,12 +219,11 @@ impl Visibility {
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns true if visibility requires additional processing
|
||||
#[allow(missing_docs)]
|
||||
pub fn requires_processing(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. }) | matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
/// Returns vector of output indices that this visibility setting affects
|
||||
#[allow(missing_docs)]
|
||||
pub fn overwrites_inputs(&self) -> Vec<usize> {
|
||||
if let Visibility::Hashed { outlets, .. } = self {
|
||||
return outlets.clone();
|
||||
@@ -237,14 +232,14 @@ impl Visibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages scaling factors for different parts of the model
|
||||
/// Represents the scale of the model input, model parameters.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarScales {
|
||||
/// Scale factor for input values
|
||||
///
|
||||
pub input: crate::Scale,
|
||||
/// Scale factor for parameter values
|
||||
///
|
||||
pub params: crate::Scale,
|
||||
/// Multiplier for scale rebasing
|
||||
///
|
||||
pub rebase_multiplier: u32,
|
||||
}
|
||||
|
||||
@@ -255,17 +250,17 @@ impl std::fmt::Display for VarScales {
|
||||
}
|
||||
|
||||
impl VarScales {
|
||||
/// Returns maximum scale value
|
||||
///
|
||||
pub fn get_max(&self) -> crate::Scale {
|
||||
std::cmp::max(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Returns minimum scale value
|
||||
///
|
||||
pub fn get_min(&self) -> crate::Scale {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Creates VarScales from runtime arguments
|
||||
/// Place in [VarScales] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
@@ -275,17 +270,16 @@ impl VarScales {
|
||||
}
|
||||
}
|
||||
|
||||
/// Controls visibility settings for different parts of the model
|
||||
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarVisibility {
|
||||
/// Visibility of model inputs
|
||||
/// Input to the model or computational graph
|
||||
pub input: Visibility,
|
||||
/// Visibility of model parameters (weights, biases)
|
||||
/// Parameters, such as weights and biases, in the model
|
||||
pub params: Visibility,
|
||||
/// Visibility of model outputs
|
||||
/// Output of the model or computational graph
|
||||
pub output: Visibility,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VarVisibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
@@ -307,7 +301,8 @@ impl Default for VarVisibility {
|
||||
}
|
||||
|
||||
impl VarVisibility {
|
||||
/// Creates visibility settings from runtime arguments
|
||||
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Place in [VarVisibility] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let input_vis = &args.input_visibility;
|
||||
let params_vis = &args.param_visibility;
|
||||
@@ -318,17 +313,17 @@ impl VarVisibility {
|
||||
}
|
||||
|
||||
if !output_vis.is_public()
|
||||
&& !params_vis.is_public()
|
||||
&& !input_vis.is_public()
|
||||
&& !output_vis.is_fixed()
|
||||
&& !params_vis.is_fixed()
|
||||
&& !input_vis.is_fixed()
|
||||
&& !output_vis.is_hashed()
|
||||
&& !params_vis.is_hashed()
|
||||
&& !input_vis.is_hashed()
|
||||
&& !output_vis.is_polycommit()
|
||||
&& !params_vis.is_polycommit()
|
||||
&& !input_vis.is_polycommit()
|
||||
& !params_vis.is_public()
|
||||
& !input_vis.is_public()
|
||||
& !output_vis.is_fixed()
|
||||
& !params_vis.is_fixed()
|
||||
& !input_vis.is_fixed()
|
||||
& !output_vis.is_hashed()
|
||||
& !params_vis.is_hashed()
|
||||
& !input_vis.is_hashed()
|
||||
& !output_vis.is_polycommit()
|
||||
& !params_vis.is_polycommit()
|
||||
& !input_vis.is_polycommit()
|
||||
{
|
||||
return Err(GraphError::Visibility);
|
||||
}
|
||||
@@ -340,17 +335,17 @@ impl VarVisibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for circuit columns used by a model
|
||||
/// A wrapper for holding all columns that will be assigned to by a model.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Advice columns for circuit assignments
|
||||
#[allow(missing_docs)]
|
||||
pub advices: Vec<VarTensor>,
|
||||
/// Optional instance column for public inputs
|
||||
#[allow(missing_docs)]
|
||||
pub instance: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
/// Gets reference to instance column if it exists
|
||||
/// Get instance col
|
||||
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
|
||||
if let Some(instance) = &self.instance {
|
||||
match instance {
|
||||
@@ -362,14 +357,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets initial offset for instance values
|
||||
/// Set the initial instance offset
|
||||
pub fn set_initial_instance_offset(&mut self, offset: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_initial_instance_offset(offset);
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets total length of instance data
|
||||
/// Get the total instance len
|
||||
pub fn get_instance_len(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_total_instance_len()
|
||||
@@ -378,21 +373,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments instance index
|
||||
/// Increment the instance offset
|
||||
pub fn increment_instance_idx(&mut self) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.increment_idx();
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets instance index to specific value
|
||||
/// Reset the instance offset
|
||||
pub fn set_instance_idx(&mut self, val: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_idx(val);
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets current instance index
|
||||
/// Get the instance offset
|
||||
pub fn get_instance_idx(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_idx()
|
||||
@@ -401,7 +396,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes instance column with specified dimensions and scale
|
||||
///
|
||||
pub fn instantiate_instance(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -422,7 +417,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates new ModelVars with allocated columns based on settings
|
||||
/// Allocate all columns that will be assigned to by a model.
|
||||
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
|
||||
303
src/lib.rs
303
src/lib.rs
@@ -28,9 +28,6 @@
|
||||
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
use log::warn;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc as _;
|
||||
|
||||
/// Error type
|
||||
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
|
||||
@@ -102,7 +99,7 @@ use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use graph::Visibility;
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
};
|
||||
@@ -168,6 +165,7 @@ pub mod srs_sha;
|
||||
pub mod tensor;
|
||||
#[cfg(feature = "ios-bindings")]
|
||||
uniffi::setup_scaffolding!();
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -182,9 +180,11 @@ lazy_static! {
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
|
||||
}
|
||||
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
@@ -266,96 +266,76 @@ impl From<String> for Commitments {
|
||||
}
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
///
|
||||
/// RunArgs contains all configuration parameters needed to control the proving process,
|
||||
/// including scaling factors, visibility settings, and circuit parameters.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(Args, ToFlags)
|
||||
)]
|
||||
pub struct RunArgs {
|
||||
/// Error tolerance for model outputs
|
||||
/// Only applicable when outputs are public
|
||||
/// The tolerance for error on model outputs
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
|
||||
pub tolerance: Tolerance,
|
||||
/// Fixed point scaling factor for quantizing inputs
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub input_scale: Scale,
|
||||
/// Fixed point scaling factor for quantizing parameters
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
/// The denominator in the fixed point representation used when quantizing parameters
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub param_scale: Scale,
|
||||
/// Scale rebase threshold multiplier
|
||||
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
|
||||
/// Advanced parameter that should be used with caution
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// Range for lookup table input column values
|
||||
/// Specified as (min, max) pair
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
|
||||
pub lookup_range: Range,
|
||||
/// Log2 of the number of rows in the circuit
|
||||
/// Controls circuit size and proving time
|
||||
/// The log_2 number of rows
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
|
||||
pub logrows: u32,
|
||||
/// Number of inner columns per block
|
||||
/// Affects circuit layout and efficiency
|
||||
/// The log_2 number of rows
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
pub num_inner_cols: usize,
|
||||
/// Graph variables for parameterizing the computation
|
||||
/// Format: "name->value", e.g. "batch_size->1"
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Visibility setting for input values
|
||||
/// Controls whether inputs are public or private in the circuit
|
||||
/// Flags whether inputs are public, private, fixed, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub input_visibility: Visibility,
|
||||
/// Visibility setting for output values
|
||||
/// Controls whether outputs are public or private in the circuit
|
||||
/// Flags whether outputs are public, private, fixed, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
|
||||
pub output_visibility: Visibility,
|
||||
/// Visibility setting for parameters
|
||||
/// Controls how parameters are handled in the circuit
|
||||
/// Flags whether params are fixed, private, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub param_visibility: Visibility,
|
||||
/// Whether to rebase constants with zero fractional part to scale 0
|
||||
/// Can improve efficiency for integer constants
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
/// Circuit checking mode
|
||||
/// Controls level of constraint verification
|
||||
/// check mode (safe, unsafe, etc)
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
|
||||
pub check_mode: CheckMode,
|
||||
/// Commitment scheme for circuit proving
|
||||
/// Affects proof size and verification time
|
||||
/// commitment scheme
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
|
||||
pub commitment: Option<Commitments>,
|
||||
/// Base for number decomposition
|
||||
/// Must be a power of 2
|
||||
/// the base used for decompositions
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
|
||||
pub decomp_base: usize,
|
||||
/// Number of decomposition legs
|
||||
/// Controls decomposition granularity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
/// Whether to use bounded lookup for logarithm computation
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
/// Creates a new RunArgs instance with default values
|
||||
///
|
||||
/// Default configuration is optimized for common use cases
|
||||
/// while maintaining reasonable proving time and circuit size
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
@@ -380,132 +360,49 @@ impl Default for RunArgs {
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
/// Validates the RunArgs configuration
|
||||
///
|
||||
/// Performs comprehensive validation of all parameters to ensure they are within
|
||||
/// acceptable ranges and follow required constraints. Returns accumulated errors
|
||||
/// if any validations fail.
|
||||
///
|
||||
/// # Returns
|
||||
/// - Ok(()) if all validations pass
|
||||
/// - Err(String) with detailed error message if any validation fails
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Visibility validations
|
||||
if self.param_visibility == Visibility::Public {
|
||||
errors.push(
|
||||
"Parameters cannot be public instances. Use 'fixed' or 'kzgcommit' instead"
|
||||
.to_string(),
|
||||
return Err(
|
||||
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
|
||||
}
|
||||
|
||||
// Scale validations
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
errors.push("scale_rebase_multiplier must be >= 1".to_string());
|
||||
return Err("scale_rebase_multiplier must be >= 1".into());
|
||||
}
|
||||
|
||||
// if any of the scales are too small
|
||||
if self.input_scale < 8 || self.param_scale < 8 {
|
||||
warn!("low scale values (<8) may impact precision");
|
||||
}
|
||||
|
||||
// Lookup range validations
|
||||
if self.lookup_range.0 > self.lookup_range.1 {
|
||||
errors.push(format!(
|
||||
"Invalid lookup range: min ({}) is greater than max ({})",
|
||||
self.lookup_range.0, self.lookup_range.1
|
||||
));
|
||||
return Err("lookup_range min is greater than max".into());
|
||||
}
|
||||
|
||||
// Size validations
|
||||
if self.logrows < 1 {
|
||||
errors.push("logrows must be >= 1".to_string());
|
||||
return Err("logrows must be >= 1".into());
|
||||
}
|
||||
|
||||
if self.num_inner_cols < 1 {
|
||||
errors.push("num_inner_cols must be >= 1".to_string());
|
||||
return Err("num_inner_cols must be >= 1".into());
|
||||
}
|
||||
|
||||
let batch_size = self.variables.iter().find(|(name, _)| name == "batch_size");
|
||||
if let Some(batch_size) = batch_size {
|
||||
if batch_size.1 == 0 {
|
||||
errors.push("'batch_size' cannot be 0".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Decomposition validations
|
||||
if self.decomp_base == 0 {
|
||||
errors.push("decomp_base cannot be 0".to_string());
|
||||
}
|
||||
|
||||
if self.decomp_legs == 0 {
|
||||
errors.push("decomp_legs cannot be 0".to_string());
|
||||
}
|
||||
|
||||
// Performance validations
|
||||
if self.logrows > MAX_PUBLIC_SRS {
|
||||
warn!("logrows exceeds maximum public SRS size");
|
||||
}
|
||||
|
||||
// Validate tolerance is non-negative
|
||||
if self.tolerance.val < 0.0 {
|
||||
errors.push("tolerance cannot be negative".to_string());
|
||||
}
|
||||
|
||||
// Performance warnings
|
||||
if self.input_scale > 20 || self.param_scale > 20 {
|
||||
warn!("High scale values (>20) may impact performance");
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors.join("\n"))
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
return Err("tolerance > 0.0 requires output_visibility to be public".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Exports the configuration as JSON
|
||||
///
|
||||
/// Serializes the RunArgs instance to a JSON string
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(String)` containing JSON representation
|
||||
/// * `Err` if serialization fails
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let res = serde_json::to_string(&self)?;
|
||||
Ok(res)
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
/// Parses configuration from JSON
|
||||
///
|
||||
/// Deserializes a RunArgs instance from a JSON string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `arg_json` - JSON string containing configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(RunArgs)` if parsing succeeds
|
||||
/// * `Err` if parsing fails
|
||||
/// Parse an ezkl configuration from a json
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(arg_json)
|
||||
}
|
||||
}
|
||||
|
||||
// Additional helper functions for the module
|
||||
|
||||
/// Parses a key-value pair from a string in the format "key->value"
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `s` - Input string in the format "key->value"
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok((T, U))` - Parsed key and value
|
||||
/// * `Err` - If parsing fails
|
||||
/// Parse a single key-value pair
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
@@ -518,15 +415,14 @@ where
|
||||
{
|
||||
let pos = s
|
||||
.find("->")
|
||||
.ok_or_else(|| format!("invalid KEY->VALUE: no `->` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 2..].parse()?))
|
||||
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
|
||||
let a = s[..pos].parse()?;
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
}
|
||||
|
||||
/// Verifies that a version string matches the expected artifact version
|
||||
/// Logs warnings for version mismatches or unversioned artifacts
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `artifact_version` - Version string from the artifact
|
||||
/// Check if the version string matches the artifact version
|
||||
/// If the version string does not match the artifact version, log a warning
|
||||
pub fn check_version_string_matches(artifact_version: &str) {
|
||||
if artifact_version == "0.0.0"
|
||||
|| artifact_version == "source - no compatibility guaranteed"
|
||||
@@ -551,98 +447,3 @@ pub fn check_version_string_matches(artifact_version: &str) {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::field_reassign_with_default)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_valid_default_args() {
|
||||
let args = RunArgs::default();
|
||||
assert!(args.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_param_visibility() {
|
||||
let mut args = RunArgs::default();
|
||||
args.param_visibility = Visibility::Public;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Parameters cannot be public instances"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_scale_rebase() {
|
||||
let mut args = RunArgs::default();
|
||||
args.scale_rebase_multiplier = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("scale_rebase_multiplier must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_lookup_range() {
|
||||
let mut args = RunArgs::default();
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Invalid lookup range"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_logrows() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("logrows must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_inner_cols() {
|
||||
let mut args = RunArgs::default();
|
||||
args.num_inner_cols = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("num_inner_cols must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = 1.0;
|
||||
args.output_visibility = Visibility::Private;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negative_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = -1.0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("tolerance cannot be negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_batch_size() {
|
||||
let mut args = RunArgs::default();
|
||||
args.variables = vec![("batch_size".to_string(), 0)];
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("'batch_size' cannot be 0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_serialization() {
|
||||
let args = RunArgs::default();
|
||||
let json = args.as_json().unwrap();
|
||||
let deserialized = RunArgs::from_json(&json).unwrap();
|
||||
assert_eq!(args, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_validation_errors() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
// Should contain multiple error messages
|
||||
assert!(err.matches("\n").count() >= 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +133,7 @@ pub fn aggregate<'a>(
|
||||
.collect_vec()
|
||||
}));
|
||||
|
||||
// loader.ctx().constrain_equal(cell_0, cell_1)
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
|
||||
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
@@ -308,11 +309,11 @@ impl AggregationCircuit {
|
||||
})
|
||||
}
|
||||
|
||||
/// Number of limbs used for decomposition
|
||||
///
|
||||
pub fn num_limbs() -> usize {
|
||||
LIMBS
|
||||
}
|
||||
/// Number of bits used for decomposition
|
||||
///
|
||||
pub fn num_bits() -> usize {
|
||||
BITS
|
||||
}
|
||||
|
||||
@@ -353,7 +353,6 @@ where
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
{
|
||||
/// Create a new application snark from proof and instance variables ready for aggregation
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
protocol: Option<PlonkProtocol<C>>,
|
||||
instances: Vec<Vec<F>>,
|
||||
@@ -529,6 +528,7 @@ pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
@@ -794,6 +794,7 @@ pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
@@ -816,6 +817,7 @@ pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
|
||||
@@ -38,10 +38,4 @@ pub enum TensorError {
|
||||
/// Decomposition error
|
||||
#[error("decomposition error: {0}")]
|
||||
DecompositionError(#[from] DecompositionError),
|
||||
/// Invalid argument
|
||||
#[error("invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
/// Index out of bounds
|
||||
#[error("index {0} out of bounds for dimension {1}")]
|
||||
IndexOutOfBounds(usize, usize),
|
||||
}
|
||||
|
||||
@@ -803,12 +803,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot duplicate every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
|
||||
let mut offset = initial_offset;
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
@@ -838,17 +832,11 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot remove every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Pre-calculate capacity to avoid reallocations
|
||||
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
|
||||
let mut inner = Vec::with_capacity(estimated_size);
|
||||
|
||||
// Use iterator directly instead of creating intermediate collectionsif
|
||||
// Use iterator directly instead of creating intermediate collections
|
||||
let mut i = 0;
|
||||
while i < self.inner.len() {
|
||||
// Add the current element
|
||||
@@ -867,6 +855,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
|
||||
/// Remove indices
|
||||
/// WARN: assumes indices are in ascending order for speed
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
@@ -893,11 +882,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
// remove indices
|
||||
for elem in indices.iter().rev() {
|
||||
if *elem < self.len() {
|
||||
inner.remove(*elem);
|
||||
} else {
|
||||
return Err(TensorError::IndexOutOfBounds(*elem, self.len()));
|
||||
}
|
||||
inner.remove(*elem);
|
||||
}
|
||||
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
@@ -1658,9 +1643,7 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
}
|
||||
|
||||
// implement remainder
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + PartialEq> Rem
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
|
||||
/// Elementwise remainder of a tensor with another tensor.
|
||||
@@ -1689,25 +1672,9 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| {
|
||||
if let Some(zero) = T::zero() {
|
||||
if r != zero {
|
||||
*o = o.clone() % r;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Cannot divide by zero in remainder".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() % r;
|
||||
});
|
||||
|
||||
Ok(lhs)
|
||||
}
|
||||
@@ -1742,6 +1709,7 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
/// assert_eq!(c, vec![2, 3]);
|
||||
///
|
||||
/// ```
|
||||
|
||||
pub fn get_broadcasted_shape(
|
||||
shape_a: &[usize],
|
||||
shape_b: &[usize],
|
||||
@@ -1749,21 +1717,20 @@ pub fn get_broadcasted_shape(
|
||||
let num_dims_a = shape_a.len();
|
||||
let num_dims_b = shape_b.len();
|
||||
|
||||
if num_dims_a == num_dims_b {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
match (num_dims_a, num_dims_b) {
|
||||
(a, b) if a == b => {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
} else if num_dims_a < num_dims_b {
|
||||
Ok(shape_b.to_vec())
|
||||
} else if num_dims_a > num_dims_b {
|
||||
Ok(shape_a.to_vec())
|
||||
} else {
|
||||
Err(TensorError::DimError(
|
||||
(a, b) if a < b => Ok(shape_b.to_vec()),
|
||||
(a, b) if a > b => Ok(shape_a.to_vec()),
|
||||
_ => Err(TensorError::DimError(
|
||||
"Unknown condition for broadcasting".to_string(),
|
||||
))
|
||||
)),
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
|
||||
@@ -385,12 +385,6 @@ pub fn resize<T: TensorType + Send + Sync>(
|
||||
pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("add".to_string()));
|
||||
}
|
||||
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -439,11 +433,6 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("sub".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -490,11 +479,6 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("mult".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,35 +5,33 @@ use log::{debug, error, warn};
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
|
||||
use super::*;
|
||||
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
|
||||
/// VarTensors are used to store and manage circuit columns, typically for assigning ValTensor
|
||||
/// values during circuit layout. The tensor organizes storage into blocks of columns, where each
|
||||
/// block contains multiple columns and each column contains multiple rows.
|
||||
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
|
||||
/// Typically assign [ValTensor]s to [VarTensor]s when laying out a circuit.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq)]
|
||||
pub enum VarTensor {
|
||||
/// A VarTensor for holding Advice values, which are assigned at proving time.
|
||||
Advice {
|
||||
/// Vec of Advice columns, we have [[xx][xx][xx]...] where each inner vec is xx columns
|
||||
inner: Vec<Vec<Column<Advice>>>,
|
||||
/// The number of columns in each inner block
|
||||
///
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// A placeholder tensor used for testing or temporary storage
|
||||
/// Dummy var
|
||||
Dummy {
|
||||
/// The number of columns in each inner block
|
||||
///
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// An empty tensor with no storage
|
||||
/// Empty var
|
||||
#[default]
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// Returns the name of the tensor variant as a static string
|
||||
/// name of the tensor
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => "Advice",
|
||||
@@ -42,35 +40,22 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the tensor is an Advice variant
|
||||
///
|
||||
pub fn is_advice(&self) -> bool {
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
}
|
||||
|
||||
/// Calculates the maximum number of usable rows in the constraint system
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - Log base 2 of the total number of rows (including system and blinding rows)
|
||||
///
|
||||
/// # Returns
|
||||
/// The maximum number of usable rows after accounting for blinding factors
|
||||
pub fn max_rows<F: PrimeField>(cs: &ConstraintSystem<F>, logrows: usize) -> usize {
|
||||
let base = 2u32;
|
||||
base.pow(logrows as u32) as usize - cs.blinding_factors() - 1
|
||||
}
|
||||
|
||||
/// Creates a new VarTensor::Advice with unblinded columns. Unblinded columns are used when
|
||||
/// the values do not need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with unblinded columns enabled for equality constraints
|
||||
/// Create a new VarTensor::Advice that is unblinded
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn new_unblinded_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -108,17 +93,11 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
|
||||
/// the values need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with blinded columns enabled for equality constraints
|
||||
/// Create a new VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn new_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -154,17 +133,11 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
|
||||
/// Fixed columns are used for constant values that are known at circuit creation time.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_constants` - Number of constant values needed
|
||||
/// * `module_requires_fixed` - Whether the module requires at least one fixed column
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of fixed columns created
|
||||
/// Initializes fixed columns to support the VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn constant_cols<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -196,14 +169,7 @@ impl VarTensor {
|
||||
modulo
|
||||
}
|
||||
|
||||
/// Creates a new dummy VarTensor for testing or temporary storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Dummy with the specified dimensions
|
||||
/// Create a new VarTensor::Dummy
|
||||
pub fn dummy(logrows: usize, num_inner_cols: usize) -> Self {
|
||||
let base = 2u32;
|
||||
let max_rows = base.pow(logrows as u32) as usize - 6;
|
||||
@@ -213,7 +179,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of blocks in the tensor
|
||||
/// Gets the dims of the object the VarTensor represents
|
||||
pub fn num_blocks(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner.len(),
|
||||
@@ -221,7 +187,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of columns in each inner block
|
||||
/// Num inner cols
|
||||
pub fn num_inner_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { num_inner_cols, .. } | VarTensor::Dummy { num_inner_cols, .. } => {
|
||||
@@ -231,7 +197,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the total number of columns across all blocks
|
||||
/// Total number of columns
|
||||
pub fn num_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner[0].len() * inner.len(),
|
||||
@@ -239,7 +205,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the maximum number of rows in each column
|
||||
/// Gets the size of each column
|
||||
pub fn col_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { col_size, .. } | VarTensor::Dummy { col_size, .. } => *col_size,
|
||||
@@ -247,7 +213,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the total size of each block (num_inner_cols * col_size)
|
||||
/// Gets the size of each column
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice {
|
||||
@@ -264,13 +230,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a linear coordinate to (block, column, row) coordinates in the storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `linear_coord` - The linear index to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (block_index, column_index, row_index)
|
||||
/// Take a linear coordinate and output the (column, row) position in the storage block.
|
||||
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
|
||||
// x indexes over blocks of size num_inner_cols
|
||||
let x = linear_coord / self.block_size();
|
||||
@@ -283,17 +243,7 @@ impl VarTensor {
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// Queries a range of cells in the tensor during circuit synthesis
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `y` - Column index within block
|
||||
/// * `z` - Starting row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried cells
|
||||
/// Retrieve the value of a specific cell in the tensor.
|
||||
pub fn query_rng<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -318,16 +268,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Queries an entire block of cells at a given offset
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `z` - Row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried block
|
||||
/// Retrieve the value of a specific block at an offset in the tensor.
|
||||
pub fn query_whole_block<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -352,16 +293,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a constant value to a specific cell in the tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `coord` - Coordinate within the tensor
|
||||
/// * `constant` - The constant value to assign
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned cell or an error if assignment fails
|
||||
/// Assigns a constant value to a specific cell in the tensor.
|
||||
pub fn assign_constant<F: PrimeField + TensorType + PartialOrd>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -381,17 +313,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values from a ValTensor to this tensor, excluding specified positions
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `omissions` - Set of positions to skip during assignment
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -422,16 +344,7 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Assigns values from a ValTensor to this tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -483,23 +396,14 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Returns the remaining available space in a column for assignments
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `offset` - Current offset in the column
|
||||
/// * `values` - The ValTensor to check space for
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of rows that need to be flushed or an error if space is insufficient
|
||||
/// Helper function to get the remaining size of the column
|
||||
pub fn get_column_flush<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<usize, halo2_proofs::plonk::Error> {
|
||||
if values.len() > self.col_size() {
|
||||
error!(
|
||||
"There are too many values to flush for this column size, try setting the logrows to a higher value (eg. --logrows 22 on the cli)"
|
||||
);
|
||||
error!("Values are too large for the column");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
|
||||
@@ -523,16 +427,8 @@ impl VarTensor {
|
||||
Ok(flush_len)
|
||||
}
|
||||
|
||||
/// Assigns values to a single column, avoiding column overflow by flushing to the next column if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, number of rows flushed) or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor. Whereby the values are assigned to a single column, without overflowing.
|
||||
/// So for instance if we are assigning 10 values and we are at index 18 of the column, and the columns are of length 20, we skip the last 2 values of current column and start from the beginning of the next column.
|
||||
pub fn assign_exact_column<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -547,17 +443,8 @@ impl VarTensor {
|
||||
Ok((assigned_vals, flush_len))
|
||||
}
|
||||
|
||||
/// Assigns values with duplication in dummy mode, used for testing and simulation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `single_inner_col` - Whether to treat as a single column
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn dummy_assign_with_duplication<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -607,16 +494,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values with duplication but without enforcing constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
pub fn assign_with_duplication_unconstrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -655,18 +533,8 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values with duplication and enforces equality constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `check_mode` - Mode for checking equality constraints
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn assign_with_duplication_constrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -745,17 +613,6 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a single value to the tensor. This is a helper function used by other assignment methods.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `k` - The value to assign
|
||||
/// * `coord` - The coordinate where to assign the value
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned value or an error if assignment fails
|
||||
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -766,28 +623,24 @@ impl VarTensor {
|
||||
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
let res = match k {
|
||||
// Handle direct value assignment
|
||||
ValType::Value(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned value
|
||||
ValType::PrevAssigned(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned constant
|
||||
ValType::AssignedConstant(v, val) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle assigning evaluated value
|
||||
ValType::AssignedValue(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
|
||||
region
|
||||
@@ -796,7 +649,6 @@ impl VarTensor {
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle constant value assignment with caching
|
||||
ValType::Constant(v) => {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
|
||||
let value = ValType::AssignedConstant(
|
||||
|
||||
@@ -75,8 +75,9 @@ mod native_tests {
|
||||
});
|
||||
}
|
||||
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
fn init_wasm() {
|
||||
pub fn init_wasm() {
|
||||
COMPILE_WASM.call_once(|| {
|
||||
build_wasm_ezkl();
|
||||
});
|
||||
@@ -2246,7 +2247,6 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier(
|
||||
num_inner_columns: usize,
|
||||
test_dir: &str,
|
||||
@@ -2796,14 +2796,7 @@ mod native_tests {
|
||||
"icicle",
|
||||
];
|
||||
#[cfg(feature = "macos-metal")]
|
||||
let args = [
|
||||
"build",
|
||||
"--release",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--features",
|
||||
"macos-metal",
|
||||
];
|
||||
let args = ["build", "--release", "--bin", "ezkl", "--features", "macos-metal"];
|
||||
// not macos-metal and not icicle
|
||||
#[cfg(all(not(feature = "icicle"), not(feature = "macos-metal")))]
|
||||
let args = ["build", "--release", "--bin", "ezkl"];
|
||||
|
||||
Reference in New Issue
Block a user