mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
089fc6846d | ||
|
|
9c64e42bd3 | ||
|
|
27b5e5dde3 | ||
|
|
83c4afce3b | ||
|
|
50740a22df | ||
|
|
a2624f6303 | ||
|
|
fc5be4f949 | ||
|
|
d0ba505baa | ||
|
|
f35688917d | ||
|
|
7ae541ed35 | ||
|
|
675628cd08 | ||
|
|
4fe7290240 | ||
|
|
3e027db9b6 | ||
|
|
e566acc22a | ||
|
|
75ea99e81d | ||
|
|
c5354c382d | ||
|
|
bdcba5ca61 | ||
|
|
6752a05f19 | ||
|
|
03aefb85eb | ||
|
|
e86caca8b6 | ||
|
|
c839a30ae6 | ||
|
|
352812b9ac | ||
|
|
d48d0b0b3e | ||
|
|
8b223354cc | ||
|
|
caa6ef8e16 | ||
|
|
c4354c10a5 | ||
|
|
c1ce8c88d0 |
55
.github/workflows/benchmarks.yml
vendored
55
.github/workflows/benchmarks.yml
vendored
@@ -6,22 +6,15 @@ on:
|
||||
description: "Test scenario tags"
|
||||
|
||||
jobs:
|
||||
bench_elgamal:
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Bench elgamal
|
||||
run: cargo bench --verbose --bench elgamal
|
||||
|
||||
bench_poseidon:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -31,10 +24,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench poseidon
|
||||
|
||||
bench_einsum_accum_matmul:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -44,10 +41,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_einsum_matmul
|
||||
|
||||
bench_accum_matmul_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -57,10 +58,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu
|
||||
|
||||
bench_accum_matmul_relu_overflow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -70,10 +75,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu_overflow
|
||||
|
||||
bench_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -83,10 +92,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench relu
|
||||
|
||||
bench_accum_dot:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -96,10 +109,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_dot
|
||||
|
||||
bench_accum_conv:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -109,10 +126,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_conv
|
||||
|
||||
bench_accum_sumpool:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -122,10 +143,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sumpool
|
||||
|
||||
bench_pairwise_add:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -135,10 +160,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench pairwise_add
|
||||
|
||||
bench_accum_sum:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -148,10 +177,14 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sum
|
||||
|
||||
bench_pairwise_pow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
|
||||
84
.github/workflows/engine.yml
vendored
84
.github/workflows/engine.yml
vendored
@@ -15,11 +15,18 @@ defaults:
|
||||
working-directory: .
|
||||
jobs:
|
||||
publish-wasm-bindings:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -40,43 +47,39 @@ jobs:
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
echo '{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "${{ github.ref_name }}",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}' > pkg/package.json
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
@@ -184,21 +187,26 @@ jobs:
|
||||
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
|
||||
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
|
||||
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
|
||||
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
|
||||
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
|
||||
|
||||
4
.github/workflows/large-tests.yml
vendored
4
.github/workflows/large-tests.yml
vendored
@@ -6,9 +6,13 @@ on:
|
||||
description: "Test scenario tags"
|
||||
jobs:
|
||||
large-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: kaiju
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
|
||||
10
.github/workflows/pypi-gpu.yml
vendored
10
.github/workflows/pypi-gpu.yml
vendored
@@ -18,12 +18,19 @@ defaults:
|
||||
jobs:
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: GPU
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -34,6 +41,7 @@ jobs:
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
@@ -43,8 +51,6 @@ jobs:
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
|
||||
148
.github/workflows/pypi.yml
vendored
148
.github/workflows/pypi.yml
vendored
@@ -16,22 +16,32 @@ defaults:
|
||||
|
||||
jobs:
|
||||
macos:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64, universal2-apple-darwin]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
@@ -45,6 +55,13 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'universal2-apple-darwin'
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'x86_64'
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
@@ -62,6 +79,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
windows:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: windows-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -69,11 +88,21 @@ jobs:
|
||||
target: [x64, x86]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: ${{ matrix.target }}
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -107,6 +136,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -114,11 +145,21 @@ jobs:
|
||||
target: [x86_64]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -129,7 +170,6 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -168,58 +208,9 @@ jobs:
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
|
||||
# linux-cross:
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
# matrix:
|
||||
# target: [aarch64, armv7]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: 3.12
|
||||
|
||||
# - name: Install cross-compilation tools for aarch64
|
||||
# if: matrix.target == 'aarch64'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
|
||||
|
||||
# - name: Install cross-compilation tools for armv7
|
||||
# if: matrix.target == 'armv7'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
|
||||
|
||||
# - name: Build wheels
|
||||
# uses: PyO3/maturin-action@v1
|
||||
# with:
|
||||
# target: ${{ matrix.target }}
|
||||
# manylinux: auto
|
||||
# args: --release --out dist --features python-bindings
|
||||
|
||||
# - uses: uraimo/run-on-arch-action@v2.5.0
|
||||
# name: Install built wheel
|
||||
# with:
|
||||
# arch: ${{ matrix.target }}
|
||||
# distro: ubuntu20.04
|
||||
# githubToken: ${{ github.token }}
|
||||
# install: |
|
||||
# apt-get update
|
||||
# apt-get install -y --no-install-recommends python3 python3-pip
|
||||
# pip3 install -U pip
|
||||
# run: |
|
||||
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
|
||||
# python3 -c "import ezkl"
|
||||
|
||||
# - name: Upload wheels
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: wheels
|
||||
# path: dist
|
||||
|
||||
musllinux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -228,11 +219,21 @@ jobs:
|
||||
- x86_64-unknown-linux-musl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -276,6 +277,8 @@ jobs:
|
||||
path: dist
|
||||
|
||||
musllinux-cross:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -285,10 +288,20 @@ jobs:
|
||||
arch: aarch64
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -332,8 +345,6 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
if: "startsWith(github.ref, 'refs/tags/')"
|
||||
# TODO: Uncomment if linux-cross is working
|
||||
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
|
||||
needs: [macos, windows, linux, musllinux, musllinux-cross]
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
@@ -341,35 +352,34 @@ jobs:
|
||||
name: wheels
|
||||
- name: List Files
|
||||
run: ls -R
|
||||
|
||||
# Both publish steps will fail if there is no trusted publisher setup
|
||||
# On failure the publish step will then simply continue to the next one
|
||||
|
||||
# # publishes to TestPyPI
|
||||
# - name: Publish package distribution to TestPyPI
|
||||
# uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# packages-dir: ./
|
||||
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
23
.github/workflows/release.yml
vendored
23
.github/workflows/release.yml
vendored
@@ -10,6 +10,9 @@ on:
|
||||
- "*"
|
||||
jobs:
|
||||
create-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: create-release
|
||||
runs-on: ubuntu-22.04
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -33,6 +36,9 @@ jobs:
|
||||
tag_name: ${{ env.EZKL_VERSION }}
|
||||
|
||||
build-release-gpu:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: build-release-gpu
|
||||
needs: ["create-release"]
|
||||
runs-on: GPU
|
||||
@@ -50,6 +56,9 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- name: Get release version from tag
|
||||
shell: bash
|
||||
@@ -91,6 +100,10 @@ jobs:
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
build-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
issues: write
|
||||
name: build-release
|
||||
needs: ["create-release"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -132,6 +145,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get release version from tag
|
||||
shell: bash
|
||||
@@ -181,14 +196,18 @@ jobs:
|
||||
echo "target flag is: ${{ env.TARGET_FLAGS }}"
|
||||
echo "target dir is: ${{ env.TARGET_DIR }}"
|
||||
|
||||
- name: Build release binary (no asm)
|
||||
if: matrix.build != 'linux-gnu'
|
||||
- name: Build release binary (no asm or metal)
|
||||
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
|
||||
|
||||
- name: Build release binary (metal)
|
||||
if: matrix.build == 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
|
||||
|
||||
- name: Strip release binary
|
||||
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
|
||||
run: strip "target/${{ matrix.target }}/release/ezkl"
|
||||
|
||||
268
.github/workflows/rust.yml
vendored
268
.github/workflows/rust.yml
vendored
@@ -19,11 +19,30 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
fr-age-test:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: fr age Mock
|
||||
run: cargo test --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
|
||||
|
||||
build:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -33,9 +52,13 @@ jobs:
|
||||
run: cargo build --verbose
|
||||
|
||||
docs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -45,9 +68,13 @@ jobs:
|
||||
run: cargo doc --verbose
|
||||
|
||||
library-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -71,6 +98,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -101,9 +130,13 @@ jobs:
|
||||
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
|
||||
ultra-overflow-tests_og-lookup:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -134,9 +167,13 @@ jobs:
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored
|
||||
|
||||
ultra-overflow-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -167,9 +204,13 @@ jobs:
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture -- --include-ignored
|
||||
|
||||
model-serialization:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -183,9 +224,13 @@ jobs:
|
||||
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
|
||||
|
||||
wasm32-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -194,7 +239,7 @@ jobs:
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
version: "v0.12.1"
|
||||
- uses: nanasess/setup-chromedriver@v2
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
@@ -208,10 +253,14 @@ jobs:
|
||||
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
|
||||
|
||||
mock-proving-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -271,10 +320,14 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
|
||||
prove-and-verify-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -285,6 +338,8 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
@@ -349,11 +404,49 @@ jobs:
|
||||
- name: KZG prove and verify tests (EVM + hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
|
||||
|
||||
# prove-and-verify-tests-metal:
|
||||
# permissions:
|
||||
# contents: read
|
||||
# runs-on: macos-13
|
||||
# # needs: [build, library-tests, docs]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: jetli/wasm-pack-action@v0.4.0
|
||||
# with:
|
||||
# # Pin to version 0.12.1
|
||||
# version: 'v0.12.1'
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2024-07-18
|
||||
# - uses: actions/checkout@v3
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - name: Use pnpm 8
|
||||
# uses: pnpm/action-setup@v2
|
||||
# with:
|
||||
# version: 8
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --release --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
|
||||
|
||||
prove-and-verify-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -362,13 +455,15 @@ jobs:
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
version: "v0.12.1"
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
@@ -431,6 +526,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -461,10 +558,14 @@ jobs:
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -483,6 +584,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -496,10 +599,14 @@ jobs:
|
||||
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -513,10 +620,14 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -534,10 +645,14 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
examples:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -551,10 +666,14 @@ jobs:
|
||||
run: cargo nextest run --release tests_examples
|
||||
|
||||
python-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
@@ -577,10 +696,14 @@ jobs:
|
||||
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
|
||||
|
||||
accuracy-measurement-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
@@ -607,6 +730,8 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
|
||||
|
||||
python-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
services:
|
||||
# Label used to access the service container
|
||||
@@ -628,6 +753,8 @@ jobs:
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
@@ -650,6 +777,8 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: Neural bow
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
@@ -669,72 +798,87 @@ jobs:
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
# - name: Reusable verifier tutorial
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_
|
||||
|
||||
ios-integration-tests:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
|
||||
swift-package-tests:
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
|
||||
33
.github/workflows/static-analysis.yml
vendored
Normal file
33
.github/workflows/static-analysis.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Static Analysis
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
# Run Zizmor static analysis
|
||||
|
||||
- name: Install Zizmor
|
||||
run: cargo install --locked zizmor
|
||||
|
||||
- name: Run Zizmor Analysis
|
||||
run: zizmor .
|
||||
|
||||
|
||||
|
||||
134
.github/workflows/swift-pm.yml
vendored
Normal file
134
.github/workflows/swift-pm.yml
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
name: Build and Publish EZKL iOS SPM package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
# Only support SemVer versioning tags
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+'
|
||||
- '[0-9]+.[0-9]+.[0-9]+'
|
||||
|
||||
jobs:
|
||||
build-and-update:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
|
||||
steps:
|
||||
- name: Checkout EZKL
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Extract TAG from github.ref_name
|
||||
run: |
|
||||
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
|
||||
TAG="${RELEASE_TAG}"
|
||||
echo "Original TAG: $TAG"
|
||||
# Remove leading 'v' if present to match the Swift Package Manager version format.
|
||||
NEW_TAG=${TAG#v}
|
||||
echo "Stripped TAG: $NEW_TAG"
|
||||
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Rust (nightly)
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
override: true
|
||||
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift-package repository
|
||||
run: |
|
||||
git clone https://${{ env.EZKL_SWIFT_PACKAGE_REPO }}
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
if git diff --quiet Sources/EzklCoreBindings Tests/EzklAssets; then
|
||||
echo "no_changes=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "no_changes=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Set up Xcode environment
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
|
||||
- name: Setup Git
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
git config user.name "GitHub Action"
|
||||
git config user.email "action@github.com"
|
||||
git remote set-url origin https://zkonduit:${EZKL_SWIFT_PACKAGE_REPO_TOKEN}@${{ env.EZKL_SWIFT_PACKAGE_REPO }}
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
|
||||
|
||||
- name: Commit and Push Changes
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
git add Sources/EzklCoreBindings Tests/EzklAssets
|
||||
git commit -m "Automatically updated EzklCoreBindings for EZKL"
|
||||
if ! git push origin; then
|
||||
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Tag the latest commit
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
source $GITHUB_ENV
|
||||
# Tag the latest commit on the current branch
|
||||
if git rev-parse "$TAG" >/dev/null 2>&1; then
|
||||
echo "Tag $TAG already exists locally. Skipping tag creation."
|
||||
else
|
||||
git tag "$TAG"
|
||||
fi
|
||||
|
||||
if ! git push origin "$TAG"; then
|
||||
echo "::error::Failed to push tag '$TAG' to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure EZKL_PORTER_TOKEN has correct permissions."
|
||||
exit 1
|
||||
fi
|
||||
2
.github/workflows/tagging.yml
vendored
2
.github/workflows/tagging.yml
vendored
@@ -12,6 +12,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
|
||||
85
.github/workflows/update-ios-package.yml
vendored
85
.github/workflows/update-ios-package.yml
vendored
@@ -1,85 +0,0 @@
|
||||
name: Build and Publish EZKL iOS SPM package
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag to release"
|
||||
required: true
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
jobs:
|
||||
build-and-update:
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout EZKL
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
override: true
|
||||
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift-package repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/*
|
||||
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
|
||||
- name: Commit and Push Changes to feat/ezkl-direct-integration
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
git config user.name "GitHub Action"
|
||||
git config user.email "action@github.com"
|
||||
git add Sources/EzklCoreBindings
|
||||
git add Tests/EzklAssets
|
||||
git commit -m "Automatically updated EzklCoreBindings for EZKL"
|
||||
git tag ${{ github.event.inputs.tag }}
|
||||
git remote set-url origin https://zkonduit:${EZKL_PORTER_TOKEN}@github.com/zkonduit/ezkl-swift-package.git
|
||||
git push origin
|
||||
git push origin tag ${{ github.event.inputs.tag }}
|
||||
env:
|
||||
EZKL_PORTER_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
|
||||
130
Cargo.lock
generated
130
Cargo.lock
generated
@@ -1,6 +1,6 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
@@ -1760,7 +1760,7 @@ checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
|
||||
[[package]]
|
||||
name = "ecc"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"integer",
|
||||
"num-bigint",
|
||||
@@ -1835,6 +1835,16 @@ dependencies = [
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.2"
|
||||
@@ -1848,6 +1858,19 @@ dependencies = [
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"env_filter",
|
||||
"humantime",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
@@ -1923,7 +1946,7 @@ dependencies = [
|
||||
"console_error_panic_hook",
|
||||
"criterion 0.5.1",
|
||||
"ecc",
|
||||
"env_logger",
|
||||
"env_logger 0.10.2",
|
||||
"ethabi",
|
||||
"foundry-compilers",
|
||||
"gag",
|
||||
@@ -1931,7 +1954,7 @@ dependencies = [
|
||||
"halo2_gadgets",
|
||||
"halo2_proofs",
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.7.0",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"hex",
|
||||
"indicatif",
|
||||
"instant",
|
||||
@@ -1939,20 +1962,17 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"metal",
|
||||
"mimalloc",
|
||||
"mnist",
|
||||
"num",
|
||||
"objc",
|
||||
"openssl",
|
||||
"pg_bigdecimal",
|
||||
"portable-atomic",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"semver 1.0.22",
|
||||
"seq-macro",
|
||||
@@ -2377,7 +2397,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b"
|
||||
source = "git+https://github.com/zkonduit/halo2#d7ecad83c7439fa1cb450ee4a89c2d0b45604ceb"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec",
|
||||
@@ -2394,14 +2414,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b#0654e92bdf725fd44d849bfef3643870a8c7d50b"
|
||||
source = "git+https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996#bf9d0057a82443be48c4779bbe14961c18fb5996"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
"env_logger 0.10.2",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2curves 0.7.0",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"icicle-bn254",
|
||||
"icicle-core",
|
||||
"icicle-cuda-runtime",
|
||||
@@ -2409,6 +2429,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"mopro-msm",
|
||||
"rand_chacha",
|
||||
"rand_core 0.6.4",
|
||||
"rustc-hash 2.0.0",
|
||||
@@ -2494,6 +2515,36 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d380afeef3f1d4d3245b76895172018cfb087d9976a7cabcd5597775b2933e07"
|
||||
dependencies = [
|
||||
"blake2",
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"pairing",
|
||||
"pasta_curves",
|
||||
"paste",
|
||||
"rand 0.8.5",
|
||||
"rand_core 0.6.4",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_arrays",
|
||||
"sha2",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.7.0"
|
||||
@@ -2503,7 +2554,7 @@ dependencies = [
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive",
|
||||
"halo2derive 0.1.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
@@ -2523,6 +2574,20 @@ dependencies = [
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bdb99e7492b4f5ff469d238db464131b86c2eaac814a78715acba369f64d2c76"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
@@ -2539,7 +2604,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2_proofs",
|
||||
"num-bigint",
|
||||
@@ -2890,7 +2955,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "integer"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"maingate",
|
||||
"num-bigint",
|
||||
@@ -3074,7 +3139,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.48.5",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3201,7 +3266,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "maingate"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2wrong",
|
||||
"num-bigint",
|
||||
@@ -3283,7 +3348,8 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "metal"
|
||||
version = "0.29.0"
|
||||
source = "git+https://github.com/gfx-rs/metal-rs#0e1918b34689c4b8cd13a43372f9898680547ee9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"block",
|
||||
@@ -3354,6 +3420,28 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mopro-msm"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/metal-msm-gpu-acceleration.git#be5f647b1a6c1a6ea9024390744a2b4d87f5d002"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"env_logger 0.11.6",
|
||||
"halo2curves 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"instant",
|
||||
"itertools 0.13.0",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"metal",
|
||||
"objc",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
@@ -3587,9 +3675,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-src"
|
||||
version = "300.2.3+3.2.1"
|
||||
version = "300.4.1+3.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cff92b6f71555b61bb9315f7c64da3ca43d87531622120fea0195fc761b4843"
|
||||
checksum = "faa4eac4138c62414b5622d1b31c5c304f34b406b013c079c2bbc652fdd6678c"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
@@ -5142,7 +5230,7 @@ checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
|
||||
[[package]]
|
||||
name = "snark-verifier"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac%2Fchunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
dependencies = [
|
||||
"ecc",
|
||||
"halo2_proofs",
|
||||
@@ -6146,7 +6234,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "uniffi_testing"
|
||||
version = "0.28.0"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat/testing-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat%2Ftesting-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"camino",
|
||||
|
||||
19
Cargo.toml
19
Cargo.toml
@@ -40,7 +40,6 @@ maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = { version = "1.6.0", optional = true }
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
|
||||
semver = { version = "1.0.22", optional = true }
|
||||
|
||||
@@ -74,7 +73,6 @@ tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
regex = { version = "1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
@@ -91,7 +89,6 @@ pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", ver
|
||||
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
pyo3-stub-gen = { version = "0.6.0", optional = true }
|
||||
@@ -147,6 +144,10 @@ shellexpand = "3.1.0"
|
||||
runner = 'wasm-bindgen-test-runner'
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "zero_finder"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_dot"
|
||||
harness = false
|
||||
@@ -241,16 +242,14 @@ ezkl = [
|
||||
"dep:indicatif",
|
||||
"dep:gag",
|
||||
"dep:reqwest",
|
||||
"dep:openssl",
|
||||
"dep:tokio-postgres",
|
||||
"dep:pg_bigdecimal",
|
||||
"dep:lazy_static",
|
||||
"dep:regex",
|
||||
"dep:tokio",
|
||||
"dep:openssl",
|
||||
"dep:mimalloc",
|
||||
"dep:chrono",
|
||||
"dep:sha256",
|
||||
"dep:portable-atomic",
|
||||
"dep:clap_complete",
|
||||
"dep:halo2_solidity_verifier",
|
||||
"dep:semver",
|
||||
@@ -273,10 +272,14 @@ icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
no-update = []
|
||||
|
||||
macos-metal = ["halo2_proofs/macos"]
|
||||
ios-metal = ["halo2_proofs/ios"]
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
117
benches/zero_finder.rs
Normal file
117
benches/zero_finder.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use std::thread;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use halo2curves::{bn256::Fr as F, ff::Field};
|
||||
use maybe_rayon::{
|
||||
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
|
||||
slice::ParallelSlice,
|
||||
};
|
||||
use rand::Rng;
|
||||
|
||||
// Assuming these are your types
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
enum ValType {
|
||||
Constant(F),
|
||||
AssignedConstant(usize, F),
|
||||
Other,
|
||||
}
|
||||
|
||||
// Helper to generate test data
|
||||
fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..size)
|
||||
.map(|_i| {
|
||||
if rng.gen::<f64>() < zero_probability {
|
||||
ValType::Constant(F::ZERO)
|
||||
} else {
|
||||
ValType::Constant(F::ONE) // Or some other non-zero value
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn bench_zero_finding(c: &mut Criterion) {
|
||||
let sizes = [
|
||||
1_000, // 1K
|
||||
10_000, // 10K
|
||||
100_000, // 100K
|
||||
256 * 256 * 2, // Our specific case
|
||||
1_000_000, // 1M
|
||||
10_000_000, // 10M
|
||||
];
|
||||
|
||||
let zero_probability = 0.1; // 10% zeros
|
||||
|
||||
let mut group = c.benchmark_group("zero_finding");
|
||||
group.sample_size(10); // Adjust based on your needs
|
||||
|
||||
for &size in &sizes {
|
||||
let data = generate_test_data(size, zero_probability);
|
||||
|
||||
// Benchmark sequential version
|
||||
group.bench_function(format!("sequential_{}", size), |b| {
|
||||
b.iter(|| {
|
||||
let result = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| match e {
|
||||
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
|
||||
(*r == F::ZERO).then_some(i)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
black_box(result)
|
||||
})
|
||||
});
|
||||
|
||||
// Benchmark parallel version
|
||||
group.bench_function(format!("parallel_{}", size), |b| {
|
||||
b.iter(|| {
|
||||
let result = data
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| match e {
|
||||
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
|
||||
(*r == F::ZERO).then_some(i)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
black_box(result)
|
||||
})
|
||||
});
|
||||
|
||||
// Benchmark chunked parallel version
|
||||
group.bench_function(format!("chunked_parallel_{}", size), |b| {
|
||||
b.iter(|| {
|
||||
let num_cores = thread::available_parallelism()
|
||||
.map(|n| n.get())
|
||||
.unwrap_or(1);
|
||||
let chunk_size = (size / num_cores).max(100);
|
||||
|
||||
let result = data
|
||||
.par_chunks(chunk_size)
|
||||
.enumerate()
|
||||
.flat_map(|(chunk_idx, chunk)| {
|
||||
chunk
|
||||
.par_iter() // Make sure we use par_iter() here
|
||||
.enumerate()
|
||||
.filter_map(move |(i, e)| match e {
|
||||
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
|
||||
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
black_box(result)
|
||||
})
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_zero_finding);
|
||||
criterion_main!(benches);
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '18.1.8'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -54,7 +54,7 @@
|
||||
" gip_run_args.param_scale = 19\n",
|
||||
" gip_run_args.logrows = 8\n",
|
||||
" run_args = ezkl.gen_settings(py_run_args=gip_run_args)\n",
|
||||
" ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n",
|
||||
" await ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n",
|
||||
" ezkl.compile_circuit()\n",
|
||||
" res = await ezkl.gen_witness()\n",
|
||||
" print(res)\n",
|
||||
|
||||
@@ -1,279 +1,284 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Linear Regression\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
|
||||
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import json\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.linear_model import LinearRegression\n",
|
||||
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
|
||||
"# y = 1 * x_0 + 2 * x_1 + 3\n",
|
||||
"y = np.dot(X, np.array([1, 2])) + 3\n",
|
||||
"reg = LinearRegression().fit(X, y)\n",
|
||||
"reg.score(X, y)\n",
|
||||
"\n",
|
||||
"circuit = convert(reg, \"torch\", X[:1]).model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = X.shape[1:]\n",
|
||||
"x = torch.rand(1, *shape, requires_grad=True)\n",
|
||||
"torch_out = circuit(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
" x,\n",
|
||||
" # where to save the model (can be a file or file-like object)\n",
|
||||
" \"network.onnx\",\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names=['input'], # the model's input names\n",
|
||||
" output_names=['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cal_path = os.path.join(\"calibration.json\")\n",
|
||||
"\n",
|
||||
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aa4f090",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Linear Regression\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
|
||||
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import json\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.linear_model import LinearRegression\n",
|
||||
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
|
||||
"# y = 1 * x_0 + 2 * x_1 + 3\n",
|
||||
"y = np.dot(X, np.array([1, 2])) + 3\n",
|
||||
"reg = LinearRegression().fit(X, y)\n",
|
||||
"reg.score(X, y)\n",
|
||||
"\n",
|
||||
"circuit = convert(reg, \"torch\", X[:1]).model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = X.shape[1:]\n",
|
||||
"x = torch.rand(1, *shape, requires_grad=True)\n",
|
||||
"torch_out = circuit(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
" x,\n",
|
||||
" # where to save the model (can be a file or file-like object)\n",
|
||||
" \"network.onnx\",\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names=['input'], # the model's input names\n",
|
||||
" output_names=['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# note that you can also call the following function to generate random data for the model\n",
|
||||
"# it is functionally equivalent to the code above\n",
|
||||
"ezkl.gen_random_data()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cal_path = os.path.join(\"calibration.json\")\n",
|
||||
"\n",
|
||||
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aa4f090",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
||||
@@ -1,456 +1,459 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2wIAHwqH2_mo"
|
||||
},
|
||||
"source": [
|
||||
"**Import Dependencies**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "osjj-0Ta3E8O"
|
||||
},
|
||||
"source": [
|
||||
"**Create Computational Graph**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
"\n",
|
||||
" # x is a time series \n",
|
||||
" def forward(self, x):\n",
|
||||
" return [torch.mean(x)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = Model()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# # print(torch.__version__)\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"print(device)\n",
|
||||
"\n",
|
||||
"circuit.to(device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=11, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"# export(circuit, input_shape=[1, 20])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E3qCeX-X5xqd"
|
||||
},
|
||||
"source": [
|
||||
"**Set Data Source and Get Data**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"pg_input_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
|
||||
"print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# this corresponds to 4 batches\n",
|
||||
"calibration_filename = os.path.join('calibration.json')\n",
|
||||
"\n",
|
||||
"pg_cal_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eLJ7oirQ_HQR"
|
||||
},
|
||||
"source": [
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rNw0C9QL6W88"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
"with open(\"settings.json\") as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
" json_formatted_str = json.dumps(data, indent=2)\n",
|
||||
"\n",
|
||||
" print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"assert os.path.exists(\"settings.json\")\n",
|
||||
"assert os.path.exists(\"input.json\")\n",
|
||||
"assert os.path.exists(\"lol.onnx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_filename)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2wIAHwqH2_mo"
|
||||
},
|
||||
"source": [
|
||||
"**Import Dependencies**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "osjj-0Ta3E8O"
|
||||
},
|
||||
"source": [
|
||||
"**Create Computational Graph**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
"\n",
|
||||
" # x is a time series \n",
|
||||
" def forward(self, x):\n",
|
||||
" return [torch.mean(x)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = Model()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# # print(torch.__version__)\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"print(device)\n",
|
||||
"\n",
|
||||
"circuit.to(device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=11, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"# export(circuit, input_shape=[1, 20])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E3qCeX-X5xqd"
|
||||
},
|
||||
"source": [
|
||||
"**Set Data Source and Get Data**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"pg_input_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
|
||||
"print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# this corresponds to 4 batches\n",
|
||||
"calibration_filename = os.path.join('calibration.json')\n",
|
||||
"\n",
|
||||
"pg_cal_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eLJ7oirQ_HQR"
|
||||
},
|
||||
"source": [
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rNw0C9QL6W88"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"await ezkl.get_srs(settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
"with open(\"settings.json\") as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
" json_formatted_str = json.dumps(data, indent=2)\n",
|
||||
"\n",
|
||||
" print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"assert os.path.exists(\"settings.json\")\n",
|
||||
"assert os.path.exists(\"input.json\")\n",
|
||||
"assert os.path.exists(\"lol.onnx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_filename)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
||||
@@ -453,18 +453,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now mock aggregate the proofs\n",
|
||||
"proofs = []\n",
|
||||
"for i in range(3):\n",
|
||||
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
" proofs.append(proof_path)\n",
|
||||
"# proofs = []\n",
|
||||
"# for i in range(3):\n",
|
||||
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
"# proofs.append(proof_path)\n",
|
||||
"\n",
|
||||
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
|
||||
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ezkl",
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -478,7 +478,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
"version": "3.12.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
1
examples/onnx/fr_age/input.json
Normal file
1
examples/onnx/fr_age/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/fr_age/network.onnx
Normal file
BIN
examples/onnx/fr_age/network.onnx
Normal file
Binary file not shown.
@@ -12,6 +12,7 @@ asyncio_mode = "auto"
|
||||
|
||||
[project]
|
||||
name = "ezkl"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.7"
|
||||
classifiers = [
|
||||
"Programming Language :: Rust",
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
// ignore file if compiling for wasm
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::{CommandFactory, Parser};
|
||||
|
||||
@@ -938,6 +938,45 @@ fn gen_settings(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Generates random data for the model
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// output: str
|
||||
/// Path to create the data file
|
||||
///
|
||||
/// seed: int
|
||||
/// Random seed to use for generated data
|
||||
///
|
||||
/// variables
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
output=PathBuf::from(DEFAULT_SETTINGS),
|
||||
variables=Vec::from([("batch_size".to_string(), 1)]),
|
||||
seed=DEFAULT_SEED.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_random_data(
|
||||
model: PathBuf,
|
||||
output: PathBuf,
|
||||
variables: Vec<(String, usize)>,
|
||||
seed: u64,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
|
||||
let err_str = format!("Failed to generate settings: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Calibrates the circuit settings
|
||||
///
|
||||
/// Arguments
|
||||
@@ -2055,6 +2094,7 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(get_srs, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;
|
||||
|
||||
@@ -141,10 +141,11 @@ pub(crate) fn gen_vk(
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
|
||||
|
||||
let mut serialized_vk = Vec::new();
|
||||
vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes)
|
||||
.map_err(|e| {
|
||||
EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e))
|
||||
})?;
|
||||
vk.write(
|
||||
&mut serialized_vk,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
|
||||
|
||||
Ok(serialized_vk)
|
||||
}
|
||||
@@ -165,7 +166,7 @@ pub(crate) fn gen_pk(
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit.settings().clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
|
||||
@@ -197,7 +198,7 @@ pub(crate) fn verify(
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings.clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
|
||||
@@ -277,7 +278,7 @@ pub(crate) fn verify_aggr(
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
|
||||
@@ -365,7 +366,7 @@ pub(crate) fn prove(
|
||||
let mut reader = BufReader::new(&pk[..]);
|
||||
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit.settings().clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
|
||||
@@ -487,7 +488,7 @@ pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
|
||||
@@ -504,7 +505,7 @@ pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
|
||||
let mut reader = BufReader::new(&pk[..]);
|
||||
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
|
||||
|
||||
@@ -100,9 +100,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
Self::configure_with_cols(
|
||||
@@ -152,9 +149,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
let instance = meta.instance_column();
|
||||
@@ -176,7 +170,10 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
assert_eq!(message.len(), 1);
|
||||
if message.len() != 1 {
|
||||
return Err(ModuleError::InputWrongLength(message.len()));
|
||||
}
|
||||
|
||||
let message = message[0].clone();
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -231,7 +228,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"PrevAssigned".to_string(),
|
||||
"AssignedValue".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
@@ -296,6 +293,12 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
|
||||
// empty hash case
|
||||
if input_cells.is_empty() {
|
||||
return Ok(input[0].clone());
|
||||
}
|
||||
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
input_cells.iter().map(|e| ValType::from(e.clone())).into();
|
||||
@@ -517,6 +520,21 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash_empty() {
|
||||
let message = [];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 2> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash() {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::str::FromStr;
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector, TableColumn},
|
||||
poly::Rotation,
|
||||
};
|
||||
use log::debug;
|
||||
@@ -75,6 +75,16 @@ impl FromStr for CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckMode {
|
||||
/// Returns the value of the check mode
|
||||
pub fn is_safe(&self) -> bool {
|
||||
match self {
|
||||
CheckMode::SAFE => true,
|
||||
CheckMode::UNSAFE => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
|
||||
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
|
||||
@@ -205,15 +215,16 @@ impl DynamicLookups {
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
|
||||
pub struct Shuffles {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub reference_selectors: Vec<Selector>,
|
||||
pub output_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// tables
|
||||
pub references: Vec<VarTensor>,
|
||||
pub outputs: Vec<VarTensor>,
|
||||
}
|
||||
|
||||
impl Shuffles {
|
||||
@@ -224,9 +235,13 @@ impl Shuffles {
|
||||
|
||||
Self {
|
||||
input_selectors: BTreeMap::new(),
|
||||
reference_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone()],
|
||||
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
|
||||
output_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
|
||||
outputs: vec![
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -326,6 +341,8 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
/// shared table inputs
|
||||
pub shared_table_inputs: Vec<TableColumn>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
@@ -338,6 +355,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
shared_table_inputs: vec![],
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -364,6 +382,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
if inputs[0].num_cols() != output.num_cols() {
|
||||
log::warn!("input and output shapes do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
|
||||
log::warn!("input number of inner columns do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != output.num_inner_cols() {
|
||||
log::warn!("input and output number of inner columns do not match");
|
||||
}
|
||||
|
||||
for i in 0..output.num_blocks() {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
@@ -476,6 +500,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
shared_table_inputs: vec![],
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -506,21 +531,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
return Err(CircuitError::WrongColumnType(output.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let table = if !self.static_lookups.tables.contains_key(nl) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let table = if let Some(table) = self.static_lookups.tables.values().next() {
|
||||
Table::<F>::configure(
|
||||
cs,
|
||||
lookup_range,
|
||||
logrows,
|
||||
nl,
|
||||
Some(table.table_inputs.clone()),
|
||||
)
|
||||
} else {
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
|
||||
};
|
||||
let table =
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, &mut self.shared_table_inputs);
|
||||
self.static_lookups.tables.insert(nl.clone(), table.clone());
|
||||
table
|
||||
} else {
|
||||
@@ -571,9 +584,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// this is 0 if the index is the same as the column index (starting from 1)
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* table
|
||||
* (table
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier =
|
||||
table.selector_constructor.get_selector_val_at_idx(col_idx);
|
||||
@@ -605,6 +618,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.static_lookups
|
||||
.selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
@@ -730,8 +777,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
pub fn configure_shuffles(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
inputs: &[VarTensor; 2],
|
||||
references: &[VarTensor; 2],
|
||||
inputs: &[VarTensor; 3],
|
||||
outputs: &[VarTensor; 3],
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
@@ -742,14 +789,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
}
|
||||
|
||||
for t in references.iter() {
|
||||
for t in outputs.iter() {
|
||||
if !t.is_advice() || t.num_inner_cols() > 1 {
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// assert all tables have the same number of blocks
|
||||
if references
|
||||
if outputs
|
||||
.iter()
|
||||
.map(|t| t.num_blocks())
|
||||
.collect::<Vec<_>>()
|
||||
@@ -757,23 +804,23 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
.any(|w| w[0] != w[1])
|
||||
{
|
||||
return Err(CircuitError::WrongDynamicColumnType(
|
||||
"references inner cols".to_string(),
|
||||
"outputs inner cols".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
for q in 0..references[0].num_blocks() {
|
||||
let s_reference = cs.complex_selector();
|
||||
for q in 0..outputs[0].num_blocks() {
|
||||
let s_output = cs.complex_selector();
|
||||
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
cs.lookup_any("shuffle", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_referenceq = cs.query_selector(s_reference);
|
||||
let s_outputq = cs.query_selector(s_output);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
|
||||
for input in inputs {
|
||||
@@ -785,9 +832,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
});
|
||||
}
|
||||
|
||||
let mut ref_queries = vec![one.clone()];
|
||||
for reference in references {
|
||||
ref_queries.push(match reference {
|
||||
let mut output_queries = vec![one.clone()];
|
||||
for output in outputs {
|
||||
output_queries.push(match output {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[q][0], Rotation(0))
|
||||
}
|
||||
@@ -796,7 +843,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
|
||||
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
@@ -807,13 +854,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
.or_insert(s_input);
|
||||
}
|
||||
}
|
||||
self.shuffles.reference_selectors.push(s_reference);
|
||||
self.shuffles.output_selectors.push(s_output);
|
||||
}
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.shuffles.references.is_empty() {
|
||||
debug!("assigning shuffles reference");
|
||||
self.shuffles.references = references.to_vec();
|
||||
if self.shuffles.outputs.is_empty() {
|
||||
debug!("assigning shuffles output");
|
||||
self.shuffles.outputs = outputs.to_vec();
|
||||
}
|
||||
if self.shuffles.inputs.is_empty() {
|
||||
debug!("assigning shuffles input");
|
||||
@@ -845,7 +892,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
|
||||
self.range_checks.ranges.entry(range)
|
||||
{
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
@@ -883,9 +929,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let default_x = range_check.get_first_element(col_idx);
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* range_check
|
||||
* (range_check
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
|
||||
let multiplier = range_check
|
||||
.selector_constructor
|
||||
@@ -908,6 +954,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.range_checks
|
||||
.selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
|
||||
@@ -25,7 +25,7 @@ pub enum CircuitError {
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
///
|
||||
/// Invalid einsum expression
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
/// Flush error
|
||||
@@ -100,4 +100,10 @@ pub enum CircuitError {
|
||||
#[error("invalid input type {0}")]
|
||||
/// Invalid input type
|
||||
InvalidInputType(String),
|
||||
#[error("an element is missing from the shuffled version of the tensor")]
|
||||
/// An element is missing from the shuffled version of the tensor
|
||||
MissingShuffleElement,
|
||||
/// Visibility has not been set
|
||||
#[error("visibility has not been set")]
|
||||
UnsetVisibility,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::integer_rep_to_felt,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -250,8 +250,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
integer_rep_to_felt(input_scale.0 as IntegerRep),
|
||||
integer_rep_to_felt(output_scale.0 as IntegerRep),
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
@@ -259,7 +259,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(denom.0 as i128),
|
||||
integer_rep_to_felt(denom.0 as IntegerRep),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -264,7 +264,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
}
|
||||
/// Rebase the scale of the constant
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
|
||||
let visibility = self.quantized_values.visibility().unwrap();
|
||||
let visibility = match self.quantized_values.visibility() {
|
||||
Some(v) => v,
|
||||
None => return Err(CircuitError::UnsetVisibility),
|
||||
};
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -252,6 +252,12 @@ impl<
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 1 {
|
||||
return Err(TensorError::DimError(
|
||||
"GatherElements only accepts single inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
@@ -269,6 +275,12 @@ impl<
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 2 {
|
||||
return Err(TensorError::DimError(
|
||||
"ScatterElements requires two inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
|
||||
@@ -671,22 +671,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_with_duplication(
|
||||
pub fn assign_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &crate::circuit::CheckMode,
|
||||
single_inner_col: bool,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication(
|
||||
let (res, len) = var.assign_with_duplication_unconstrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
check_mode,
|
||||
single_inner_col,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
@@ -695,7 +690,37 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
single_inner_col,
|
||||
false,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_with_duplication_constrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &crate::circuit::CheckMode,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication_constrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
check_mode,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
true,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
|
||||
@@ -132,21 +132,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size given the number of rows and reserved blinding rows
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
(range_len / col_size as IntegerRep) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
@@ -168,7 +163,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
nonlinearity: &LookupOp,
|
||||
preexisting_inputs: Option<Vec<TableColumn>>,
|
||||
preexisting_inputs: &mut Vec<TableColumn>,
|
||||
) -> Table<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
@@ -177,28 +172,28 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
debug!("table range: {:?}", range);
|
||||
|
||||
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
|
||||
let mut cols = vec![];
|
||||
for _ in 0..num_cols {
|
||||
cols.push(cs.lookup_table_column());
|
||||
// validate enough columns are provided to store the range
|
||||
if preexisting_inputs.len() < num_cols {
|
||||
// add columns to match the required number of columns
|
||||
let diff = num_cols - preexisting_inputs.len();
|
||||
for _ in 0..diff {
|
||||
preexisting_inputs.push(cs.lookup_table_column());
|
||||
}
|
||||
cols
|
||||
});
|
||||
|
||||
let num_cols = table_inputs.len();
|
||||
}
|
||||
|
||||
let num_cols = preexisting_inputs.len();
|
||||
if num_cols > 1 {
|
||||
warn!("Using {} columns for non-linearity table.", num_cols);
|
||||
}
|
||||
|
||||
let table_outputs = table_inputs
|
||||
let table_outputs = preexisting_inputs
|
||||
.iter()
|
||||
.map(|_| cs.lookup_table_column())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Table {
|
||||
nonlinearity: nonlinearity.clone(),
|
||||
table_inputs,
|
||||
table_inputs: preexisting_inputs.clone(),
|
||||
table_outputs,
|
||||
is_assigned: false,
|
||||
selector_constructor: SelectorConstructor::new(num_cols),
|
||||
@@ -355,16 +350,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
/// calculates the column size
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
|
||||
@@ -1040,6 +1040,10 @@ mod conv {
|
||||
let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
|
||||
// column for constants
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1171,7 +1175,7 @@ mod conv_col_ultra_overflow {
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const K: usize = 6;
|
||||
const LEN: usize = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -1191,9 +1195,10 @@ mod conv_col_ultra_overflow {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * LEN * LEN * LEN, false);
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1776,13 +1781,18 @@ mod shuffle {
|
||||
|
||||
let d = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let e = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
|
||||
|
||||
let mut config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
|
||||
config
|
||||
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
|
||||
.configure_shuffles(
|
||||
cs,
|
||||
&[a.clone(), b.clone(), c.clone()],
|
||||
&[d.clone(), e.clone(), f.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
@@ -83,13 +83,15 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,10";
|
||||
/// Default use reduced srs for verification
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
/// Default commitment
|
||||
pub const DEFAULT_COMMITMENT: &str = "kzg";
|
||||
/// Default seed used to generate random data
|
||||
pub const DEFAULT_SEED: &str = "21242";
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
@@ -422,7 +424,21 @@ pub enum Commands {
|
||||
#[clap(flatten)]
|
||||
args: RunArgs,
|
||||
},
|
||||
|
||||
/// Generate random data for a model
|
||||
GenRandomData {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
/// The path to the .json data file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = crate::parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
|
||||
variables: Vec<(String, usize)>,
|
||||
/// random seed for reproducibility (optional)
|
||||
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
|
||||
seed: u64,
|
||||
},
|
||||
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
|
||||
CalibrateSettings {
|
||||
/// The path to the .json calibration data file.
|
||||
|
||||
@@ -488,7 +488,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
}
|
||||
}
|
||||
|
||||
let contract = match call_to_account {
|
||||
match call_to_account {
|
||||
Some(call) => {
|
||||
deploy_single_da_contract(
|
||||
client,
|
||||
@@ -514,10 +514,10 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
)
|
||||
.await
|
||||
}
|
||||
};
|
||||
return contract;
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn deploy_multi_da_contract(
|
||||
client: EthersClient,
|
||||
contract_instance_offset: usize,
|
||||
@@ -630,7 +630,7 @@ async fn deploy_single_da_contract(
|
||||
// bytes memory _callData,
|
||||
PackedSeqToken(call_data.as_ref()),
|
||||
// uint256 _decimals,
|
||||
WordToken(B256::from(decimals).into()),
|
||||
WordToken(B256::from(decimals)),
|
||||
// uint[] memory _scales,
|
||||
DynSeqToken(
|
||||
scales
|
||||
|
||||
@@ -65,6 +65,8 @@ use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use tabled::Tabled;
|
||||
use thiserror::Error;
|
||||
use tract_onnx::prelude::IntoTensor;
|
||||
use tract_onnx::prelude::Tensor as TractTensor;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -116,7 +118,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
} => gen_srs_cmd(
|
||||
srs_path,
|
||||
logrows as u32,
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
commitment.unwrap_or_else(|| Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
),
|
||||
Commands::GetSrs {
|
||||
srs_path,
|
||||
@@ -134,6 +136,17 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
args,
|
||||
),
|
||||
Commands::GenRandomData {
|
||||
model,
|
||||
data,
|
||||
variables,
|
||||
seed,
|
||||
} => gen_random_data(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
variables,
|
||||
seed,
|
||||
),
|
||||
Commands::CalibrateSettings {
|
||||
model,
|
||||
settings_path,
|
||||
@@ -828,6 +841,71 @@ pub(crate) fn gen_circuit_settings(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
/// Generate a circuit settings file
|
||||
pub(crate) fn gen_random_data(
|
||||
model_path: PathBuf,
|
||||
data_path: PathBuf,
|
||||
variables: Vec<(String, usize)>,
|
||||
seed: u64,
|
||||
) -> Result<String, EZKLError> {
|
||||
let mut file = std::fs::File::open(&model_path).map_err(|e| {
|
||||
crate::graph::errors::GraphError::ReadWriteFileError(
|
||||
model_path.display().to_string(),
|
||||
e.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (tract_model, _symbol_values) = Model::load_onnx_using_tract(&mut file, &variables)?;
|
||||
|
||||
let input_facts = tract_model
|
||||
.input_outlets()
|
||||
.map_err(|e| EZKLError::from(e.to_string()))?
|
||||
.iter()
|
||||
.map(|&i| tract_model.outlet_fact(i))
|
||||
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
|
||||
.map_err(|e| EZKLError::from(e.to_string()))?;
|
||||
|
||||
/// Generates a random tensor of a given size and type.
|
||||
fn random(
|
||||
sizes: &[usize],
|
||||
datum_type: tract_onnx::prelude::DatumType,
|
||||
seed: u64,
|
||||
) -> TractTensor {
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
|
||||
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
|
||||
let slice = tensor.as_slice_mut::<f32>().unwrap();
|
||||
slice.iter_mut().for_each(|x| *x = rng.gen());
|
||||
tensor.cast_to_dt(datum_type).unwrap().into_owned()
|
||||
}
|
||||
|
||||
fn tensor_for_fact(fact: &tract_onnx::prelude::TypedFact, seed: u64) -> TractTensor {
|
||||
if let Some(value) = &fact.konst {
|
||||
return value.clone().into_tensor();
|
||||
}
|
||||
|
||||
random(
|
||||
fact.shape
|
||||
.as_concrete()
|
||||
.expect("Expected concrete shape, found: {fact:?}"),
|
||||
fact.datum_type,
|
||||
seed,
|
||||
)
|
||||
}
|
||||
|
||||
let generated = input_facts
|
||||
.iter()
|
||||
.map(|v| tensor_for_fact(v, seed))
|
||||
.collect_vec();
|
||||
|
||||
let data = GraphData::from_tract_data(&generated)?;
|
||||
|
||||
data.save(data_path)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
// not for wasm targets
|
||||
pub(crate) fn init_spinner() -> ProgressBar {
|
||||
let pb = indicatif::ProgressBar::new_spinner();
|
||||
@@ -1457,7 +1535,8 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
trace!("params computed");
|
||||
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
let data =
|
||||
GraphData::from_path(input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
// The number of input and output instances we attest to for the single call data attestation
|
||||
let mut input_len = None;
|
||||
@@ -2048,6 +2127,7 @@ pub(crate) fn mock_aggregate(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
|
||||
@@ -5,10 +5,12 @@ use halo2curves::ff::PrimeField;
|
||||
/// Integer representation of a PrimeField element.
|
||||
pub type IntegerRep = i128;
|
||||
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
/// Converts an integer rep to a PrimeField element.
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else if x == IntegerRep::MIN {
|
||||
-F::from_u128(x.saturating_neg() as u128) - F::ONE
|
||||
} else {
|
||||
-F::from_u128(x.saturating_neg() as u128)
|
||||
}
|
||||
@@ -32,6 +34,9 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN;
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -51,7 +56,7 @@ mod test {
|
||||
use halo2curves::pasta::Fp as F;
|
||||
|
||||
#[test]
|
||||
fn test_conv() {
|
||||
fn integerreptofelt() {
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
@@ -69,8 +74,24 @@ mod test {
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmin() {
|
||||
let x = IntegerRep::MIN;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmax() {
|
||||
let x = IntegerRep::MAX;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,12 @@ pub enum GraphError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Non scalar power
|
||||
#[error("we only support scalar powers")]
|
||||
NonScalarPower,
|
||||
/// Non scalar base for exponentiation
|
||||
#[error("we only support scalar bases for exponentiation")]
|
||||
NonScalarBase,
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
@@ -113,13 +119,13 @@ pub enum GraphError {
|
||||
/// Missing input for a node
|
||||
#[error("missing input for node {0}")]
|
||||
MissingInput(usize),
|
||||
///
|
||||
/// Ranges can only be constant
|
||||
#[error("range only supports constant inputs in a zk circuit")]
|
||||
NonConstantRange,
|
||||
///
|
||||
/// Trilu diagonal must be constant
|
||||
#[error("trilu only supports constant diagonals in a zk circuit")]
|
||||
NonConstantTrilu,
|
||||
///
|
||||
/// The witness was too short
|
||||
#[error("insufficient witness values to generate a fixed output")]
|
||||
InsufficientWitnessValues,
|
||||
/// Missing scale
|
||||
@@ -143,4 +149,13 @@ pub enum GraphError {
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
/// Only nearest neighbor interpolation is supported
|
||||
#[error("only nearest neighbor interpolation is supported")]
|
||||
InvalidInterpolation,
|
||||
/// Node has a missing output
|
||||
#[error("node {0} has a missing output")]
|
||||
MissingOutput(usize),
|
||||
/// Inssuficient advice columns
|
||||
#[error("insuficcient advice columns (need {0} at least)")]
|
||||
InsufficientAdviceColumns(usize),
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
@@ -25,6 +24,7 @@ use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
|
||||
|
||||
@@ -32,30 +32,95 @@ type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
|
||||
///
|
||||
/// Represents different types of values that can be stored in a file source
|
||||
/// Used for handling various input types in zero-knowledge proofs
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum FileSourceInner {
|
||||
/// Inner elements of float inputs coming from a file
|
||||
/// Floating point value (64-bit)
|
||||
Float(f64),
|
||||
/// Inner elements of bool inputs coming from a file
|
||||
/// Boolean value
|
||||
Bool(bool),
|
||||
/// Inner elements of inputs coming from a witness
|
||||
/// Field element value for direct use in circuits
|
||||
Field(Fp),
|
||||
}
|
||||
|
||||
impl FileSourceInner {
|
||||
///
|
||||
/// Returns true if the value is a floating point number
|
||||
pub fn is_float(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Float(_))
|
||||
}
|
||||
///
|
||||
|
||||
/// Returns true if the value is a boolean
|
||||
pub fn is_bool(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Bool(_))
|
||||
}
|
||||
///
|
||||
|
||||
/// Returns true if the value is a field element
|
||||
pub fn is_field(&self) -> bool {
|
||||
matches!(self, FileSourceInner::Field(_))
|
||||
}
|
||||
|
||||
/// Creates a new floating point value
|
||||
pub fn new_float(f: f64) -> Self {
|
||||
FileSourceInner::Float(f)
|
||||
}
|
||||
|
||||
/// Creates a new field element value
|
||||
pub fn new_field(f: Fp) -> Self {
|
||||
FileSourceInner::Field(f)
|
||||
}
|
||||
|
||||
/// Creates a new boolean value
|
||||
pub fn new_bool(f: bool) -> Self {
|
||||
FileSourceInner::Bool(f)
|
||||
}
|
||||
|
||||
/// Adjusts the value according to the specified input type
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_type` - Type specification to convert the value to
|
||||
pub fn as_type(&mut self, input_type: &InputType) {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
||||
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
||||
FileSourceInner::Field(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the value to a field element using appropriate scaling
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `scale` - Scaling factor for floating point conversion
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
} else {
|
||||
Fp::zero()
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => *f,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the value to a floating point number
|
||||
pub fn to_float(&self) -> f64 {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => *f,
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for FileSourceInner {
|
||||
@@ -71,8 +136,8 @@ impl Serialize for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
// Deserialization implementation for FileSourceInner
|
||||
// Uses JSON deserialization to handle the different variants
|
||||
impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -99,70 +164,16 @@ impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
/// Elements of inputs coming from a file
|
||||
/// A collection of input values from a file source
|
||||
/// Organized as a vector of vectors where each inner vector represents a row/entry
|
||||
pub type FileSource = Vec<Vec<FileSourceInner>>;
|
||||
|
||||
impl FileSourceInner {
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_float(f: f64) -> Self {
|
||||
FileSourceInner::Float(f)
|
||||
}
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_field(f: Fp) -> Self {
|
||||
FileSourceInner::Field(f)
|
||||
}
|
||||
/// Create a new FileSourceInner
|
||||
pub fn new_bool(f: bool) -> Self {
|
||||
FileSourceInner::Bool(f)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn as_type(&mut self, input_type: &InputType) {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => input_type.roundtrip(f),
|
||||
FileSourceInner::Bool(_) => assert!(matches!(input_type, InputType::Bool)),
|
||||
FileSourceInner::Field(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => {
|
||||
integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap())
|
||||
}
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
} else {
|
||||
Fp::zero()
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => *f,
|
||||
}
|
||||
}
|
||||
/// Convert to a float
|
||||
pub fn to_float(&self) -> f64 {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => *f,
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Call type for attested inputs on-chain
|
||||
/// Represents different types of calls for fetching on-chain data
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum Calls {
|
||||
/// Vector of calls to accounts, each returning an attested data point
|
||||
/// Multiple calls to different accounts, each returning individual values
|
||||
Multiple(Vec<CallsToAccount>),
|
||||
/// Single call to account, returning an array of attested data points
|
||||
/// Single call returning an array of values
|
||||
Single(CallToAccount),
|
||||
}
|
||||
|
||||
@@ -171,32 +182,6 @@ impl Default for Calls {
|
||||
Calls::Multiple(Vec::new())
|
||||
}
|
||||
}
|
||||
/// Inner elements of inputs/outputs coming from on-chain
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Calls to accounts
|
||||
pub calls: Calls,
|
||||
/// RPC url
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Create a new OnChainSource with multiple calls
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new OnChainSource with a single call
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Calls {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
@@ -218,7 +203,6 @@ impl<'de> Deserialize<'de> for Calls {
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = multiple_try {
|
||||
return Ok(Calls::Multiple(t));
|
||||
@@ -228,111 +212,52 @@ impl<'de> Deserialize<'de> for Calls {
|
||||
return Ok(Calls::Single(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom(
|
||||
"failed to deserialize FileSourceInner",
|
||||
))
|
||||
Err(serde::de::Error::custom("failed to deserialize Calls"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Inner elements of inputs/outputs coming from postgres DB
|
||||
/// Configuration for accessing on-chain data sources
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
/// postgres host
|
||||
pub host: RPCUrl,
|
||||
/// user to connect to postgres
|
||||
pub user: String,
|
||||
/// password to connect to postgres
|
||||
pub password: String,
|
||||
/// query to execute
|
||||
pub query: String,
|
||||
/// dbname
|
||||
pub dbname: String,
|
||||
/// port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Create a new PostgresSource
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch data from postgres
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// clone to move into thread
|
||||
let user = self.user.clone();
|
||||
let host = self.host.clone();
|
||||
let query = self.query.clone();
|
||||
let dbname = self.dbname.clone();
|
||||
let port = self.port.clone();
|
||||
let password = self.password.clone();
|
||||
|
||||
let config = if password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
host, user, dbname, port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
host, user, dbname, port, password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[]).await? {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetch data from postgres and format it as a FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
pub struct OnChainSource {
|
||||
/// Call specifications for fetching data
|
||||
pub calls: Calls,
|
||||
/// RPC endpoint URL for accessing the chain
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Creates a new OnChainSource with multiple calls
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `calls` - Vector of call specifications
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new OnChainSource with a single call
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `call` - Call specification
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Create dummy local on-chain data to test the OnChain data source
|
||||
/// Creates test data for the OnChain data source
|
||||
/// Used for testing and development purposes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Sample file data to use
|
||||
/// * `scales` - Scaling factors for each input
|
||||
/// * `shapes` - Shapes of the input tensors
|
||||
/// * `rpc` - Optional RPC endpoint override
|
||||
pub async fn test_from_file_data(
|
||||
data: &FileSource,
|
||||
scales: Vec<crate::Scale>,
|
||||
@@ -399,48 +324,40 @@ impl OnChainSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines the view only calls to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
/// Specification for view-only calls to fetch on-chain data
|
||||
/// Used for data attestation in smart contract verification
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallsToAccount {
|
||||
/// A vector of tuples, where index 0 of tuples
|
||||
/// are the byte strings representing the ABI encoded function calls to
|
||||
/// read the data from the address. This call must return a single
|
||||
/// elementary type (<https://docs.soliditylang.org/en/v0.8.20/abi-spec.html#types>).
|
||||
/// The second index of the tuple is the number of decimals for f32 conversion.
|
||||
/// We don't support dynamic types currently.
|
||||
/// Vector of (call data, decimals) pairs
|
||||
/// call_data: ABI-encoded function call
|
||||
/// decimals: Number of decimal places for float conversion
|
||||
pub call_data: Vec<(Call, Decimals)>,
|
||||
/// Address of the contract to read the data from.
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
/// Defines a view only call to accounts to fetch the on-chain input data.
|
||||
/// This data will be included as part of the first elements in the publicInputs
|
||||
/// for the sol evm verifier and will be verifyWithDataAttestation.sol
|
||||
/// Specification for a single view-only call returning an array
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct CallToAccount {
|
||||
/// The call_data is a byte strings representing the ABI encoded function call to
|
||||
/// read the data from the address. This call must return a single array of integers that can be
|
||||
/// be safely cast to the int128 type in solidity.
|
||||
/// ABI-encoded function call data
|
||||
pub call_data: Call,
|
||||
/// The number of decimals for f32 conversion of all of the elements returned from the
|
||||
/// call.
|
||||
/// Number of decimal places for float conversion
|
||||
pub decimals: Decimals,
|
||||
/// Address of the contract to read the data from.
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
/// The number of elements returned from the call.
|
||||
/// Expected length of returned array
|
||||
pub len: usize,
|
||||
}
|
||||
/// Enum that defines source of the inputs/outputs to the EZKL model
|
||||
|
||||
/// Represents different sources of input/output data for the EZKL model
|
||||
#[derive(Clone, Debug, Serialize, PartialOrd, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum DataSource {
|
||||
/// .json File data source.
|
||||
/// Data from a JSON file containing arrays of values
|
||||
File(FileSource),
|
||||
/// On-chain data source. The first element is the calls to the account, and the second is the RPC url.
|
||||
/// Data fetched from blockchain contracts
|
||||
OnChain(OnChainSource),
|
||||
/// Postgres DB
|
||||
/// Data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DB(PostgresSource),
|
||||
}
|
||||
@@ -483,8 +400,7 @@ impl From<OnChainSource> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
// Note: Always use JSON serialization for untagged enums
|
||||
impl<'de> Deserialize<'de> for DataSource {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -492,15 +408,19 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
|
||||
// Try deserializing as FileSource first
|
||||
let first_try: Result<FileSource, _> = serde_json::from_str(this_json.get());
|
||||
|
||||
if let Ok(t) = first_try {
|
||||
return Ok(DataSource::File(t));
|
||||
}
|
||||
|
||||
// Try deserializing as OnChainSource
|
||||
let second_try: Result<OnChainSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = second_try {
|
||||
return Ok(DataSource::OnChain(t));
|
||||
}
|
||||
|
||||
// Try deserializing as PostgresSource if feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
|
||||
@@ -513,22 +433,29 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// Input to graph as a datasource
|
||||
/// Always use JSON serialization for GraphData. Seriously.
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq)]
|
||||
/// Container for input and output data for graph computations
|
||||
///
|
||||
/// Important: Always use JSON serialization for GraphData to handle enum variants correctly
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
|
||||
pub struct GraphData {
|
||||
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
||||
/// Input data for the model/graph
|
||||
/// Can be empty if inputs come from on-chain sources
|
||||
pub input_data: DataSource,
|
||||
/// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain).
|
||||
|
||||
/// Optional output data for the model/graph
|
||||
/// Can be empty if outputs come from on-chain sources
|
||||
pub output_data: Option<DataSource>,
|
||||
}
|
||||
|
||||
impl UnwindSafe for GraphData {}
|
||||
|
||||
impl GraphData {
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Convert the input data to tract data
|
||||
/// Converts the input data to tract's tensor format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `shapes` - Expected shapes for each input tensor
|
||||
/// * `datum_types` - Expected data types for each input
|
||||
pub fn to_tract_data(
|
||||
&self,
|
||||
shapes: &[Vec<usize>],
|
||||
@@ -557,7 +484,43 @@ impl GraphData {
|
||||
Ok(inputs)
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Converts tract tensor data into GraphData format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensors` - Array of tract tensors to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the converted tensor data
|
||||
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
|
||||
use tract_onnx::prelude::DatumType;
|
||||
|
||||
let mut input_data = vec![];
|
||||
for tensor in tensors {
|
||||
match tensor.datum_type() {
|
||||
tract_onnx::prelude::DatumType::Bool => {
|
||||
let tensor = tensor.to_array_view::<bool>()?;
|
||||
let tensor = tensor.iter().map(|e| FileSourceInner::Bool(*e)).collect();
|
||||
input_data.push(tensor);
|
||||
}
|
||||
_ => {
|
||||
let cast_tensor = tensor.cast_to_dt(DatumType::F64)?;
|
||||
let tensor = cast_tensor.to_array_view::<f64>()?;
|
||||
let tensor = tensor.iter().map(|e| FileSourceInner::Float(*e)).collect();
|
||||
input_data.push(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(GraphData {
|
||||
input_data: DataSource::File(input_data),
|
||||
output_data: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a new GraphData instance with given input data
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_data` - The input data source
|
||||
pub fn new(input_data: DataSource) -> Self {
|
||||
GraphData {
|
||||
input_data,
|
||||
@@ -565,7 +528,13 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the model input from a file
|
||||
/// Loads graph input data from a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path to the input file
|
||||
///
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the loaded data
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let reader = std::fs::File::open(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
@@ -579,23 +548,35 @@ impl GraphData {
|
||||
Ok(graph_input)
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
/// Saves the graph data to a file
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Path where to save the data
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let file = std::fs::File::create(path.clone()).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Splits the input data into multiple batches based on input shapes
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `input_shapes` - Vector of shapes for each input tensor
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of GraphData instances, one for each batch
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if:
|
||||
/// - Data is from on-chain source
|
||||
/// - Input size is not evenly divisible by batch size
|
||||
pub async fn split_into_batches(
|
||||
&self,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
// split input data into batches
|
||||
let mut batched_inputs = vec![];
|
||||
|
||||
let iterable = match self {
|
||||
@@ -619,10 +600,12 @@ impl GraphData {
|
||||
} => data.fetch_and_format_as_file().await?,
|
||||
};
|
||||
|
||||
// Process each input tensor according to its shape
|
||||
for (i, shape) in input_shapes.iter().enumerate() {
|
||||
// ensure the input is evenly divisible by batch_size
|
||||
let input_size = shape.clone().iter().product::<usize>();
|
||||
let input = &iterable[i];
|
||||
|
||||
// Validate input size is divisible by batch size
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
@@ -630,6 +613,8 @@ impl GraphData {
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Split input into batches
|
||||
let mut batches = vec![];
|
||||
for batch in input.chunks(input_size) {
|
||||
batches.push(batch.to_vec());
|
||||
@@ -637,18 +622,18 @@ impl GraphData {
|
||||
batched_inputs.push(batches);
|
||||
}
|
||||
|
||||
// now merge all the batches for each input into a vector of batches
|
||||
// first assert each input has the same number of batches
|
||||
// Merge batches across inputs
|
||||
let num_batches = if batched_inputs.is_empty() {
|
||||
0
|
||||
} else {
|
||||
let num_batches = batched_inputs[0].len();
|
||||
// Verify all inputs have same number of batches
|
||||
for input in batched_inputs.iter() {
|
||||
assert_eq!(input.len(), num_batches);
|
||||
}
|
||||
num_batches
|
||||
};
|
||||
// now merge the batches
|
||||
|
||||
let mut input_batches = vec![];
|
||||
for i in 0..num_batches {
|
||||
let mut batch = vec![];
|
||||
@@ -658,11 +643,12 @@ impl GraphData {
|
||||
input_batches.push(DataSource::File(batch));
|
||||
}
|
||||
|
||||
// Ensure at least one batch exists
|
||||
if input_batches.is_empty() {
|
||||
input_batches.push(DataSource::File(vec![vec![]]));
|
||||
}
|
||||
|
||||
// create a new GraphWitness for each batch
|
||||
// Create GraphData instance for each batch
|
||||
let batches = input_batches
|
||||
.into_iter()
|
||||
.map(GraphData::new)
|
||||
@@ -674,6 +660,7 @@ impl GraphData {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallsToAccount {
|
||||
/// Converts CallsToAccount to Python object
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("account", &self.address).unwrap();
|
||||
@@ -682,6 +669,165 @@ impl ToPyObject for CallsToAccount {
|
||||
}
|
||||
}
|
||||
|
||||
// Additional Python bindings for various types...
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_postgres_source_new() {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let source = PostgresSource::new(
|
||||
"localhost".to_string(),
|
||||
"5432".to_string(),
|
||||
"user".to_string(),
|
||||
"SELECT * FROM table".to_string(),
|
||||
"database".to_string(),
|
||||
"password".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(source.host, "localhost");
|
||||
assert_eq!(source.port, "5432");
|
||||
assert_eq!(source.user, "user");
|
||||
assert_eq!(source.query, "SELECT * FROM table");
|
||||
assert_eq!(source.dbname, "database");
|
||||
assert_eq!(source.password, "password");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
// Test backwards compatibility with old format
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
// Test serialization/deserialization of graph input
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_python_compat() {
|
||||
// Test compatibility with mclbn256 library serialization
|
||||
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
||||
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
||||
assert_eq!(format!("{:?}", source), original_addr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Source data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
/// Database host address
|
||||
pub host: RPCUrl,
|
||||
/// Database user name
|
||||
pub user: String,
|
||||
/// Database password
|
||||
pub password: String,
|
||||
/// SQL query to execute
|
||||
pub query: String,
|
||||
/// Database name
|
||||
pub dbname: String,
|
||||
/// Database port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Creates a new PostgreSQL data source
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches data from the PostgreSQL database
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// Configuration string
|
||||
let config = if self.password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
self.host, self.user, self.dbname, self.port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
self.host, self.user, self.dbname, self.port, self.password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
|
||||
// Extract rows from query
|
||||
for row in client.query(&self.query, &[]).await? {
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetches and formats data as FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallToAccount {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -716,6 +862,7 @@ impl ToPyObject for DataSource {
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DataSource::DB(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("host", &source.host).unwrap();
|
||||
@@ -740,69 +887,3 @@ impl ToPyObject for FileSourceInner {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for GraphData {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut state = serializer.serialize_struct("GraphData", 4)?;
|
||||
state.serialize_field("input_data", &self.input_data)?;
|
||||
state.serialize_field("output_data", &self.output_data)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
let source = DataSource::from(vec![vec![0.053_262_424, 0.074_970_566, 0.052_355_476]]);
|
||||
|
||||
let serialized = serde_json::to_string(&source).unwrap();
|
||||
|
||||
const JSON: &str = r#"[[0.053262424,0.074970566,0.052355476]]"#;
|
||||
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let expect = serde_json::from_str::<DataSource>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(expect, source);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// this is for backwards compatibility with the old format
|
||||
fn test_graph_input_serialization_round_trip() {
|
||||
let file = GraphData::new(DataSource::from(vec![vec![
|
||||
0.05326242372393608,
|
||||
0.07497056573629379,
|
||||
0.05235547572374344,
|
||||
]]));
|
||||
|
||||
let serialized = serde_json::to_string(&file).unwrap();
|
||||
|
||||
const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#;
|
||||
|
||||
assert_eq!(serialized, JSON);
|
||||
|
||||
let graph_input3 = serde_json::from_str::<GraphData>(JSON)
|
||||
.map_err(|e| e.to_string())
|
||||
.unwrap();
|
||||
assert_eq!(graph_input3, file);
|
||||
}
|
||||
|
||||
// test for the compatibility with the serialized elements from the mclbn256 library
|
||||
#[test]
|
||||
fn test_python_compat() {
|
||||
let source = Fp::from_raw([18445520602771460712, 838677322461845011, 3079992810, 0]);
|
||||
|
||||
let original_addr = "0x000000000000000000000000b794f5ea0ba39494ce839613fffba74279579268";
|
||||
|
||||
assert_eq!(format!("{:?}", source), original_addr);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +280,13 @@ impl GraphWitness {
|
||||
})?;
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
let witness: GraphWitness =
|
||||
serde_json::from_reader(reader).map_err(Into::<GraphError>::into)?;
|
||||
|
||||
// check versions match
|
||||
crate::check_version_string_matches(witness.version.as_deref().unwrap_or(""));
|
||||
|
||||
Ok(witness)
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
@@ -572,10 +578,14 @@ impl GraphSettings {
|
||||
// buf reader
|
||||
let reader =
|
||||
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
|
||||
serde_json::from_reader(reader).map_err(|e| {
|
||||
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
|
||||
error!("failed to load settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
})
|
||||
})?;
|
||||
|
||||
crate::check_version_string_matches(&settings.version);
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
@@ -609,11 +619,6 @@ impl GraphSettings {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn uses_modules(&self) -> bool {
|
||||
!self.module_sizes.max_constraints() > 0
|
||||
}
|
||||
|
||||
/// if any visibility is encrypted or hashed
|
||||
pub fn module_requires_fixed(&self) -> bool {
|
||||
self.run_args.input_visibility.is_hashed()
|
||||
@@ -697,6 +702,9 @@ impl GraphCircuit {
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let result: GraphCircuit = bincode::deserialize_from(reader)?;
|
||||
|
||||
// check the versions matche
|
||||
crate::check_version_string_matches(&result.core.settings.version);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -753,7 +761,7 @@ pub struct TestOnChainData {
|
||||
pub data: std::path::PathBuf,
|
||||
/// rpc endpoint
|
||||
pub rpc: Option<String>,
|
||||
///
|
||||
/// data sources for the on chain data
|
||||
pub data_sources: TestSources,
|
||||
}
|
||||
|
||||
@@ -941,7 +949,7 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
_ => unreachable!("cannot load from on-chain data"),
|
||||
_ => Err(GraphError::OnChainDataSource),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -384,8 +384,7 @@ pub struct ParsedNodes {
|
||||
impl ParsedNodes {
|
||||
/// Returns the number of the computational graph's inputs
|
||||
pub fn num_inputs(&self) -> usize {
|
||||
let input_nodes = self.inputs.iter();
|
||||
input_nodes.len()
|
||||
self.inputs.len()
|
||||
}
|
||||
|
||||
/// Input types
|
||||
@@ -425,8 +424,7 @@ impl ParsedNodes {
|
||||
|
||||
/// Returns the number of the computational graph's outputs
|
||||
pub fn num_outputs(&self) -> usize {
|
||||
let output_nodes = self.outputs.iter();
|
||||
output_nodes.len()
|
||||
self.outputs.len()
|
||||
}
|
||||
|
||||
/// Returns shapes of the computational graph's outputs
|
||||
@@ -621,19 +619,23 @@ impl Model {
|
||||
/// * `scale` - The scale to use for quantization.
|
||||
/// * `public_params` - Whether to make the params public.
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn load_onnx_using_tract(
|
||||
pub(crate) fn load_onnx_using_tract(
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
variables: &[(String, usize)],
|
||||
) -> Result<TractResult, GraphError> {
|
||||
use tract_onnx::tract_hir::internal::GenericFactoid;
|
||||
|
||||
let mut model = tract_onnx::onnx().model_for_read(reader)?;
|
||||
|
||||
let variables: std::collections::HashMap<String, usize> =
|
||||
std::collections::HashMap::from_iter(run_args.variables.clone());
|
||||
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
|
||||
|
||||
for (i, id) in model.clone().inputs.iter().enumerate() {
|
||||
let input = model.node_mut(id.node);
|
||||
|
||||
if input.outputs.len() == 0 {
|
||||
return Err(GraphError::MissingOutput(id.node));
|
||||
}
|
||||
let mut fact: InferenceFact = input.outputs[0].fact.clone();
|
||||
|
||||
for (i, x) in fact.clone().shape.dims().enumerate() {
|
||||
@@ -655,7 +657,7 @@ impl Model {
|
||||
}
|
||||
|
||||
let mut symbol_values = SymbolValues::default();
|
||||
for (symbol, value) in run_args.variables.iter() {
|
||||
for (symbol, value) in variables.iter() {
|
||||
let symbol = model.symbols.sym(symbol);
|
||||
symbol_values = symbol_values.with(&symbol, *value as i64);
|
||||
debug!("set {} to {}", symbol, value);
|
||||
@@ -683,7 +685,7 @@ impl Model {
|
||||
) -> Result<ParsedNodes, GraphError> {
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
|
||||
let (model, symbol_values) = Self::load_onnx_using_tract(reader, &run_args.variables)?;
|
||||
|
||||
let scales = VarScales::from_args(run_args);
|
||||
let nodes = Self::nodes_from_graph(
|
||||
@@ -964,7 +966,7 @@ impl Model {
|
||||
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
|
||||
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
|
||||
let (model, _) = Model::load_onnx_using_tract(&mut file, &run_args.variables)?;
|
||||
|
||||
let datum_types: Vec<DatumType> = model
|
||||
.input_outlets()?
|
||||
@@ -1016,6 +1018,10 @@ impl Model {
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
if vars.advices.len() < 3 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(3));
|
||||
}
|
||||
|
||||
let mut base_gate = PolyConfig::configure(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
@@ -1035,6 +1041,10 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_dynamic_lookup() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
|
||||
base_gate.configure_dynamic_lookup(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
@@ -1043,10 +1053,13 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_shuffle() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
base_gate.configure_shuffles(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
vars.advices[3..5].try_into()?,
|
||||
vars.advices[0..3].try_into()?,
|
||||
vars.advices[3..6].try_into()?,
|
||||
)?;
|
||||
}
|
||||
|
||||
@@ -1061,6 +1074,7 @@ impl Model {
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `witnessed_outputs` - The values to compare against.
|
||||
/// * `constants` - The constants for the circuit.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn layout(
|
||||
&self,
|
||||
mut config: ModelConfig,
|
||||
@@ -1226,6 +1240,7 @@ impl Model {
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
|
||||
let start = instant::Instant::now();
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
let res = if node.is_constant() && node.num_uses() == 1 {
|
||||
@@ -1363,6 +1378,7 @@ impl Model {
|
||||
results.insert(*idx, full_results);
|
||||
}
|
||||
}
|
||||
debug!("------------ layout of {} took {:?}", idx, start.elapsed());
|
||||
}
|
||||
|
||||
// we do this so we can support multiple passes of the same model and have deterministic results (Non-assigned inputs etc... etc...)
|
||||
@@ -1458,7 +1474,7 @@ impl Model {
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_felt_evals()
|
||||
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
.unwrap_or_else(|_| Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -284,7 +284,6 @@ impl GraphModules {
|
||||
log::error!("Poseidon config not initialized");
|
||||
return Err(Error::Synthesis);
|
||||
}
|
||||
// If the module is encrypted, then we need to encrypt the inputs
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
// Import dependencies for scaling operations
|
||||
use super::scale_to_multiplier;
|
||||
|
||||
// Import ONNX-specific utilities when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::utilities::node_output_shapes;
|
||||
|
||||
// Import scale management types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
|
||||
// Import visibility settings for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::Visibility;
|
||||
|
||||
// Import operation types for different circuit components
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
@@ -13,28 +22,49 @@ use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
|
||||
// Import graph error types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::errors::GraphError;
|
||||
|
||||
// Import ONNX operation conversion utilities
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
|
||||
// Import tensor error handling
|
||||
use crate::tensor::TensorError;
|
||||
|
||||
// Import curve-specific field type
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
|
||||
// Import logging for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use log::trace;
|
||||
|
||||
// Import serialization traits
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
// Import data structures for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
// Import formatting traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::fmt;
|
||||
|
||||
// Import table display formatting for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tabled::Tabled;
|
||||
|
||||
// Import ONNX-specific types and traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::{
|
||||
self,
|
||||
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
|
||||
};
|
||||
|
||||
/// Helper function to format vectors for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
if !v.is_empty() {
|
||||
@@ -44,29 +74,35 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to format operation kinds for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_opkind(v: &SupportedOp) -> String {
|
||||
v.as_string()
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
/// A wrapper for an operation that has been rescaled to handle different precision requirements.
|
||||
/// This enables operations to work with inputs that have been scaled to different fixed-point representations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Rescaled {
|
||||
/// The operation that has to be rescaled.
|
||||
/// The underlying operation that needs to be rescaled
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// The scale of the operation's inputs.
|
||||
/// Vector of (index, scale) pairs defining how each input should be scaled
|
||||
pub scale: Vec<(usize, u128)>,
|
||||
}
|
||||
|
||||
/// Implementation of the Op trait for Rescaled operations
|
||||
impl Op<Fp> for Rescaled {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
@@ -77,6 +113,7 @@ impl Op<Fp> for Rescaled {
|
||||
Op::<Fp>::out_scale(&*self.inner, in_scales)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -93,28 +130,40 @@ impl Op<Fp> for Rescaled {
|
||||
self.inner.layout(config, region, res)
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
/// A wrapper for operations that require scale rebasing
|
||||
/// This handles cases where operation scales need to be adjusted to a target scale
|
||||
/// while preserving the numerical relationships
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RebaseScale {
|
||||
/// The operation that has to be rescaled.
|
||||
/// The operation that needs to be rescaled
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// rebase op
|
||||
/// Operation used for rebasing, typically division
|
||||
pub rebase_op: HybridOp,
|
||||
/// scale being rebased to
|
||||
/// Scale that we're rebasing to
|
||||
pub target_scale: i32,
|
||||
/// The original scale of the operation's inputs.
|
||||
/// Original scale of operation's inputs before rebasing
|
||||
pub original_scale: i32,
|
||||
/// multiplier
|
||||
/// Scaling multiplier used in rebasing
|
||||
pub multiplier: f64,
|
||||
}
|
||||
|
||||
impl RebaseScale {
|
||||
/// Creates a rebased version of an operation if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `global_scale` - Base scale for the system
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
/// * `scale_rebase_multiplier` - Factor determining when rebasing should occur
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation depending on scale relationships
|
||||
pub fn rebase(
|
||||
inner: SupportedOp,
|
||||
global_scale: crate::Scale,
|
||||
@@ -155,7 +204,15 @@ impl RebaseScale {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a rebased operation with increased scale
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `target_scale` - Scale to rebase to
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation with increased scale
|
||||
pub fn rebase_up(
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
@@ -192,10 +249,12 @@ impl RebaseScale {
|
||||
}
|
||||
|
||||
impl Op<Fp> for RebaseScale {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
"REBASED (div={:?}, rebasing_op={}) ({})",
|
||||
@@ -205,10 +264,12 @@ impl Op<Fp> for RebaseScale {
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -222,34 +283,40 @@ impl Op<Fp> for RebaseScale {
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
/// Represents all supported operation types in the circuit
|
||||
/// Each variant encapsulates a different type of operation with specific behavior
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum SupportedOp {
|
||||
/// A linear operation.
|
||||
/// Linear operations (polynomial-based)
|
||||
Linear(PolyOp),
|
||||
/// A nonlinear operation.
|
||||
/// Nonlinear operations requiring lookup tables
|
||||
Nonlinear(LookupOp),
|
||||
/// A hybrid operation.
|
||||
/// Mixed operations combining different approaches
|
||||
Hybrid(HybridOp),
|
||||
///
|
||||
/// Input values to the circuit
|
||||
Input(Input),
|
||||
///
|
||||
/// Constant values in the circuit
|
||||
Constant(Constant<Fp>),
|
||||
///
|
||||
/// Placeholder for unsupported operations
|
||||
Unknown(Unknown),
|
||||
///
|
||||
/// Operations requiring rescaling of inputs
|
||||
Rescaled(Rescaled),
|
||||
///
|
||||
/// Operations requiring scale rebasing
|
||||
RebaseScale(RebaseScale),
|
||||
}
|
||||
|
||||
impl SupportedOp {
|
||||
/// Checks if the operation is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if operation requires lookup table
|
||||
/// * `false` otherwise
|
||||
pub fn is_lookup(&self) -> bool {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(_) => true,
|
||||
@@ -257,7 +324,12 @@ impl SupportedOp {
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns input operation if this is an input
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(Input)` if this is an input operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_input(&self) -> Option<Input> {
|
||||
match self {
|
||||
SupportedOp::Input(op) => Some(op.clone()),
|
||||
@@ -265,7 +337,11 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to rebased operation if this is a rebased operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&RebaseScale)` if this is a rebased operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_rebased(&self) -> Option<&RebaseScale> {
|
||||
match self {
|
||||
SupportedOp::RebaseScale(op) => Some(op),
|
||||
@@ -273,7 +349,11 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to lookup operation if this is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&LookupOp)` if this is a lookup operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_lookup(&self) -> Option<&LookupOp> {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(op) => Some(op),
|
||||
@@ -281,7 +361,11 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_constant(&self) -> Option<&Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -289,7 +373,11 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns mutable reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&mut Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_mutable_constant(&mut self) -> Option<&mut Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -297,18 +385,19 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a homogeneously rescaled version of this operation if needed
|
||||
/// Only available with EZKL feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn homogenous_rescale(
|
||||
&self,
|
||||
in_scales: Vec<crate::Scale>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
let op = self.clone_dyn();
|
||||
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
|
||||
}
|
||||
|
||||
/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
|
||||
/// Returns reference to underlying Op implementation
|
||||
fn as_op(&self) -> &dyn Op<Fp> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => op,
|
||||
@@ -322,9 +411,10 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// check if is the identity operation
|
||||
/// Checks if this is an identity operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if the operation is the identity operation
|
||||
/// * `true` if this operation passes input through unchanged
|
||||
/// * `false` otherwise
|
||||
pub fn is_identity(&self) -> bool {
|
||||
match self {
|
||||
@@ -361,9 +451,11 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
if let Some(op) = value.as_any().downcast_ref::<Unknown>() {
|
||||
return SupportedOp::Unknown(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<Rescaled>() {
|
||||
return SupportedOp::Rescaled(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<RebaseScale>() {
|
||||
return SupportedOp::RebaseScale(op.clone());
|
||||
};
|
||||
@@ -375,6 +467,7 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
/// Layout this operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -384,54 +477,61 @@ impl Op<Fp> for SupportedOp {
|
||||
self.as_op().layout(config, region, values)
|
||||
}
|
||||
|
||||
/// Check if this is an input operation
|
||||
fn is_input(&self) -> bool {
|
||||
self.as_op().is_input()
|
||||
}
|
||||
|
||||
/// Check if this is a constant operation
|
||||
fn is_constant(&self) -> bool {
|
||||
self.as_op().is_constant()
|
||||
}
|
||||
|
||||
/// Get which inputs require homogeneous scales
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
self.as_op().requires_homogenous_input_scales()
|
||||
}
|
||||
|
||||
/// Create a clone of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
self.as_op().clone_dyn()
|
||||
}
|
||||
|
||||
/// Get string representation
|
||||
fn as_string(&self) -> String {
|
||||
self.as_op().as_string()
|
||||
}
|
||||
|
||||
/// Convert to Any type
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Calculate output scale from input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
}
|
||||
|
||||
/// A node's input is a tensor from another node's output.
|
||||
/// Represents a connection to another node's output
|
||||
/// First element is node index, second is output slot index
|
||||
pub type Outlet = (usize, usize);
|
||||
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
/// Represents a single computational node in the circuit graph
|
||||
/// Contains all information needed to execute and connect operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Node {
|
||||
/// [Op] i.e what operation this node represents.
|
||||
/// The operation this node performs
|
||||
pub opkind: SupportedOp,
|
||||
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
|
||||
/// Fixed point scale factor for this node's output
|
||||
pub out_scale: i32,
|
||||
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
|
||||
// but in_dim is [in], out_dim is [out]
|
||||
/// The indices of the node's inputs.
|
||||
/// Connections to other nodes' outputs that serve as inputs
|
||||
pub inputs: Vec<Outlet>,
|
||||
/// Dimensions of output.
|
||||
/// Shape of this node's output tensor
|
||||
pub out_dims: Vec<usize>,
|
||||
/// The node's unique identifier.
|
||||
/// Unique identifier for this node
|
||||
pub idx: usize,
|
||||
/// The node's num of uses
|
||||
/// Number of times this node's output is used
|
||||
pub num_uses: usize,
|
||||
}
|
||||
|
||||
@@ -469,12 +569,19 @@ impl PartialEq for Node {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Converts a tract [OnnxNode] into an ezkl [Node].
|
||||
/// # Arguments:
|
||||
/// * `node` - [OnnxNode]
|
||||
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
|
||||
/// * `public_params` - flag if parameters of model are public
|
||||
/// * `idx` - The node's unique identifier.
|
||||
/// Creates a new Node from an ONNX node
|
||||
/// Only available when EZKL feature is enabled
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node` - Source ONNX node
|
||||
/// * `other_nodes` - Map of existing nodes in the graph
|
||||
/// * `scales` - Scale factors for variables
|
||||
/// * `idx` - Unique identifier for this node
|
||||
/// * `symbol_values` - ONNX symbol values
|
||||
/// * `run_args` - Runtime configuration arguments
|
||||
///
|
||||
/// # Returns
|
||||
/// New Node instance or error if creation fails
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
@@ -612,16 +719,14 @@ impl Node {
|
||||
})
|
||||
}
|
||||
|
||||
/// check if it is a softmax node
|
||||
/// Check if this node performs softmax operation
|
||||
pub fn is_softmax(&self) -> bool {
|
||||
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
matches!(self.opkind, SupportedOp::Hybrid(HybridOp::Softmax { .. }))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to rescale constants that are only used once
|
||||
/// Only available when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn rescale_const_with_single_use(
|
||||
constant: &mut Constant<Fp>,
|
||||
|
||||
@@ -44,11 +44,11 @@ use tract_onnx::tract_hir::{
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `vec` - the vector to quantize.
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `elem` - the element to quantize.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(
|
||||
@@ -59,7 +59,7 @@ pub fn quantize_float(
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
if *elem > max_value || *elem < -max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
|
||||
f64::powf(2., scale as f64)
|
||||
}
|
||||
|
||||
/// Converts a scale (log base 2) to a fixed point multiplier.
|
||||
/// Converts a fixed point multiplier to a scale (log base 2).
|
||||
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
mult.log2().round() as crate::Scale
|
||||
}
|
||||
@@ -142,8 +142,6 @@ use tract_onnx::prelude::SymbolValues;
|
||||
pub fn extract_tensor_value(
|
||||
input: Arc<tract_onnx::prelude::Tensor>,
|
||||
) -> Result<Tensor<f32>, GraphError> {
|
||||
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
|
||||
let dt = input.datum_type();
|
||||
let dims = input.shape().to_vec();
|
||||
|
||||
@@ -156,7 +154,7 @@ pub fn extract_tensor_value(
|
||||
match dt {
|
||||
DatumType::F16 => {
|
||||
let vec = input.as_slice::<tract_onnx::prelude::f16>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| (*x).into()).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| (*x).into()).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::F32 => {
|
||||
@@ -165,61 +163,61 @@ pub fn extract_tensor_value(
|
||||
}
|
||||
DatumType::F64 => {
|
||||
let vec = input.as_slice::<f64>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::I64 => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<i64>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::I32 => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<i32>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::I16 => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<i16>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::I8 => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<i8>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::U8 => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<u8>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::U16 => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<u16>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::U32 => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<u32>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::U64 => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<u64>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::Bool => {
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<bool>()?.to_vec();
|
||||
let cast: Vec<f32> = vec.par_iter().map(|x| *x as usize as f32).collect();
|
||||
let cast: Vec<f32> = vec.iter().map(|x| *x as usize as f32).collect();
|
||||
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
|
||||
}
|
||||
DatumType::TDim => {
|
||||
@@ -227,13 +225,10 @@ pub fn extract_tensor_value(
|
||||
let vec = input.as_slice::<tract_onnx::prelude::TDim>()?.to_vec();
|
||||
|
||||
let cast: Result<Vec<f32>, GraphError> = vec
|
||||
.par_iter()
|
||||
.iter()
|
||||
.map(|x| match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
},
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -314,6 +309,9 @@ pub fn new_op_from_onnx(
|
||||
let mut deleted_indices = vec![];
|
||||
let node = match node.op().name().as_ref() {
|
||||
"ShiftLeft" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -326,10 +324,13 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
|
||||
}
|
||||
}
|
||||
"ShiftRight" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -342,7 +343,7 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
@@ -365,7 +366,10 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
|
||||
if input_ops.len() != 3 {
|
||||
return Err(GraphError::InvalidDims(idx, "range".to_string()));
|
||||
}
|
||||
|
||||
let input_ops = input_ops
|
||||
.iter()
|
||||
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
|
||||
@@ -421,6 +425,10 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
|
||||
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| {
|
||||
@@ -449,8 +457,17 @@ pub fn new_op_from_onnx(
|
||||
"Topk" => {
|
||||
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
};
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
}
|
||||
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
c.raw_values.map(|x| x as usize)[0]
|
||||
@@ -490,6 +507,10 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -524,6 +545,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -557,6 +581,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -591,6 +618,9 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -686,7 +716,9 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
assert_eq!(axes.len(), 1, "only support argmax over one axis");
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
|
||||
}
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
|
||||
}
|
||||
@@ -696,7 +728,9 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
assert_eq!(axes.len(), 1, "only support argmin over one axis");
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
|
||||
}
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
|
||||
}
|
||||
@@ -805,6 +839,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
@@ -848,6 +885,9 @@ pub fn new_op_from_onnx(
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
@@ -935,7 +975,9 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
|
||||
let dt = op.to;
|
||||
|
||||
assert_eq!(input_scales.len(), 1);
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
|
||||
};
|
||||
|
||||
match dt {
|
||||
DatumType::Bool
|
||||
@@ -985,6 +1027,11 @@ pub fn new_op_from_onnx(
|
||||
|
||||
if const_idx.len() == 1 {
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if inputs.len() <= const_idx {
|
||||
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
|
||||
// if not divisible by 2 then we need to add a range check
|
||||
@@ -1059,6 +1106,9 @@ pub fn new_op_from_onnx(
|
||||
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
|
||||
}
|
||||
};
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
|
||||
}
|
||||
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
@@ -1098,22 +1148,42 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Ceil" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Floor" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Round" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "round".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"RoundHalfToEven" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
// Extract the slope layer hyperparams from a const
|
||||
@@ -1123,7 +1193,9 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
return Err(GraphError::NonScalarPower);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
@@ -1136,26 +1208,30 @@ pub fn new_op_from_onnx(
|
||||
a: crate::circuit::utils::F32(exponent),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if let Some(c) = inputs[0].opkind().get_mutable_constant() {
|
||||
inputs[0].decrement_use();
|
||||
deleted_indices.push(0);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar base")
|
||||
}
|
||||
|
||||
let base = c.raw_values[0];
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
scale: scale_to_multiplier(input_scales[1]).into(),
|
||||
base: base.into(),
|
||||
})
|
||||
} else {
|
||||
unimplemented!("only support constant base or pow for now")
|
||||
} else if let Some(c) = inputs[0].opkind().get_mutable_constant() {
|
||||
inputs[0].decrement_use();
|
||||
deleted_indices.push(0);
|
||||
if c.raw_values.len() > 1 {
|
||||
return Err(GraphError::NonScalarBase);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
|
||||
let base = c.raw_values[0];
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
scale: scale_to_multiplier(input_scales[1]).into(),
|
||||
base: base.into(),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
}
|
||||
}
|
||||
"Div" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -1163,14 +1239,15 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 {
|
||||
if const_idx.len() > 1 || const_idx.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if const_idx != 1 {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
@@ -1184,10 +1261,14 @@ pub fn new_op_from_onnx(
|
||||
denom: denom.into(),
|
||||
})
|
||||
} else {
|
||||
unimplemented!("only support non zero divisors of size 1")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support non zero divisors of size 1".to_string(),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
|
||||
@@ -1327,7 +1408,7 @@ pub fn new_op_from_onnx(
|
||||
if !resize_node.contains("interpolator: Nearest")
|
||||
&& !resize_node.contains("nearest: Floor")
|
||||
{
|
||||
unimplemented!("Only nearest neighbor interpolation is supported")
|
||||
return Err(GraphError::InvalidInterpolation);
|
||||
}
|
||||
// check if optional scale factor is present
|
||||
if inputs.len() != 2 && inputs.len() != 3 {
|
||||
@@ -1431,6 +1512,10 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Linear(PolyOp::Reshape(output_shape))
|
||||
}
|
||||
"Flatten" => {
|
||||
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
|
||||
};
|
||||
|
||||
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
|
||||
SupportedOp::Linear(PolyOp::Flatten(new_dims))
|
||||
}
|
||||
@@ -1504,12 +1589,10 @@ pub fn homogenize_input_scales(
|
||||
input_scales: Vec<crate::Scale>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let relevant_input_scales = input_scales
|
||||
.clone()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| inputs_to_scale.contains(idx))
|
||||
.map(|(_, scale)| scale)
|
||||
let relevant_input_scales = inputs_to_scale
|
||||
.iter()
|
||||
.filter(|idx| input_scales.len() > **idx)
|
||||
.map(|&idx| input_scales[idx])
|
||||
.collect_vec();
|
||||
|
||||
if inputs_to_scale.is_empty() {
|
||||
@@ -1550,10 +1633,30 @@ pub fn homogenize_input_scales(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// tests for the utility module
|
||||
pub mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
// quantization tests
|
||||
#[test]
|
||||
fn test_quantize_tensor() {
|
||||
let tensor: Tensor<f32> = (0..10).map(|x| x as f32).into();
|
||||
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
let scale = 0;
|
||||
let visibility = &Visibility::Public;
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
|
||||
assert_eq!(quantized.len(), 10);
|
||||
assert_eq!(quantized, reference);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_edge_cases() {
|
||||
assert_eq!(quantize_float(&f64::NAN, 0.0, 0).unwrap(), 0);
|
||||
assert!(quantize_float(&f64::INFINITY, 0.0, 0).is_err());
|
||||
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flatten_valtensors() {
|
||||
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
|
||||
@@ -11,35 +11,34 @@ use log::debug;
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
|
||||
/// Defines the visibility level of values within the zero-knowledge circuit
|
||||
/// Controls how values are handled during proof generation and verification
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
|
||||
pub enum Visibility {
|
||||
/// Mark an item as private to the prover (not in the proof submitted for verification)
|
||||
/// Value is private to the prover and not included in proof
|
||||
#[default]
|
||||
Private,
|
||||
/// Mark an item as public (sent in the proof submitted for verification)
|
||||
/// Value is public and included in proof for verification
|
||||
Public,
|
||||
/// Mark an item as publicly committed to (hash sent in the proof submitted for verification)
|
||||
/// Value is hashed and the hash is included in proof
|
||||
Hashed {
|
||||
/// Whether the hash is used as an instance (sent in the proof submitted for verification)
|
||||
/// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph
|
||||
/// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph
|
||||
/// Controls how the hash is handled in proof
|
||||
/// true - hash is included directly in proof (public)
|
||||
/// false - hash is used as advice and passed to computational graph
|
||||
hash_is_public: bool,
|
||||
///
|
||||
/// Specifies which outputs this hash affects
|
||||
outlets: Vec<usize>,
|
||||
},
|
||||
/// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification)
|
||||
/// Value is committed using KZG commitment scheme
|
||||
KZGCommit,
|
||||
/// assigned as a constant in the circuit
|
||||
/// Value is assigned as a constant in the circuit
|
||||
Fixed,
|
||||
}
|
||||
|
||||
@@ -66,15 +65,17 @@ impl Display for Visibility {
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Visibility {
|
||||
/// Converts visibility to command line flags
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
/// Converts string representation to Visibility
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
// split on last occurrence of '/'
|
||||
// Split on last occurrence of '/'
|
||||
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -106,8 +107,8 @@ impl<'a> From<&'a str> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Visibility {
|
||||
/// Converts Visibility to Python object
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
Visibility::Private => "private".to_object(py),
|
||||
@@ -134,14 +135,13 @@ impl IntoPy<PyObject> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Visibility {
|
||||
/// Extracts Visibility from Python object
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
let strval = strval.as_str();
|
||||
|
||||
if strval.contains("hashed/private") {
|
||||
// split on last occurence of '/'
|
||||
let (_, outlets) = strval.split_at(strval.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -174,29 +174,32 @@ impl<'source> FromPyObject<'source> for Visibility {
|
||||
}
|
||||
|
||||
impl Visibility {
|
||||
#[allow(missing_docs)]
|
||||
/// Returns true if visibility is Fixed
|
||||
pub fn is_fixed(&self) -> bool {
|
||||
matches!(&self, Visibility::Fixed)
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns true if visibility is Private or hashed private
|
||||
pub fn is_private(&self) -> bool {
|
||||
matches!(&self, Visibility::Private) || self.is_hashed_private()
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// Returns true if visibility is Public
|
||||
pub fn is_public(&self) -> bool {
|
||||
matches!(&self, Visibility::Public)
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns true if visibility involves hashing
|
||||
pub fn is_hashed(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. })
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns true if visibility uses KZG commitment
|
||||
pub fn is_polycommit(&self) -> bool {
|
||||
matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// Returns true if visibility is hashed with public hash
|
||||
pub fn is_hashed_public(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: true,
|
||||
@@ -207,7 +210,8 @@ impl Visibility {
|
||||
}
|
||||
false
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns true if visibility is hashed with private hash
|
||||
pub fn is_hashed_private(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: false,
|
||||
@@ -219,11 +223,12 @@ impl Visibility {
|
||||
false
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// Returns true if visibility requires additional processing
|
||||
pub fn requires_processing(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. }) | matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
|
||||
/// Returns vector of output indices that this visibility setting affects
|
||||
pub fn overwrites_inputs(&self) -> Vec<usize> {
|
||||
if let Visibility::Hashed { outlets, .. } = self {
|
||||
return outlets.clone();
|
||||
@@ -232,14 +237,14 @@ impl Visibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the scale of the model input, model parameters.
|
||||
/// Manages scaling factors for different parts of the model
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarScales {
|
||||
///
|
||||
/// Scale factor for input values
|
||||
pub input: crate::Scale,
|
||||
///
|
||||
/// Scale factor for parameter values
|
||||
pub params: crate::Scale,
|
||||
///
|
||||
/// Multiplier for scale rebasing
|
||||
pub rebase_multiplier: u32,
|
||||
}
|
||||
|
||||
@@ -250,17 +255,17 @@ impl std::fmt::Display for VarScales {
|
||||
}
|
||||
|
||||
impl VarScales {
|
||||
///
|
||||
/// Returns maximum scale value
|
||||
pub fn get_max(&self) -> crate::Scale {
|
||||
std::cmp::max(self.input, self.params)
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns minimum scale value
|
||||
pub fn get_min(&self) -> crate::Scale {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Place in [VarScales] struct.
|
||||
/// Creates VarScales from runtime arguments
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
@@ -270,16 +275,17 @@ impl VarScales {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Controls visibility settings for different parts of the model
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarVisibility {
|
||||
/// Input to the model or computational graph
|
||||
/// Visibility of model inputs
|
||||
pub input: Visibility,
|
||||
/// Parameters, such as weights and biases, in the model
|
||||
/// Visibility of model parameters (weights, biases)
|
||||
pub params: Visibility,
|
||||
/// Output of the model or computational graph
|
||||
/// Visibility of model outputs
|
||||
pub output: Visibility,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VarVisibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
@@ -301,8 +307,7 @@ impl Default for VarVisibility {
|
||||
}
|
||||
|
||||
impl VarVisibility {
|
||||
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Place in [VarVisibility] struct.
|
||||
/// Creates visibility settings from runtime arguments
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let input_vis = &args.input_visibility;
|
||||
let params_vis = &args.param_visibility;
|
||||
@@ -313,17 +318,17 @@ impl VarVisibility {
|
||||
}
|
||||
|
||||
if !output_vis.is_public()
|
||||
& !params_vis.is_public()
|
||||
& !input_vis.is_public()
|
||||
& !output_vis.is_fixed()
|
||||
& !params_vis.is_fixed()
|
||||
& !input_vis.is_fixed()
|
||||
& !output_vis.is_hashed()
|
||||
& !params_vis.is_hashed()
|
||||
& !input_vis.is_hashed()
|
||||
& !output_vis.is_polycommit()
|
||||
& !params_vis.is_polycommit()
|
||||
& !input_vis.is_polycommit()
|
||||
&& !params_vis.is_public()
|
||||
&& !input_vis.is_public()
|
||||
&& !output_vis.is_fixed()
|
||||
&& !params_vis.is_fixed()
|
||||
&& !input_vis.is_fixed()
|
||||
&& !output_vis.is_hashed()
|
||||
&& !params_vis.is_hashed()
|
||||
&& !input_vis.is_hashed()
|
||||
&& !output_vis.is_polycommit()
|
||||
&& !params_vis.is_polycommit()
|
||||
&& !input_vis.is_polycommit()
|
||||
{
|
||||
return Err(GraphError::Visibility);
|
||||
}
|
||||
@@ -335,17 +340,17 @@ impl VarVisibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for holding all columns that will be assigned to by a model.
|
||||
/// Container for circuit columns used by a model
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
|
||||
#[allow(missing_docs)]
|
||||
/// Advice columns for circuit assignments
|
||||
pub advices: Vec<VarTensor>,
|
||||
#[allow(missing_docs)]
|
||||
/// Optional instance column for public inputs
|
||||
pub instance: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
/// Get instance col
|
||||
/// Gets reference to instance column if it exists
|
||||
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
|
||||
if let Some(instance) = &self.instance {
|
||||
match instance {
|
||||
@@ -357,14 +362,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the initial instance offset
|
||||
/// Sets initial offset for instance values
|
||||
pub fn set_initial_instance_offset(&mut self, offset: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_initial_instance_offset(offset);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the total instance len
|
||||
/// Gets total length of instance data
|
||||
pub fn get_instance_len(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_total_instance_len()
|
||||
@@ -373,21 +378,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment the instance offset
|
||||
/// Increments instance index
|
||||
pub fn increment_instance_idx(&mut self) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.increment_idx();
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the instance offset
|
||||
/// Sets instance index to specific value
|
||||
pub fn set_instance_idx(&mut self, val: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_idx(val);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the instance offset
|
||||
/// Gets current instance index
|
||||
pub fn get_instance_idx(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_idx()
|
||||
@@ -396,7 +401,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Initializes instance column with specified dimensions and scale
|
||||
pub fn instantiate_instance(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -417,7 +422,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
};
|
||||
}
|
||||
|
||||
/// Allocate all columns that will be assigned to by a model.
|
||||
/// Creates new ModelVars with allocated columns based on settings
|
||||
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
@@ -435,7 +440,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
.collect_vec();
|
||||
|
||||
if requires_dynamic_lookup || requires_shuffle {
|
||||
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
|
||||
let num_cols = 3;
|
||||
for _ in 0..num_cols {
|
||||
let dynamic_lookup =
|
||||
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
|
||||
|
||||
340
src/lib.rs
340
src/lib.rs
@@ -28,6 +28,9 @@
|
||||
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
use log::warn;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc as _;
|
||||
|
||||
/// Error type
|
||||
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
|
||||
@@ -99,7 +102,7 @@ use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::Visibility;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
};
|
||||
@@ -165,7 +168,6 @@ pub mod srs_sha;
|
||||
pub mod tensor;
|
||||
#[cfg(feature = "ios-bindings")]
|
||||
uniffi::setup_scaffolding!();
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -180,11 +182,9 @@ lazy_static! {
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
|
||||
}
|
||||
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
@@ -266,76 +266,96 @@ impl From<String> for Commitments {
|
||||
}
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
///
|
||||
/// RunArgs contains all configuration parameters needed to control the proving process,
|
||||
/// including scaling factors, visibility settings, and circuit parameters.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(Args, ToFlags)
|
||||
)]
|
||||
pub struct RunArgs {
|
||||
/// The tolerance for error on model outputs
|
||||
/// Error tolerance for model outputs
|
||||
/// Only applicable when outputs are public
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
|
||||
pub tolerance: Tolerance,
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
/// Fixed point scaling factor for quantizing inputs
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub input_scale: Scale,
|
||||
/// The denominator in the fixed point representation used when quantizing parameters
|
||||
/// Fixed point scaling factor for quantizing parameters
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub param_scale: Scale,
|
||||
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
/// Scale rebase threshold multiplier
|
||||
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
|
||||
/// Advanced parameter that should be used with caution
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
/// Range for lookup table input column values
|
||||
/// Specified as (min, max) pair
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
/// Log2 of the number of rows in the circuit
|
||||
/// Controls circuit size and proving time
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
|
||||
pub logrows: u32,
|
||||
/// The log_2 number of rows
|
||||
/// Number of inner columns per block
|
||||
/// Affects circuit layout and efficiency
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
pub num_inner_cols: usize,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
/// Graph variables for parameterizing the computation
|
||||
/// Format: "name->value", e.g. "batch_size->1"
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Flags whether inputs are public, private, fixed, hashed, polycommit
|
||||
/// Visibility setting for input values
|
||||
/// Controls whether inputs are public or private in the circuit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub input_visibility: Visibility,
|
||||
/// Flags whether outputs are public, private, fixed, hashed, polycommit
|
||||
/// Visibility setting for output values
|
||||
/// Controls whether outputs are public or private in the circuit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
|
||||
pub output_visibility: Visibility,
|
||||
/// Flags whether params are fixed, private, hashed, polycommit
|
||||
/// Visibility setting for parameters
|
||||
/// Controls how parameters are handled in the circuit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub param_visibility: Visibility,
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
/// Whether to rebase constants with zero fractional part to scale 0
|
||||
/// Can improve efficiency for integer constants
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
/// check mode (safe, unsafe, etc)
|
||||
/// Circuit checking mode
|
||||
/// Controls level of constraint verification
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
|
||||
pub check_mode: CheckMode,
|
||||
/// commitment scheme
|
||||
/// Commitment scheme for circuit proving
|
||||
/// Affects proof size and verification time
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
|
||||
pub commitment: Option<Commitments>,
|
||||
/// the base used for decompositions
|
||||
/// Base for number decomposition
|
||||
/// Must be a power of 2
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
|
||||
pub decomp_base: usize,
|
||||
/// Number of decomposition legs
|
||||
/// Controls decomposition granularity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
/// Whether to use bounded lookup for logarithm computation
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
/// Creates a new RunArgs instance with default values
|
||||
///
|
||||
/// Default configuration is optimized for common use cases
|
||||
/// while maintaining reasonable proving time and circuit size
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
@@ -360,49 +380,132 @@ impl Default for RunArgs {
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
/// Validates the RunArgs configuration
|
||||
///
|
||||
/// Performs comprehensive validation of all parameters to ensure they are within
|
||||
/// acceptable ranges and follow required constraints. Returns accumulated errors
|
||||
/// if any validations fail.
|
||||
///
|
||||
/// # Returns
|
||||
/// - Ok(()) if all validations pass
|
||||
/// - Err(String) with detailed error message if any validation fails
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Visibility validations
|
||||
if self.param_visibility == Visibility::Public {
|
||||
return Err(
|
||||
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
|
||||
.into(),
|
||||
errors.push(
|
||||
"Parameters cannot be public instances. Use 'fixed' or 'kzgcommit' instead"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
return Err("scale_rebase_multiplier must be >= 1".into());
|
||||
}
|
||||
if self.lookup_range.0 > self.lookup_range.1 {
|
||||
return Err("lookup_range min is greater than max".into());
|
||||
}
|
||||
if self.logrows < 1 {
|
||||
return Err("logrows must be >= 1".into());
|
||||
}
|
||||
if self.num_inner_cols < 1 {
|
||||
return Err("num_inner_cols must be >= 1".into());
|
||||
}
|
||||
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
return Err("tolerance > 0.0 requires output_visibility to be public".into());
|
||||
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
|
||||
}
|
||||
|
||||
// Scale validations
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
errors.push("scale_rebase_multiplier must be >= 1".to_string());
|
||||
}
|
||||
|
||||
// if any of the scales are too small
|
||||
if self.input_scale < 8 || self.param_scale < 8 {
|
||||
warn!("low scale values (<8) may impact precision");
|
||||
}
|
||||
|
||||
// Lookup range validations
|
||||
if self.lookup_range.0 > self.lookup_range.1 {
|
||||
errors.push(format!(
|
||||
"Invalid lookup range: min ({}) is greater than max ({})",
|
||||
self.lookup_range.0, self.lookup_range.1
|
||||
));
|
||||
}
|
||||
|
||||
// Size validations
|
||||
if self.logrows < 1 {
|
||||
errors.push("logrows must be >= 1".to_string());
|
||||
}
|
||||
|
||||
if self.num_inner_cols < 1 {
|
||||
errors.push("num_inner_cols must be >= 1".to_string());
|
||||
}
|
||||
|
||||
let batch_size = self.variables.iter().find(|(name, _)| name == "batch_size");
|
||||
if let Some(batch_size) = batch_size {
|
||||
if batch_size.1 == 0 {
|
||||
errors.push("'batch_size' cannot be 0".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Decomposition validations
|
||||
if self.decomp_base == 0 {
|
||||
errors.push("decomp_base cannot be 0".to_string());
|
||||
}
|
||||
|
||||
if self.decomp_legs == 0 {
|
||||
errors.push("decomp_legs cannot be 0".to_string());
|
||||
}
|
||||
|
||||
// Performance validations
|
||||
if self.logrows > MAX_PUBLIC_SRS {
|
||||
warn!("logrows exceeds maximum public SRS size");
|
||||
}
|
||||
|
||||
// Validate tolerance is non-negative
|
||||
if self.tolerance.val < 0.0 {
|
||||
errors.push("tolerance cannot be negative".to_string());
|
||||
}
|
||||
|
||||
// Performance warnings
|
||||
if self.input_scale > 20 || self.param_scale > 20 {
|
||||
warn!("High scale values (>20) may impact performance");
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors.join("\n"))
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
/// Exports the configuration as JSON
|
||||
///
|
||||
/// Serializes the RunArgs instance to a JSON string
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(String)` containing JSON representation
|
||||
/// * `Err` if serialization fails
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
let res = serde_json::to_string(&self)?;
|
||||
Ok(res)
|
||||
}
|
||||
/// Parse an ezkl configuration from a json
|
||||
|
||||
/// Parses configuration from JSON
|
||||
///
|
||||
/// Deserializes a RunArgs instance from a JSON string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `arg_json` - JSON string containing configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(RunArgs)` if parsing succeeds
|
||||
/// * `Err` if parsing fails
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(arg_json)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single key-value pair
|
||||
// Additional helper functions for the module
|
||||
|
||||
/// Parses a key-value pair from a string in the format "key->value"
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `s` - Input string in the format "key->value"
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok((T, U))` - Parsed key and value
|
||||
/// * `Err` - If parsing fails
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
@@ -415,8 +518,131 @@ where
|
||||
{
|
||||
let pos = s
|
||||
.find("->")
|
||||
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
|
||||
let a = s[..pos].parse()?;
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
.ok_or_else(|| format!("invalid KEY->VALUE: no `->` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 2..].parse()?))
|
||||
}
|
||||
|
||||
/// Verifies that a version string matches the expected artifact version
|
||||
/// Logs warnings for version mismatches or unversioned artifacts
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `artifact_version` - Version string from the artifact
|
||||
pub fn check_version_string_matches(artifact_version: &str) {
|
||||
if artifact_version == "0.0.0"
|
||||
|| artifact_version == "source - no compatibility guaranteed"
|
||||
|| artifact_version.is_empty()
|
||||
{
|
||||
log::warn!("Artifact version is 0.0.0, skipping version check");
|
||||
return;
|
||||
}
|
||||
|
||||
let version = crate::version();
|
||||
|
||||
if version == "source - no compatibility guaranteed" {
|
||||
log::warn!("Compiled source version is not guaranteed to match artifact version");
|
||||
return;
|
||||
}
|
||||
|
||||
if version != artifact_version {
|
||||
log::warn!(
|
||||
"Version mismatch: CLI version is {} but artifact version is {}",
|
||||
version,
|
||||
artifact_version
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::field_reassign_with_default)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_valid_default_args() {
|
||||
let args = RunArgs::default();
|
||||
assert!(args.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_param_visibility() {
|
||||
let mut args = RunArgs::default();
|
||||
args.param_visibility = Visibility::Public;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Parameters cannot be public instances"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_scale_rebase() {
|
||||
let mut args = RunArgs::default();
|
||||
args.scale_rebase_multiplier = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("scale_rebase_multiplier must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_lookup_range() {
|
||||
let mut args = RunArgs::default();
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Invalid lookup range"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_logrows() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("logrows must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_inner_cols() {
|
||||
let mut args = RunArgs::default();
|
||||
args.num_inner_cols = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("num_inner_cols must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = 1.0;
|
||||
args.output_visibility = Visibility::Private;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negative_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = -1.0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("tolerance cannot be negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_batch_size() {
|
||||
let mut args = RunArgs::default();
|
||||
args.variables = vec![("batch_size".to_string(), 0)];
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("'batch_size' cannot be 0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_serialization() {
|
||||
let args = RunArgs::default();
|
||||
let json = args.as_json().unwrap();
|
||||
let deserialized = RunArgs::from_json(&json).unwrap();
|
||||
assert_eq!(args, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_validation_errors() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
// Should contain multiple error messages
|
||||
assert!(err.matches("\n").count() >= 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,7 +133,6 @@ pub fn aggregate<'a>(
|
||||
.collect_vec()
|
||||
}));
|
||||
|
||||
// loader.ctx().constrain_equal(cell_0, cell_1)
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
|
||||
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
@@ -309,11 +308,11 @@ impl AggregationCircuit {
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
/// Number of limbs used for decomposition
|
||||
pub fn num_limbs() -> usize {
|
||||
LIMBS
|
||||
}
|
||||
///
|
||||
/// Number of bits used for decomposition
|
||||
pub fn num_bits() -> usize {
|
||||
BITS
|
||||
}
|
||||
|
||||
@@ -353,6 +353,7 @@ where
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
{
|
||||
/// Create a new application snark from proof and instance variables ready for aggregation
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
protocol: Option<PlonkProtocol<C>>,
|
||||
instances: Vec<Vec<F>>,
|
||||
@@ -528,7 +529,6 @@ pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
@@ -794,7 +794,6 @@ pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
@@ -817,11 +816,11 @@ pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
debug!("loading proving key from {:?}", path);
|
||||
let start = instant::Instant::now();
|
||||
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
@@ -830,7 +829,8 @@ where
|
||||
params,
|
||||
)
|
||||
.map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
|
||||
info!("loaded proving key ✅");
|
||||
let elapsed = start.elapsed();
|
||||
info!("loaded proving key in {:?}", elapsed);
|
||||
Ok(pk)
|
||||
}
|
||||
|
||||
|
||||
@@ -38,4 +38,10 @@ pub enum TensorError {
|
||||
/// Decomposition error
|
||||
#[error("decomposition error: {0}")]
|
||||
DecompositionError(#[from] DecompositionError),
|
||||
/// Invalid argument
|
||||
#[error("invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
/// Index out of bounds
|
||||
#[error("index {0} out of bounds for dimension {1}")]
|
||||
IndexOutOfBounds(usize, usize),
|
||||
}
|
||||
|
||||
@@ -24,9 +24,6 @@ use std::path::PathBuf;
|
||||
pub use val::*;
|
||||
pub use var::*;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use instant::Instant;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
@@ -40,8 +37,6 @@ use halo2_proofs::{
|
||||
poly::Rotation,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
#[cfg(feature = "metal")]
|
||||
use metal::{Device, MTLResourceOptions, MTLSize};
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::io::Read;
|
||||
@@ -49,31 +44,6 @@ use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
lazy_static::lazy_static! {
|
||||
static ref DEVICE: Device = Device::system_default().expect("no device found");
|
||||
|
||||
static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap();
|
||||
|
||||
static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue();
|
||||
|
||||
static ref PIPELINES: HashMap<String, metal::ComputePipelineState> = {
|
||||
let mut map = HashMap::new();
|
||||
for name in ["add", "sub", "mul"] {
|
||||
let function = LIB.get_function(name, None).unwrap();
|
||||
let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap();
|
||||
map.insert(name.to_string(), pipeline);
|
||||
}
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
pub trait TensorType: Clone + Debug + 'static {
|
||||
/// Returns the zero value.
|
||||
@@ -638,42 +608,44 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
if indices.is_empty() {
|
||||
// Fast path: empty indices or full tensor slice
|
||||
if indices.is_empty()
|
||||
|| indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims
|
||||
{
|
||||
return Ok(self.clone());
|
||||
}
|
||||
|
||||
// Validate dimensions
|
||||
if self.dims.len() < indices.len() {
|
||||
return Err(TensorError::DimError(format!(
|
||||
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
|
||||
indices, self.dims
|
||||
)));
|
||||
} else if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims {
|
||||
// else if slice is the same as dims, return self
|
||||
return Ok(self.clone());
|
||||
}
|
||||
|
||||
// if indices weren't specified we fill them in as required
|
||||
let mut full_indices = indices.to_vec();
|
||||
// Pre-allocate the full indices vector with capacity
|
||||
let mut full_indices = Vec::with_capacity(self.dims.len());
|
||||
full_indices.extend_from_slice(indices);
|
||||
|
||||
for i in 0..(self.dims.len() - indices.len()) {
|
||||
full_indices.push(0..self.dims()[indices.len() + i])
|
||||
}
|
||||
// Fill remaining dimensions
|
||||
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
|
||||
|
||||
let cartesian_coord: Vec<Vec<usize>> = full_indices
|
||||
// Pre-calculate total size and allocate result vector
|
||||
let total_size: usize = full_indices
|
||||
.iter()
|
||||
.cloned()
|
||||
.multi_cartesian_product()
|
||||
.collect();
|
||||
|
||||
let res: Vec<T> = cartesian_coord
|
||||
.par_iter()
|
||||
.map(|e| {
|
||||
let index = self.get_index(e);
|
||||
self[index].clone()
|
||||
})
|
||||
.collect();
|
||||
.map(|range| range.end - range.start)
|
||||
.product();
|
||||
let mut res = Vec::with_capacity(total_size);
|
||||
|
||||
// Calculate new dimensions once
|
||||
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
|
||||
|
||||
// Use iterator directly without collecting into intermediate Vec
|
||||
for coord in full_indices.iter().cloned().multi_cartesian_product() {
|
||||
let index = self.get_index(&coord);
|
||||
res.push(self[index].clone());
|
||||
}
|
||||
|
||||
Tensor::new(Some(&res), &dims)
|
||||
}
|
||||
|
||||
@@ -831,7 +803,13 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner: Vec<T> = vec![];
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot duplicate every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
|
||||
let mut offset = initial_offset;
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
if (i + offset + 1) % n == 0 {
|
||||
@@ -860,20 +838,28 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner: Vec<T> = vec![];
|
||||
let mut indices_to_remove = std::collections::HashSet::new();
|
||||
for i in 0..self.inner.len() {
|
||||
if (i + initial_offset + 1) % n == 0 {
|
||||
for j in 1..(1 + num_repeats) {
|
||||
indices_to_remove.insert(i + j);
|
||||
}
|
||||
}
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot remove every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let old_inner = self.inner.clone();
|
||||
for (i, elem) in old_inner.into_iter().enumerate() {
|
||||
if !indices_to_remove.contains(&i) {
|
||||
inner.push(elem.clone());
|
||||
// Pre-calculate capacity to avoid reallocations
|
||||
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
|
||||
let mut inner = Vec::with_capacity(estimated_size);
|
||||
|
||||
// Use iterator directly instead of creating intermediate collectionsif
|
||||
let mut i = 0;
|
||||
while i < self.inner.len() {
|
||||
// Add the current element
|
||||
inner.push(self.inner[i].clone());
|
||||
|
||||
// If this is an nth position (accounting for offset)
|
||||
if (i + initial_offset + 1) % n == 0 {
|
||||
// Skip the next num_repeats elements
|
||||
i += num_repeats + 1;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -881,7 +867,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
|
||||
/// Remove indices
|
||||
/// WARN: assumes indices are in ascending order for speed
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
@@ -908,7 +893,11 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
// remove indices
|
||||
for elem in indices.iter().rev() {
|
||||
inner.remove(*elem);
|
||||
if *elem < self.len() {
|
||||
inner.remove(*elem);
|
||||
} else {
|
||||
return Err(TensorError::IndexOutOfBounds(*elem, self.len()));
|
||||
}
|
||||
}
|
||||
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
@@ -1400,10 +1389,6 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "add");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1501,10 +1486,6 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "sub");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1572,10 +1553,6 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "mul");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1681,7 +1658,9 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
}
|
||||
|
||||
// implement remainder
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + PartialEq> Rem
|
||||
for Tensor<T>
|
||||
{
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
|
||||
/// Elementwise remainder of a tensor with another tensor.
|
||||
@@ -1710,9 +1689,25 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() % r;
|
||||
});
|
||||
lhs.par_iter_mut()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| {
|
||||
if let Some(zero) = T::zero() {
|
||||
if r != zero {
|
||||
*o = o.clone() % r;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Cannot divide by zero in remainder".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(lhs)
|
||||
}
|
||||
@@ -1747,7 +1742,6 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
/// assert_eq!(c, vec![2, 3]);
|
||||
///
|
||||
/// ```
|
||||
|
||||
pub fn get_broadcasted_shape(
|
||||
shape_a: &[usize],
|
||||
shape_b: &[usize],
|
||||
@@ -1755,20 +1749,21 @@ pub fn get_broadcasted_shape(
|
||||
let num_dims_a = shape_a.len();
|
||||
let num_dims_b = shape_b.len();
|
||||
|
||||
match (num_dims_a, num_dims_b) {
|
||||
(a, b) if a == b => {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
if num_dims_a == num_dims_b {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
}
|
||||
(a, b) if a < b => Ok(shape_b.to_vec()),
|
||||
(a, b) if a > b => Ok(shape_a.to_vec()),
|
||||
_ => Err(TensorError::DimError(
|
||||
Ok(broadcasted_shape)
|
||||
} else if num_dims_a < num_dims_b {
|
||||
Ok(shape_b.to_vec())
|
||||
} else if num_dims_a > num_dims_b {
|
||||
Ok(shape_a.to_vec())
|
||||
} else {
|
||||
Err(TensorError::DimError(
|
||||
"Unknown condition for broadcasting".to_string(),
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
@@ -1807,66 +1802,4 @@ mod tests {
|
||||
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_int() {
|
||||
let a = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let b = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_felt() {
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
let a = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
let b = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ pub fn get_rep(
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
|
||||
if (*x).abs() > ((base as i128).pow(n as u32)) - 1 {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
@@ -43,8 +43,8 @@ pub fn get_rep(
|
||||
let mut x = x.abs();
|
||||
//
|
||||
for i in (1..rep.len()).rev() {
|
||||
rep[i] = x % base as i128;
|
||||
x /= base as i128;
|
||||
rep[i] = x % base as IntegerRep;
|
||||
x /= base as IntegerRep;
|
||||
}
|
||||
|
||||
Ok(rep)
|
||||
@@ -127,7 +127,7 @@ pub fn decompose(
|
||||
.flatten()
|
||||
.collect::<Vec<IntegerRep>>();
|
||||
|
||||
let output = Tensor::<i128>::new(Some(&resp), &dims)?;
|
||||
let output = Tensor::<IntegerRep>::new(Some(&resp), &dims)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -385,6 +385,12 @@ pub fn resize<T: TensorType + Send + Sync>(
|
||||
pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("add".to_string()));
|
||||
}
|
||||
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -433,6 +439,11 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("sub".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -479,6 +490,11 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
return Err(TensorError::DimMismatch("mult".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,33 +5,35 @@ use log::{debug, error, warn};
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
|
||||
use super::*;
|
||||
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
|
||||
/// Typically assign [ValTensor]s to [VarTensor]s when laying out a circuit.
|
||||
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
|
||||
/// VarTensors are used to store and manage circuit columns, typically for assigning ValTensor
|
||||
/// values during circuit layout. The tensor organizes storage into blocks of columns, where each
|
||||
/// block contains multiple columns and each column contains multiple rows.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq)]
|
||||
pub enum VarTensor {
|
||||
/// A VarTensor for holding Advice values, which are assigned at proving time.
|
||||
Advice {
|
||||
/// Vec of Advice columns, we have [[xx][xx][xx]...] where each inner vec is xx columns
|
||||
inner: Vec<Vec<Column<Advice>>>,
|
||||
///
|
||||
/// The number of columns in each inner block
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// Dummy var
|
||||
/// A placeholder tensor used for testing or temporary storage
|
||||
Dummy {
|
||||
///
|
||||
/// The number of columns in each inner block
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// Empty var
|
||||
/// An empty tensor with no storage
|
||||
#[default]
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// name of the tensor
|
||||
/// Returns the name of the tensor variant as a static string
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => "Advice",
|
||||
@@ -40,22 +42,35 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns true if the tensor is an Advice variant
|
||||
pub fn is_advice(&self) -> bool {
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
}
|
||||
|
||||
/// Calculates the maximum number of usable rows in the constraint system
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - Log base 2 of the total number of rows (including system and blinding rows)
|
||||
///
|
||||
/// # Returns
|
||||
/// The maximum number of usable rows after accounting for blinding factors
|
||||
pub fn max_rows<F: PrimeField>(cs: &ConstraintSystem<F>, logrows: usize) -> usize {
|
||||
let base = 2u32;
|
||||
base.pow(logrows as u32) as usize - cs.blinding_factors() - 1
|
||||
}
|
||||
|
||||
/// Create a new VarTensor::Advice that is unblinded
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
/// Creates a new VarTensor::Advice with unblinded columns. Unblinded columns are used when
|
||||
/// the values do not need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with unblinded columns enabled for equality constraints
|
||||
pub fn new_unblinded_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -93,11 +108,17 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
|
||||
/// the values need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with blinded columns enabled for equality constraints
|
||||
pub fn new_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -133,11 +154,17 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes fixed columns to support the VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
|
||||
/// Fixed columns are used for constant values that are known at circuit creation time.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_constants` - Number of constant values needed
|
||||
/// * `module_requires_fixed` - Whether the module requires at least one fixed column
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of fixed columns created
|
||||
pub fn constant_cols<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -169,7 +196,14 @@ impl VarTensor {
|
||||
modulo
|
||||
}
|
||||
|
||||
/// Create a new VarTensor::Dummy
|
||||
/// Creates a new dummy VarTensor for testing or temporary storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Dummy with the specified dimensions
|
||||
pub fn dummy(logrows: usize, num_inner_cols: usize) -> Self {
|
||||
let base = 2u32;
|
||||
let max_rows = base.pow(logrows as u32) as usize - 6;
|
||||
@@ -179,7 +213,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the dims of the object the VarTensor represents
|
||||
/// Returns the number of blocks in the tensor
|
||||
pub fn num_blocks(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner.len(),
|
||||
@@ -187,7 +221,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Num inner cols
|
||||
/// Returns the number of columns in each inner block
|
||||
pub fn num_inner_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { num_inner_cols, .. } | VarTensor::Dummy { num_inner_cols, .. } => {
|
||||
@@ -197,7 +231,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Total number of columns
|
||||
/// Returns the total number of columns across all blocks
|
||||
pub fn num_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner[0].len() * inner.len(),
|
||||
@@ -205,7 +239,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the size of each column
|
||||
/// Returns the maximum number of rows in each column
|
||||
pub fn col_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { col_size, .. } | VarTensor::Dummy { col_size, .. } => *col_size,
|
||||
@@ -213,7 +247,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the size of each column
|
||||
/// Returns the total size of each block (num_inner_cols * col_size)
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice {
|
||||
@@ -230,7 +264,13 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Take a linear coordinate and output the (column, row) position in the storage block.
|
||||
/// Converts a linear coordinate to (block, column, row) coordinates in the storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `linear_coord` - The linear index to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (block_index, column_index, row_index)
|
||||
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
|
||||
// x indexes over blocks of size num_inner_cols
|
||||
let x = linear_coord / self.block_size();
|
||||
@@ -243,7 +283,17 @@ impl VarTensor {
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// Retrieve the value of a specific cell in the tensor.
|
||||
/// Queries a range of cells in the tensor during circuit synthesis
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `y` - Column index within block
|
||||
/// * `z` - Starting row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried cells
|
||||
pub fn query_rng<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -268,7 +318,16 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve the value of a specific block at an offset in the tensor.
|
||||
/// Queries an entire block of cells at a given offset
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `z` - Row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried block
|
||||
pub fn query_whole_block<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -293,7 +352,16 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a constant value to a specific cell in the tensor.
|
||||
/// Assigns a constant value to a specific cell in the tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `coord` - Coordinate within the tensor
|
||||
/// * `constant` - The constant value to assign
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned cell or an error if assignment fails
|
||||
pub fn assign_constant<F: PrimeField + TensorType + PartialOrd>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -313,7 +381,17 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
/// Assigns values from a ValTensor to this tensor, excluding specified positions
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `omissions` - Set of positions to skip during assignment
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -344,7 +422,16 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
/// Assigns values from a ValTensor to this tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -396,14 +483,23 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Helper function to get the remaining size of the column
|
||||
/// Returns the remaining available space in a column for assignments
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `offset` - Current offset in the column
|
||||
/// * `values` - The ValTensor to check space for
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of rows that need to be flushed or an error if space is insufficient
|
||||
pub fn get_column_flush<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<usize, halo2_proofs::plonk::Error> {
|
||||
if values.len() > self.col_size() {
|
||||
error!("Values are too large for the column");
|
||||
error!(
|
||||
"There are too many values to flush for this column size, try setting the logrows to a higher value (eg. --logrows 22 on the cli)"
|
||||
);
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
|
||||
@@ -427,8 +523,16 @@ impl VarTensor {
|
||||
Ok(flush_len)
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor. Whereby the values are assigned to a single column, without overflowing.
|
||||
/// So for instance if we are assigning 10 values and we are at index 18 of the column, and the columns are of length 20, we skip the last 2 values of current column and start from the beginning of the next column.
|
||||
/// Assigns values to a single column, avoiding column overflow by flushing to the next column if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, number of rows flushed) or an error if assignment fails
|
||||
pub fn assign_exact_column<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -443,8 +547,17 @@ impl VarTensor {
|
||||
Ok((assigned_vals, flush_len))
|
||||
}
|
||||
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
/// Assigns values with duplication in dummy mode, used for testing and simulation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `single_inner_col` - Whether to treat as a single column
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
pub fn dummy_assign_with_duplication<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -494,16 +607,75 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Assigns values with duplication but without enforcing constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
pub fn assign_with_duplication_unconstrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
match values {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
|
||||
let duplication_freq = self.block_size();
|
||||
|
||||
let num_repeats = self.num_inner_cols();
|
||||
|
||||
let duplication_offset = offset;
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
let mut res: ValTensor<F> = {
|
||||
v.enum_map(|coord, k| {
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord, constants)?;
|
||||
Ok::<_, halo2_proofs::plonk::Error>(cell)
|
||||
|
||||
})?.into()};
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.set_scale(values.scale());
|
||||
|
||||
Ok((res, total_used_len))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values with duplication and enforces equality constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `check_mode` - Mode for checking equality constraints
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
pub fn assign_with_duplication_constrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
row: usize,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &CheckMode,
|
||||
single_inner_col: bool,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
let mut prev_cell = None;
|
||||
@@ -512,34 +684,16 @@ impl VarTensor {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
|
||||
let duplication_freq = if single_inner_col {
|
||||
self.col_size()
|
||||
} else {
|
||||
self.block_size()
|
||||
};
|
||||
|
||||
let num_repeats = if single_inner_col {
|
||||
1
|
||||
} else {
|
||||
self.num_inner_cols()
|
||||
};
|
||||
|
||||
let duplication_offset = if single_inner_col {
|
||||
row
|
||||
} else {
|
||||
offset
|
||||
};
|
||||
let duplication_freq = self.col_size();
|
||||
let num_repeats = 1;
|
||||
let duplication_offset = row;
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
let mut res: ValTensor<F> = {
|
||||
let mut res: ValTensor<F> =
|
||||
v.enum_map(|coord, k| {
|
||||
|
||||
let step = if !single_inner_col {
|
||||
1
|
||||
} else {
|
||||
self.num_inner_cols()
|
||||
};
|
||||
let step = self.num_inner_cols();
|
||||
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord * step);
|
||||
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
|
||||
@@ -549,53 +703,59 @@ impl VarTensor {
|
||||
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
|
||||
if single_inner_col {
|
||||
if z == 0 {
|
||||
let at_end_of_column = z == duplication_freq - 1;
|
||||
let at_beginning_of_column = z == 0;
|
||||
|
||||
if at_end_of_column {
|
||||
// if we are at the end of the column, we need to copy the cell to the next column
|
||||
prev_cell = Some(cell.clone());
|
||||
} else if coord > 0 && z == 0 && single_inner_col {
|
||||
} else if coord > 0 && at_beginning_of_column {
|
||||
if let Some(prev_cell) = prev_cell.as_ref() {
|
||||
let cell = cell.cell().ok_or({
|
||||
let cell = if let Some(cell) = cell.cell() {
|
||||
cell
|
||||
} else {
|
||||
error!("Error getting cell: {:?}", (x,y));
|
||||
halo2_proofs::plonk::Error::Synthesis})?;
|
||||
let prev_cell = prev_cell.cell().ok_or({
|
||||
error!("Error getting cell: {:?}", (x,y));
|
||||
halo2_proofs::plonk::Error::Synthesis})?;
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
|
||||
prev_cell
|
||||
} else {
|
||||
error!("Error getting prev cell: {:?}", (x,y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
region.constrain_equal(prev_cell,cell)?;
|
||||
} else {
|
||||
error!("Error copy-constraining previous value: {:?}", (x,y));
|
||||
error!("Previous cell was not set");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
}}
|
||||
}
|
||||
|
||||
Ok(cell)
|
||||
|
||||
})?.into()};
|
||||
})?.into();
|
||||
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.set_scale(values.scale());
|
||||
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
// during key generation this will be 0 so we use this as a flag to check
|
||||
// TODO: this isn't very safe and would be better to get the phase directly
|
||||
let res_evals = res.int_evals().unwrap();
|
||||
let is_assigned = res_evals
|
||||
.iter()
|
||||
.all(|&x| x == 0);
|
||||
if !is_assigned {
|
||||
assert_eq!(
|
||||
values.int_evals().unwrap(),
|
||||
res_evals
|
||||
)};
|
||||
}
|
||||
|
||||
Ok((res, total_used_len))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a single value to the tensor. This is a helper function used by other assignment methods.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `k` - The value to assign
|
||||
/// * `coord` - The coordinate where to assign the value
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned value or an error if assignment fails
|
||||
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -606,24 +766,28 @@ impl VarTensor {
|
||||
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
let res = match k {
|
||||
// Handle direct value assignment
|
||||
ValType::Value(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned value
|
||||
ValType::PrevAssigned(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned constant
|
||||
ValType::AssignedConstant(v, val) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle assigning evaluated value
|
||||
ValType::AssignedValue(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
|
||||
region
|
||||
@@ -632,6 +796,7 @@ impl VarTensor {
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle constant value assignment with caching
|
||||
ValType::Constant(v) => {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
|
||||
let value = ValType::AssignedConstant(
|
||||
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -75,9 +75,8 @@ mod native_tests {
|
||||
});
|
||||
}
|
||||
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn init_wasm() {
|
||||
fn init_wasm() {
|
||||
COMPILE_WASM.call_once(|| {
|
||||
build_wasm_ezkl();
|
||||
});
|
||||
@@ -187,13 +186,14 @@ mod native_tests {
|
||||
|
||||
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
|
||||
|
||||
const LARGE_TESTS: [&str; 6] = [
|
||||
const LARGE_TESTS: [&str; 7] = [
|
||||
"self_attention",
|
||||
"nanoGPT",
|
||||
"multihead_attention",
|
||||
"mobilenet",
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age",
|
||||
];
|
||||
|
||||
const ACCURACY_CAL_TESTS: [&str; 6] = [
|
||||
@@ -395,29 +395,29 @@ mod native_tests {
|
||||
const TESTS_AGGR: [&str; 3] = ["1l_mlp", "1l_flatten", "1l_average"];
|
||||
|
||||
const TESTS_EVM: [&str; 23] = [
|
||||
"1l_mlp",
|
||||
"1l_flatten",
|
||||
"1l_average",
|
||||
"1l_reshape",
|
||||
"1l_sigmoid",
|
||||
"1l_div",
|
||||
"1l_sqrt",
|
||||
"1l_prelu",
|
||||
"1l_var",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
"1l_relu",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_small",
|
||||
"min",
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"idolmodel",
|
||||
"1l_identity",
|
||||
"lstm",
|
||||
"rnn",
|
||||
"quantize_dequantize",
|
||||
"1l_mlp", // 0
|
||||
"1l_flatten", // 1
|
||||
"1l_average", // 2
|
||||
"1l_reshape", // 3
|
||||
"1l_sigmoid", // 4
|
||||
"1l_div", // 5
|
||||
"1l_sqrt", // 6
|
||||
"1l_prelu", // 7
|
||||
"1l_var", // 8
|
||||
"1l_leakyrelu", // 9
|
||||
"1l_gelu_noappx", // 10
|
||||
"1l_relu", // 11
|
||||
"1l_tanh", // 12
|
||||
"2l_relu_sigmoid_small", // 13
|
||||
"2l_relu_small", // 14
|
||||
"min", // 15
|
||||
"max", // 16
|
||||
"1l_max_pool", // 17
|
||||
"idolmodel", // 18
|
||||
"1l_identity", // 19
|
||||
"lstm", // 20
|
||||
"rnn", // 21
|
||||
"quantize_dequantize", // 22
|
||||
];
|
||||
|
||||
const TESTS_EVM_AGGR: [&str; 18] = [
|
||||
@@ -541,7 +541,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -606,7 +606,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -616,7 +616,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -627,7 +627,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(8194), Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -642,7 +642,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -652,7 +652,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -661,7 +661,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -670,7 +670,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -679,7 +679,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -688,7 +688,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -697,7 +697,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -706,7 +706,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -716,7 +716,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -726,7 +726,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -735,7 +735,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -745,7 +745,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -754,7 +754,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -764,7 +764,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -774,7 +774,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -784,7 +784,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -794,7 +794,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -804,7 +804,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -963,7 +963,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=5 {
|
||||
seq!(N in 0..=6 {
|
||||
|
||||
#(#[test_case(LARGE_TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -981,7 +981,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1459,6 +1459,8 @@ mod native_tests {
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
@@ -1475,6 +1477,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
decomp_base,
|
||||
decomp_legs,
|
||||
);
|
||||
|
||||
if tolerance > 0.0 {
|
||||
@@ -1616,6 +1620,8 @@ mod native_tests {
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1634,6 +1640,14 @@ mod native_tests {
|
||||
format!("--commitment={}", commitment),
|
||||
];
|
||||
|
||||
if let Some(decomp_base) = decomp_base {
|
||||
args.push(format!("--decomp-base={}", decomp_base));
|
||||
}
|
||||
|
||||
if let Some(decomp_legs) = decomp_legs {
|
||||
args.push(format!("--decomp-legs={}", decomp_legs));
|
||||
}
|
||||
|
||||
if bounded_lookup_log {
|
||||
args.push("--bounded-log-lookup".to_string());
|
||||
}
|
||||
@@ -1751,6 +1765,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -2035,6 +2051,8 @@ mod native_tests {
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -2228,6 +2246,7 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier(
|
||||
num_inner_columns: usize,
|
||||
test_dir: &str,
|
||||
@@ -2467,6 +2486,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
@@ -2774,7 +2795,17 @@ mod native_tests {
|
||||
"--features",
|
||||
"icicle",
|
||||
];
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(feature = "macos-metal")]
|
||||
let args = [
|
||||
"build",
|
||||
"--release",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--features",
|
||||
"macos-metal",
|
||||
];
|
||||
// not macos-metal and not icicle
|
||||
#[cfg(all(not(feature = "icicle"), not(feature = "macos-metal")))]
|
||||
let args = ["build", "--release", "--bin", "ezkl"];
|
||||
#[cfg(not(feature = "mv-lookup"))]
|
||||
let args = [
|
||||
|
||||
@@ -72,11 +72,10 @@ mod py_tests {
|
||||
"torchtext==0.17.2",
|
||||
"torchvision==0.17.2",
|
||||
"pandas==2.2.1",
|
||||
"numpy==1.26.4",
|
||||
"seaborn==0.13.2",
|
||||
"notebook==7.1.2",
|
||||
"nbconvert==7.16.3",
|
||||
"onnx==1.16.0",
|
||||
"onnx==1.17.0",
|
||||
"kaggle==1.6.8",
|
||||
"py-solc-x==2.0.3",
|
||||
"web3==7.5.0",
|
||||
@@ -90,12 +89,13 @@ mod py_tests {
|
||||
"xgboost==2.0.3",
|
||||
"hummingbird-ml==0.4.11",
|
||||
"lightgbm==4.3.0",
|
||||
"numpy==1.26.4",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new("pip")
|
||||
.args(["install", "numpy==1.23"])
|
||||
.args(["install", "numpy==1.26.4"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
|
||||
@@ -873,6 +873,7 @@ def get_examples():
|
||||
'linear_regression',
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age"
|
||||
]
|
||||
examples = []
|
||||
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
|
||||
|
||||
Reference in New Issue
Block a user