mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4085b75ce |
66
.github/workflows/benchmarks.yml
vendored
66
.github/workflows/benchmarks.yml
vendored
@@ -8,14 +8,12 @@ on:
|
||||
jobs:
|
||||
|
||||
bench_poseidon:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -24,15 +22,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench poseidon
|
||||
|
||||
bench_einsum_accum_matmul:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -41,15 +37,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_einsum_matmul
|
||||
|
||||
bench_accum_matmul_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -58,15 +52,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu
|
||||
|
||||
bench_accum_matmul_relu_overflow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -75,15 +67,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_matmul_relu_overflow
|
||||
|
||||
bench_relu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -92,15 +82,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench relu
|
||||
|
||||
bench_accum_dot:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -109,15 +97,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_dot
|
||||
|
||||
bench_accum_conv:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -126,15 +112,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_conv
|
||||
|
||||
bench_accum_sumpool:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -143,15 +127,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sumpool
|
||||
|
||||
bench_pairwise_add:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -160,15 +142,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench pairwise_add
|
||||
|
||||
bench_accum_sum:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -177,15 +157,13 @@ jobs:
|
||||
run: cargo bench --verbose --bench accum_sum
|
||||
|
||||
bench_pairwise_pow:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
|
||||
100
.github/workflows/engine.yml
vendored
100
.github/workflows/engine.yml
vendored
@@ -15,24 +15,19 @@ defaults:
|
||||
working-directory: .
|
||||
jobs:
|
||||
publish-wasm-bindings:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
@@ -40,7 +35,7 @@ jobs:
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
@@ -49,41 +44,41 @@ jobs:
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
echo '{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "${RELEASE_TAG}",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}' > pkg/package.json
|
||||
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
@@ -176,7 +171,7 @@ jobs:
|
||||
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
@@ -191,23 +186,20 @@ jobs:
|
||||
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${RELEASE_TAG}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
|
||||
@@ -230,13 +222,13 @@ jobs:
|
||||
NR==30{$0=" specifier: \"" tag "\""}
|
||||
NR==31{$0=" version: \"" tag "\""}
|
||||
NR==400{$0=" /@ezkljs/engine@" tag ":"}
|
||||
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
|
||||
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
@@ -247,4 +239,4 @@ jobs:
|
||||
pnpm run build
|
||||
pnpm publish --no-git-checks
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
8
.github/workflows/large-tests.yml
vendored
8
.github/workflows/large-tests.yml
vendored
@@ -6,16 +6,14 @@ on:
|
||||
description: "Test scenario tags"
|
||||
jobs:
|
||||
large-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: kaiju
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
|
||||
41
.github/workflows/pypi-gpu.yml
vendored
41
.github/workflows/pypi-gpu.yml
vendored
@@ -18,46 +18,41 @@ defaults:
|
||||
jobs:
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: GPU
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag and rename ezkl to ezkl-gpu
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig > pyproject.toml.tmp
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.tmp > pyproject.toml
|
||||
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Set Cargo.toml version to match github tag and rename ezkl to ezkl-gpu
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
# the ezkl substitution here looks for the first instance of name = "ezkl" and changes it to "ezkl-gpu"
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "0,/name = \"ezkl\"/s/name = \"ezkl\"/name = \"ezkl-gpu\"/" Cargo.toml.orig > Cargo.toml.tmp
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.tmp > Cargo.toml
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig > Cargo.lock
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
@@ -65,7 +60,7 @@ jobs:
|
||||
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: auto
|
||||
@@ -78,7 +73,7 @@ jobs:
|
||||
pip install ezkl-gpu --no-index --find-links dist --force-reinstall
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: wheels
|
||||
path: dist
|
||||
@@ -94,7 +89,7 @@ jobs:
|
||||
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
|
||||
needs: [linux]
|
||||
steps:
|
||||
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: wheels
|
||||
- name: List Files
|
||||
@@ -106,14 +101,14 @@ jobs:
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./wheels
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./wheels
|
||||
packages-dir: ./
|
||||
|
||||
208
.github/workflows/pypi.yml
vendored
208
.github/workflows/pypi.yml
vendored
@@ -16,53 +16,38 @@ defaults:
|
||||
|
||||
jobs:
|
||||
macos:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64, universal2-apple-darwin]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'universal2-apple-darwin'
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'x86_64'
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
@@ -73,36 +58,26 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: dist-macos-${{ matrix.target }}
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
windows:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: windows-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x64, x86]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: ${{ matrix.target }}
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -113,14 +88,14 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
@@ -130,36 +105,26 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0 #v4.6.0
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: dist-windows-${{ matrix.target }}
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
linux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -170,13 +135,14 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: auto
|
||||
@@ -203,14 +169,63 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: dist-linux-${{ matrix.target }}
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
|
||||
# linux-cross:
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
# matrix:
|
||||
# target: [aarch64, armv7]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: 3.12
|
||||
|
||||
# - name: Install cross-compilation tools for aarch64
|
||||
# if: matrix.target == 'aarch64'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
|
||||
|
||||
# - name: Install cross-compilation tools for armv7
|
||||
# if: matrix.target == 'armv7'
|
||||
# run: |
|
||||
# sudo apt-get update
|
||||
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
|
||||
|
||||
# - name: Build wheels
|
||||
# uses: PyO3/maturin-action@v1
|
||||
# with:
|
||||
# target: ${{ matrix.target }}
|
||||
# manylinux: auto
|
||||
# args: --release --out dist --features python-bindings
|
||||
|
||||
# - uses: uraimo/run-on-arch-action@v2.5.0
|
||||
# name: Install built wheel
|
||||
# with:
|
||||
# arch: ${{ matrix.target }}
|
||||
# distro: ubuntu20.04
|
||||
# githubToken: ${{ github.token }}
|
||||
# install: |
|
||||
# apt-get update
|
||||
# apt-get install -y --no-install-recommends python3 python3-pip
|
||||
# pip3 install -U pip
|
||||
# run: |
|
||||
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
|
||||
# python3 -c "import ezkl"
|
||||
|
||||
# - name: Upload wheels
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: wheels
|
||||
# path: dist
|
||||
|
||||
musllinux:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -218,10 +233,10 @@ jobs:
|
||||
target:
|
||||
- x86_64-unknown-linux-musl
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
@@ -243,14 +258,13 @@ jobs:
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y pkg-config libssl-dev
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: musllinux_1_2
|
||||
@@ -271,14 +285,12 @@ jobs:
|
||||
python3 -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: dist-musllinux-${{ matrix.target }}
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
musllinux-cross:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
strategy:
|
||||
@@ -287,21 +299,13 @@ jobs:
|
||||
- target: aarch64-unknown-linux-musl
|
||||
arch: aarch64
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
env:
|
||||
@@ -313,13 +317,13 @@ jobs:
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
uses: PyO3/maturin-action@v1
|
||||
with:
|
||||
target: ${{ matrix.platform.target }}
|
||||
manylinux: musllinux_1_2
|
||||
args: --release --out dist --features python-bindings
|
||||
|
||||
- uses: uraimo/run-on-arch-action@5397f9e30a9b62422f302092631c99ae1effcd9e #v2.8.1
|
||||
- uses: uraimo/run-on-arch-action@v2.8.1
|
||||
name: Install built wheel
|
||||
with:
|
||||
arch: ${{ matrix.platform.arch }}
|
||||
@@ -334,9 +338,9 @@ jobs:
|
||||
python3 -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: dist-musllinux-${{ matrix.platform.target }}
|
||||
name: wheels
|
||||
path: dist
|
||||
|
||||
pypi-publish:
|
||||
@@ -345,43 +349,45 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
if: "startsWith(github.ref, 'refs/tags/')"
|
||||
# TODO: Uncomment if linux-cross is working
|
||||
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
|
||||
needs: [macos, windows, linux, musllinux, musllinux-cross]
|
||||
steps:
|
||||
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
pattern: dist-*
|
||||
merge-multiple: true
|
||||
path: wheels
|
||||
name: wheels
|
||||
- name: List Files
|
||||
run: ls -R
|
||||
|
||||
# # publishes to TestPyPI
|
||||
# - name: Publish package distribution to TestPyPI
|
||||
# uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# packages-dir: ./
|
||||
# Both publish steps will fail if there is no trusted publisher setup
|
||||
# On failure the publish step will then simply continue to the next one
|
||||
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./wheels
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
permissions:
|
||||
contents: read
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
|
||||
50
.github/workflows/release.yml
vendored
50
.github/workflows/release.yml
vendored
@@ -10,9 +10,6 @@ on:
|
||||
- "*"
|
||||
jobs:
|
||||
create-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: create-release
|
||||
runs-on: ubuntu-22.04
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -30,15 +27,12 @@ jobs:
|
||||
|
||||
- name: Create Github Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@c95fe1489396fe8a9eb87c0abf8aa5b2ef267fda #v2.2.1
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
token: ${{ secrets.RELEASE_TOKEN }}
|
||||
tag_name: ${{ env.EZKL_VERSION }}
|
||||
|
||||
build-release-gpu:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: build-release-gpu
|
||||
needs: ["create-release"]
|
||||
runs-on: GPU
|
||||
@@ -49,14 +43,14 @@ jobs:
|
||||
RUST_BACKTRACE: 1
|
||||
PCRE2_SYS_STATIC: 1
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
@@ -90,7 +84,7 @@ jobs:
|
||||
echo "ASSET=build-artifacts/ezkl-linux-gpu.tar.gz" >> $GITHUB_ENV
|
||||
|
||||
- name: Upload release archive
|
||||
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
|
||||
uses: actions/upload-release-asset@v1.0.2
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
|
||||
with:
|
||||
@@ -100,10 +94,6 @@ jobs:
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
build-release:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
issues: write
|
||||
name: build-release
|
||||
needs: ["create-release"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -119,33 +109,33 @@ jobs:
|
||||
include:
|
||||
- build: windows-msvc
|
||||
os: windows-latest
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-pc-windows-msvc
|
||||
- build: macos
|
||||
os: macos-13
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-apple-darwin
|
||||
- build: macos-aarch64
|
||||
os: macos-13
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2024-07-18
|
||||
target: aarch64-apple-darwin
|
||||
- build: linux-musl
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-unknown-linux-musl
|
||||
- build: linux-gnu
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-unknown-linux-gnu
|
||||
- build: linux-aarch64
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2024-07-18
|
||||
target: aarch64-unknown-linux-gnu
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get release version from tag
|
||||
@@ -170,7 +160,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
|
||||
uses: dtolnay/rust-toolchain@nightly
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
|
||||
@@ -196,18 +186,14 @@ jobs:
|
||||
echo "target flag is: ${{ env.TARGET_FLAGS }}"
|
||||
echo "target dir is: ${{ env.TARGET_DIR }}"
|
||||
|
||||
- name: Build release binary (no asm or metal)
|
||||
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
|
||||
- name: Build release binary (no asm)
|
||||
if: matrix.build != 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
|
||||
|
||||
- name: Build release binary (metal)
|
||||
if: matrix.build == 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
|
||||
|
||||
- name: Strip release binary
|
||||
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
|
||||
run: strip "target/${{ matrix.target }}/release/ezkl"
|
||||
@@ -233,7 +219,7 @@ jobs:
|
||||
echo "ASSET=build-artifacts/ezkl-win.zip" >> $GITHUB_ENV
|
||||
|
||||
- name: Upload release archive
|
||||
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
|
||||
uses: actions/upload-release-asset@v1.0.2
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
|
||||
with:
|
||||
|
||||
658
.github/workflows/rust.yml
vendored
658
.github/workflows/rust.yml
vendored
File diff suppressed because it is too large
Load Diff
10
.github/workflows/static-analysis.yml
vendored
10
.github/workflows/static-analysis.yml
vendored
@@ -8,16 +8,15 @@ on:
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
@@ -30,3 +29,4 @@ jobs:
|
||||
run: zizmor .
|
||||
|
||||
|
||||
|
||||
9
.github/workflows/swift-pm.yml
vendored
9
.github/workflows/swift-pm.yml
vendored
@@ -9,9 +9,6 @@ on:
|
||||
|
||||
jobs:
|
||||
build-and-update:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
|
||||
@@ -19,8 +16,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout EZKL
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Extract TAG from github.ref_name
|
||||
@@ -34,7 +31,7 @@ jobs:
|
||||
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Rust (nightly)
|
||||
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
override: true
|
||||
|
||||
6
.github/workflows/tagging.yml
vendored
6
.github/workflows/tagging.yml
vendored
@@ -11,12 +11,12 @@ jobs:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@a22cf08638b34d5badda920f9daf6e72c477b07b #v6.2
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
git tag $RELEASE_TAG
|
||||
|
||||
- name: Push changes
|
||||
uses: ad-m/github-push-action@77c5b412c50b723d2a4fbc6d71fb5723bcd439aa #master
|
||||
uses: ad-m/github-push-action@master
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
with:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,7 +9,6 @@ pkg
|
||||
!AttestData.sol
|
||||
!VerifierBase.sol
|
||||
!LoadInstances.sol
|
||||
!AttestData.t.sol
|
||||
*.pf
|
||||
*.vk
|
||||
*.pk
|
||||
@@ -50,5 +49,3 @@ timingData.json
|
||||
!tests/assets/vk.key
|
||||
docs/python/build
|
||||
!tests/assets/vk_aggr.key
|
||||
cache
|
||||
out
|
||||
|
||||
132
Cargo.lock
generated
132
Cargo.lock
generated
@@ -1,6 +1,6 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
@@ -944,7 +944,7 @@ dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.11.0",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
@@ -1760,7 +1760,7 @@ checksum = "a650a461c6a8ff1ef205ed9a2ad56579309853fecefc2423f73dced342f92258"
|
||||
[[package]]
|
||||
name = "ecc"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"integer",
|
||||
"num-bigint",
|
||||
@@ -1835,16 +1835,6 @@ dependencies = [
|
||||
"syn 2.0.90",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_filter"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.2"
|
||||
@@ -1858,19 +1848,6 @@ dependencies = [
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"env_filter",
|
||||
"humantime",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
@@ -1946,7 +1923,7 @@ dependencies = [
|
||||
"console_error_panic_hook",
|
||||
"criterion 0.5.1",
|
||||
"ecc",
|
||||
"env_logger 0.10.2",
|
||||
"env_logger",
|
||||
"ethabi",
|
||||
"foundry-compilers",
|
||||
"gag",
|
||||
@@ -1954,7 +1931,7 @@ dependencies = [
|
||||
"halo2_gadgets",
|
||||
"halo2_proofs",
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"halo2curves 0.7.0",
|
||||
"hex",
|
||||
"indicatif",
|
||||
"instant",
|
||||
@@ -1962,17 +1939,20 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"metal",
|
||||
"mimalloc",
|
||||
"mnist",
|
||||
"num",
|
||||
"objc",
|
||||
"openssl",
|
||||
"pg_bigdecimal",
|
||||
"portable-atomic",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"semver 1.0.22",
|
||||
"seq-macro",
|
||||
@@ -2397,7 +2377,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
source = "git+https://github.com/zkonduit/halo2#d7ecad83c7439fa1cb450ee4a89c2d0b45604ceb"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec",
|
||||
@@ -2414,14 +2394,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
source = "git+https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6#ee4e1a09ebdb1f79f797685b78951c6034c430a6"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
"env_logger 0.10.2",
|
||||
"env_logger",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"halo2curves 0.7.0",
|
||||
"icicle-bn254",
|
||||
"icicle-core",
|
||||
"icicle-cuda-runtime",
|
||||
@@ -2429,7 +2409,6 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"mopro-msm",
|
||||
"rand_chacha",
|
||||
"rand_core 0.6.4",
|
||||
"rustc-hash 2.0.0",
|
||||
@@ -2441,7 +2420,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_solidity_verifier"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier#80c20d6ab57d3b28b2a28df4b63c30923bde17e1"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier#7def6101d32331182f91483832e4fd293d75f33e"
|
||||
dependencies = [
|
||||
"askama",
|
||||
"blake2b_simd",
|
||||
@@ -2515,36 +2494,6 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d380afeef3f1d4d3245b76895172018cfb087d9976a7cabcd5597775b2933e07"
|
||||
dependencies = [
|
||||
"blake2",
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"pairing",
|
||||
"pasta_curves",
|
||||
"paste",
|
||||
"rand 0.8.5",
|
||||
"rand_core 0.6.4",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_arrays",
|
||||
"sha2",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.7.0"
|
||||
@@ -2554,7 +2503,7 @@ dependencies = [
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2derive 0.1.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"halo2derive",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
@@ -2574,20 +2523,6 @@ dependencies = [
|
||||
"unroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bdb99e7492b4f5ff469d238db464131b86c2eaac814a78715acba369f64d2c76"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2derive"
|
||||
version = "0.1.0"
|
||||
@@ -2604,7 +2539,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2_proofs",
|
||||
"num-bigint",
|
||||
@@ -2955,7 +2890,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "integer"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"maingate",
|
||||
"num-bigint",
|
||||
@@ -3266,7 +3201,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "maingate"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac%2Fchunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
|
||||
dependencies = [
|
||||
"halo2wrong",
|
||||
"num-bigint",
|
||||
@@ -3348,8 +3283,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "metal"
|
||||
version = "0.29.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21"
|
||||
source = "git+https://github.com/gfx-rs/metal-rs#0e1918b34689c4b8cd13a43372f9898680547ee9"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"block",
|
||||
@@ -3420,28 +3354,6 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mopro-msm"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/metal-msm-gpu-acceleration.git#be5f647b1a6c1a6ea9024390744a2b4d87f5d002"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"env_logger 0.11.6",
|
||||
"halo2curves 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"instant",
|
||||
"itertools 0.13.0",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"metal",
|
||||
"objc",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
@@ -3675,9 +3587,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-src"
|
||||
version = "300.4.1+3.4.0"
|
||||
version = "300.2.3+3.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faa4eac4138c62414b5622d1b31c5c304f34b406b013c079c2bbc652fdd6678c"
|
||||
checksum = "5cff92b6f71555b61bb9315f7c64da3ca43d87531622120fea0195fc761b4843"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
@@ -5230,7 +5142,7 @@ checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c"
|
||||
[[package]]
|
||||
name = "snark-verifier"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac%2Fchunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#8762701ab8fa04e7d243a346030afd85633ec970"
|
||||
dependencies = [
|
||||
"ecc",
|
||||
"halo2_proofs",
|
||||
@@ -6234,7 +6146,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "uniffi_testing"
|
||||
version = "0.28.0"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat%2Ftesting-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat/testing-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"camino",
|
||||
|
||||
29
Cargo.toml
29
Cargo.toml
@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
|
||||
[package]
|
||||
name = "ezkl"
|
||||
version = "0.0.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
default-run = "ezkl"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
@@ -40,6 +40,7 @@ maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = { version = "1.6.0", optional = true }
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
|
||||
semver = { version = "1.0.22", optional = true }
|
||||
|
||||
@@ -73,6 +74,7 @@ tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
regex = { version = "1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
@@ -89,6 +91,7 @@ pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", ver
|
||||
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
pyo3-stub-gen = { version = "0.6.0", optional = true }
|
||||
@@ -242,14 +245,16 @@ ezkl = [
|
||||
"dep:indicatif",
|
||||
"dep:gag",
|
||||
"dep:reqwest",
|
||||
"dep:openssl",
|
||||
"dep:tokio-postgres",
|
||||
"dep:pg_bigdecimal",
|
||||
"dep:lazy_static",
|
||||
"dep:regex",
|
||||
"dep:tokio",
|
||||
"dep:openssl",
|
||||
"dep:mimalloc",
|
||||
"dep:chrono",
|
||||
"dep:sha256",
|
||||
"dep:portable-atomic",
|
||||
"dep:clap_complete",
|
||||
"dep:halo2_solidity_verifier",
|
||||
"dep:semver",
|
||||
@@ -272,15 +277,13 @@ icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
no-update = []
|
||||
macos-metal = ["halo2_proofs/macos"]
|
||||
ios-metal = ["halo2_proofs/ios"]
|
||||
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
|
||||
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
@@ -289,12 +292,12 @@ uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "fea
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
#panic = "abort"
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
[profile.test-runs]
|
||||
inherits = "dev"
|
||||
opt-level = 3
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = ["-O4", "--flexible-inline-max-function-size", "4294967295"]
|
||||
wasm-opt = [
|
||||
"-O4",
|
||||
"--flexible-inline-max-function-size",
|
||||
"4294967295",
|
||||
]
|
||||
@@ -150,13 +150,6 @@ Ezkl is unaudited, beta software undergoing rapid development. There may be bugs
|
||||
|
||||
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
|
||||
|
||||
|
||||
### Advanced security topics
|
||||
|
||||
Check out `docs/advanced_security` for more advanced information on potential threat vectors.
|
||||
|
||||
|
||||
|
||||
### no warranty
|
||||
|
||||
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "_callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_decimals",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_bits",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "_instanceOffset",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "constructor"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "HALF_ORDER",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "ORDER",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "instances",
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"name": "attestData",
|
||||
"outputs": [],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "callData",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "contractAddress",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "getInstancesCalldata",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "instances",
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "getInstancesMemory",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "instances",
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "index",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "getScalars",
|
||||
"outputs": [
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "decimals",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "bits",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"internalType": "struct DataAttestation.Scalars",
|
||||
"name": "",
|
||||
"type": "tuple"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "instanceOffset",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "x",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "y",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "denominator",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "mulDiv",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "result",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "int256",
|
||||
"name": "x",
|
||||
"type": "int256"
|
||||
},
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "decimals",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "bits",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"internalType": "struct DataAttestation.Scalars",
|
||||
"name": "_scalars",
|
||||
"type": "tuple"
|
||||
}
|
||||
],
|
||||
"name": "quantizeData",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "int256",
|
||||
"name": "quantized_data",
|
||||
"type": "int256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "target",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "data",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "staticCall",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "int256",
|
||||
"name": "x",
|
||||
"type": "int256"
|
||||
}
|
||||
],
|
||||
"name": "toFieldElement",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "field_element",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "verifier",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "verifyWithDataAttestation",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
167
abis/DataAttestationMulti.json
Normal file
167
abis/DataAttestationMulti.json
Normal file
@@ -0,0 +1,167 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address[]",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address[]"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes[][]",
|
||||
"name": "_callData",
|
||||
"type": "bytes[][]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[][]",
|
||||
"name": "_decimals",
|
||||
"type": "uint256[][]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_scales",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "_instanceOffset",
|
||||
"type": "uint8"
|
||||
},
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "constructor"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "accountCalls",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "contractAddress",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "callCount",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "admin",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "instanceOffset",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "scales",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address[]",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address[]"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes[][]",
|
||||
"name": "_callData",
|
||||
"type": "bytes[][]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[][]",
|
||||
"name": "_decimals",
|
||||
"type": "uint256[][]"
|
||||
}
|
||||
],
|
||||
"name": "updateAccountCalls",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "updateAdmin",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "verifier",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "verifyWithDataAttestation",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
147
abis/DataAttestationSingle.json
Normal file
147
abis/DataAttestationSingle.json
Normal file
@@ -0,0 +1,147 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "_callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "_decimals",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_scales",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "_instanceOffset",
|
||||
"type": "uint8"
|
||||
},
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "constructor"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "accountCall",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "contractAddress",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "decimals",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "admin",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "instanceOffset",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "_callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "_decimals",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "updateAccountCalls",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "updateAdmin",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "verifier",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "verifyWithDataAttestation",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
@@ -73,8 +73,6 @@ impl Circuit<Fr> for MyCircuit {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::NCHW,
|
||||
kernel_format: KernelFormat::OIHW,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -69,7 +69,6 @@ impl Circuit<Fr> for MyCircuit {
|
||||
stride: vec![1, 1],
|
||||
kernel_shape: vec![2, 2],
|
||||
normalized: false,
|
||||
data_format: DataFormat::NCHW,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -23,6 +23,8 @@ use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
const L: usize = 10;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct MyCircuit {
|
||||
image: ValTensor<Fr>,
|
||||
@@ -38,7 +40,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::configure(cs, ())
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, 10>::configure(cs, ())
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -46,7 +48,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE> =
|
||||
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
|
||||
PoseidonChip::new(config);
|
||||
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
|
||||
Ok(())
|
||||
@@ -57,7 +59,7 @@ fn runposeidon(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("poseidon");
|
||||
|
||||
for size in [64, 784, 2352, 12288].iter() {
|
||||
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::num_rows(*size)
|
||||
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::num_rows(*size)
|
||||
as f32)
|
||||
.log2()
|
||||
.ceil() as u32;
|
||||
@@ -65,7 +67,7 @@ fn runposeidon(c: &mut Criterion) {
|
||||
|
||||
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
|
||||
let _output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.to_vec())
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
|
||||
.unwrap();
|
||||
|
||||
let mut image = Tensor::from(message.into_iter().map(Value::known));
|
||||
|
||||
@@ -10,7 +10,6 @@ use rand::Rng;
|
||||
|
||||
// Assuming these are your types
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
enum ValType {
|
||||
Constant(F),
|
||||
AssignedConstant(usize, F),
|
||||
@@ -22,7 +21,7 @@ fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..size)
|
||||
.map(|_i| {
|
||||
if rng.r#gen::<f64>() < zero_probability {
|
||||
if rng.gen::<f64>() < zero_probability {
|
||||
ValType::Constant(F::ZERO)
|
||||
} else {
|
||||
ValType::Constant(F::ONE) // Or some other non-zero value
|
||||
|
||||
@@ -8,27 +8,21 @@ contract LoadInstances {
|
||||
*/
|
||||
function getInstancesMemory(
|
||||
bytes memory encoded
|
||||
) public pure returns (uint256[] memory instances) {
|
||||
) internal pure returns (uint256[] memory instances) {
|
||||
bytes4 funcSig;
|
||||
uint256 instances_offset;
|
||||
uint256 instances_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := mload(add(encoded, 0x20))
|
||||
}
|
||||
if (funcSig == 0xaf83a18d) {
|
||||
instances_offset = 0x64;
|
||||
} else if (funcSig == 0x1e8e1e13) {
|
||||
instances_offset = 0x44;
|
||||
} else {
|
||||
revert("Invalid function signature");
|
||||
}
|
||||
assembly {
|
||||
|
||||
// Fetch instances offset which is 4 + 32 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
instances_offset := mload(add(encoded, instances_offset))
|
||||
instances_offset := mload(
|
||||
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
|
||||
)
|
||||
|
||||
instances_length := mload(add(add(encoded, 0x24), instances_offset))
|
||||
}
|
||||
@@ -47,10 +41,6 @@ contract LoadInstances {
|
||||
)
|
||||
}
|
||||
}
|
||||
require(
|
||||
funcSig == 0xaf83a18d || funcSig == 0x1e8e1e13,
|
||||
"Invalid function signature"
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
|
||||
@@ -59,31 +49,23 @@ contract LoadInstances {
|
||||
*/
|
||||
function getInstancesCalldata(
|
||||
bytes calldata encoded
|
||||
) public pure returns (uint256[] memory instances) {
|
||||
) internal pure returns (uint256[] memory instances) {
|
||||
bytes4 funcSig;
|
||||
uint256 instances_offset;
|
||||
uint256 instances_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := calldataload(encoded.offset)
|
||||
}
|
||||
if (funcSig == 0xaf83a18d) {
|
||||
instances_offset = 0x44;
|
||||
} else if (funcSig == 0x1e8e1e13) {
|
||||
instances_offset = 0x24;
|
||||
} else {
|
||||
revert("Invalid function signature");
|
||||
}
|
||||
// We need to create a new assembly block in order for solidity
|
||||
// to cast the funcSig to a bytes4 type. Otherwise it will load the entire first 32 bytes of the calldata
|
||||
// within the block
|
||||
assembly {
|
||||
|
||||
// Fetch instances offset which is 4 + 32 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
instances_offset := calldataload(
|
||||
add(encoded.offset, instances_offset)
|
||||
add(
|
||||
encoded.offset,
|
||||
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
|
||||
)
|
||||
)
|
||||
|
||||
instances_length := calldataload(
|
||||
@@ -114,7 +96,7 @@ contract LoadInstances {
|
||||
// The kzg commitments of a given model, all aggregated into a single bytes array.
|
||||
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
|
||||
// It will be used to check that the proof commitments match the expected commitments.
|
||||
bytes constant COMMITMENT_KZG = hex"1234";
|
||||
bytes constant COMMITMENT_KZG = hex"";
|
||||
|
||||
contract SwapProofCommitments {
|
||||
/**
|
||||
@@ -131,20 +113,17 @@ contract SwapProofCommitments {
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := calldataload(encoded.offset)
|
||||
}
|
||||
if (funcSig == 0xaf83a18d) {
|
||||
proof_offset = 0x24;
|
||||
} else if (funcSig == 0x1e8e1e13) {
|
||||
proof_offset = 0x04;
|
||||
} else {
|
||||
revert("Invalid function signature");
|
||||
}
|
||||
assembly {
|
||||
|
||||
// Fetch proof offset which is 4 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
proof_offset := calldataload(add(encoded.offset, proof_offset))
|
||||
proof_offset := calldataload(
|
||||
add(
|
||||
encoded.offset,
|
||||
add(0x04, mul(0x20, eq(funcSig, 0xaf83a18d)))
|
||||
)
|
||||
)
|
||||
|
||||
proof_length := calldataload(
|
||||
add(add(encoded.offset, 0x04), proof_offset)
|
||||
@@ -175,7 +154,7 @@ contract SwapProofCommitments {
|
||||
let wordCommitment := mload(add(commitment, i))
|
||||
equal := eq(wordProof, wordCommitment)
|
||||
if eq(equal, 0) {
|
||||
break
|
||||
return(0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -184,38 +163,36 @@ contract SwapProofCommitments {
|
||||
} /// end checkKzgCommits
|
||||
}
|
||||
|
||||
contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
// the address of the account to make calls to
|
||||
address public immutable contractAddress;
|
||||
|
||||
// the abi encoded function calls to make to the `contractAddress` that returns the attested to data
|
||||
bytes public callData;
|
||||
|
||||
struct Scalars {
|
||||
// The number of base 10 decimals to scale the data by.
|
||||
// For most ERC20 tokens this is 1e18
|
||||
contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @notice Struct used to make view only call to account to fetch the data that EZKL reads from.
|
||||
* @param the address of the account to make calls to
|
||||
* @param the abi encoded function calls to make to the `contractAddress`
|
||||
*/
|
||||
struct AccountCall {
|
||||
address contractAddress;
|
||||
bytes callData;
|
||||
uint256 decimals;
|
||||
// The number of fractional bits of the fixed point EZKL data points.
|
||||
uint256 bits;
|
||||
}
|
||||
AccountCall public accountCall;
|
||||
|
||||
Scalars[] private scalars;
|
||||
uint[] scales;
|
||||
|
||||
function getScalars(uint256 index) public view returns (Scalars memory) {
|
||||
return scalars[index];
|
||||
}
|
||||
address public admin;
|
||||
|
||||
/**
|
||||
* @notice EZKL P value
|
||||
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
|
||||
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
|
||||
*/
|
||||
uint256 public constant ORDER =
|
||||
uint256 constant ORDER =
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
);
|
||||
|
||||
uint256 public constant HALF_ORDER = ORDER >> 1;
|
||||
uint256 constant INPUT_LEN = 0;
|
||||
|
||||
uint256 constant OUTPUT_LEN = 0;
|
||||
|
||||
uint8 public instanceOffset;
|
||||
|
||||
@@ -227,27 +204,53 @@ contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
constructor(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256[] memory _decimals,
|
||||
uint[] memory _bits,
|
||||
uint8 _instanceOffset
|
||||
uint256 _decimals,
|
||||
uint[] memory _scales,
|
||||
uint8 _instanceOffset,
|
||||
address _admin
|
||||
) {
|
||||
require(
|
||||
_bits.length == _decimals.length,
|
||||
"Invalid scalar array lengths"
|
||||
);
|
||||
for (uint i; i < _bits.length; i++) {
|
||||
scalars.push(Scalars(10 ** _decimals[i], 1 << _bits[i]));
|
||||
admin = _admin;
|
||||
for (uint i; i < _scales.length; i++) {
|
||||
scales.push(1 << _scales[i]);
|
||||
}
|
||||
contractAddress = _contractAddresses;
|
||||
callData = _callData;
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
instanceOffset = _instanceOffset;
|
||||
}
|
||||
|
||||
function updateAdmin(address _admin) external {
|
||||
require(msg.sender == admin, "Only admin can update admin");
|
||||
if (_admin == address(0)) {
|
||||
revert();
|
||||
}
|
||||
admin = _admin;
|
||||
}
|
||||
|
||||
function updateAccountCalls(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256 _decimals
|
||||
) external {
|
||||
require(msg.sender == admin, "Only admin can update account calls");
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
}
|
||||
|
||||
function populateAccountCalls(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256 _decimals
|
||||
) internal {
|
||||
AccountCall memory _accountCall = accountCall;
|
||||
_accountCall.contractAddress = _contractAddresses;
|
||||
_accountCall.callData = _callData;
|
||||
_accountCall.decimals = 10 ** _decimals;
|
||||
accountCall = _accountCall;
|
||||
}
|
||||
|
||||
function mulDiv(
|
||||
uint256 x,
|
||||
uint256 y,
|
||||
uint256 denominator
|
||||
) public pure returns (uint256 result) {
|
||||
) internal pure returns (uint256 result) {
|
||||
unchecked {
|
||||
uint256 prod0;
|
||||
uint256 prod1;
|
||||
@@ -295,28 +298,21 @@ contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
|
||||
* @param x - One of the elements of the data returned from the account calls
|
||||
* @param _scalars - The scaling factors for the data returned from the account calls.
|
||||
* @param _decimals - Number of base 10 decimals to scale the data by.
|
||||
* @param _scale - The base 2 scale used to convert the floating point value into a fixed point value.
|
||||
*
|
||||
*/
|
||||
function quantizeData(
|
||||
int x,
|
||||
Scalars memory _scalars
|
||||
) public pure returns (int256 quantized_data) {
|
||||
if (_scalars.bits == 1 && _scalars.decimals == 1) {
|
||||
return x;
|
||||
}
|
||||
uint256 _decimals,
|
||||
uint256 _scale
|
||||
) internal pure returns (int256 quantized_data) {
|
||||
bool neg = x < 0;
|
||||
if (neg) x = -x;
|
||||
uint output = mulDiv(uint256(x), _scalars.bits, _scalars.decimals);
|
||||
if (
|
||||
mulmod(uint256(x), _scalars.bits, _scalars.decimals) * 2 >=
|
||||
_scalars.decimals
|
||||
) {
|
||||
uint output = mulDiv(uint256(x), _scale, _decimals);
|
||||
if (mulmod(uint256(x), _scale, _decimals) * 2 >= _decimals) {
|
||||
output += 1;
|
||||
}
|
||||
if (output > HALF_ORDER) {
|
||||
revert("Overflow field modulus");
|
||||
}
|
||||
quantized_data = neg ? -int256(output) : int256(output);
|
||||
}
|
||||
/**
|
||||
@@ -328,7 +324,7 @@ contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
function staticCall(
|
||||
address target,
|
||||
bytes memory data
|
||||
) public view returns (bytes memory) {
|
||||
) internal view returns (bytes memory) {
|
||||
(bool success, bytes memory returndata) = target.staticcall(data);
|
||||
if (success) {
|
||||
if (returndata.length == 0) {
|
||||
@@ -349,7 +345,7 @@ contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
*/
|
||||
function toFieldElement(
|
||||
int256 x
|
||||
) public pure returns (uint256 field_element) {
|
||||
) internal pure returns (uint256 field_element) {
|
||||
// The casting down to uint256 is safe because the order is about 2^254, and the value
|
||||
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
|
||||
return uint256(x + int(ORDER)) % ORDER;
|
||||
@@ -359,16 +355,315 @@ contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
|
||||
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
|
||||
*/
|
||||
function attestData(uint256[] memory instances) public view {
|
||||
bytes memory returnData = staticCall(contractAddress, callData);
|
||||
function attestData(uint256[] memory instances) internal view {
|
||||
require(
|
||||
instances.length >= INPUT_LEN + OUTPUT_LEN,
|
||||
"Invalid public inputs length"
|
||||
);
|
||||
AccountCall memory _accountCall = accountCall;
|
||||
uint[] memory _scales = scales;
|
||||
bytes memory returnData = staticCall(
|
||||
_accountCall.contractAddress,
|
||||
_accountCall.callData
|
||||
);
|
||||
int256[] memory x = abi.decode(returnData, (int256[]));
|
||||
int output;
|
||||
uint fieldElement;
|
||||
uint _offset;
|
||||
int output = quantizeData(x[0], _accountCall.decimals, _scales[0]);
|
||||
uint field_element = toFieldElement(output);
|
||||
for (uint i = 0; i < x.length; i++) {
|
||||
output = quantizeData(x[i], scalars[i]);
|
||||
fieldElement = toFieldElement(output);
|
||||
if (fieldElement != instances[i]) {
|
||||
revert("Public input does not match");
|
||||
if (field_element != instances[i + instanceOffset]) {
|
||||
_offset += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
uint length = x.length - _offset;
|
||||
for (uint i = 1; i < length; i++) {
|
||||
output = quantizeData(x[i], _accountCall.decimals, _scales[i]);
|
||||
field_element = toFieldElement(output);
|
||||
require(
|
||||
field_element == instances[i + instanceOffset + _offset],
|
||||
"Public input does not match"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Verify the proof with the data attestation.
|
||||
* @param verifier - The address of the verifier contract.
|
||||
* @param encoded - The verifier calldata.
|
||||
*/
|
||||
function verifyWithDataAttestation(
|
||||
address verifier,
|
||||
bytes calldata encoded
|
||||
) public view returns (bool) {
|
||||
require(verifier.code.length > 0, "Address: call to non-contract");
|
||||
attestData(getInstancesCalldata(encoded));
|
||||
// static call the verifier contract to verify the proof
|
||||
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
|
||||
|
||||
if (success) {
|
||||
return abi.decode(returndata, (bool));
|
||||
} else {
|
||||
revert("low-level call to verifier failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This contract serves as a Data Attestation Verifier for the EZKL model.
|
||||
// It is designed to read and attest to instances of proofs generated from a specified circuit.
|
||||
// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions.
|
||||
|
||||
// Overview of the contract functionality:
|
||||
// 1. Initialization: Through the constructor, it sets up the contract calls that the EZKL model will read from.
|
||||
// 2. Data Quantization: Quantizes the returned data into a scaled fixed-point representation. See the `quantizeData` method for details.
|
||||
// 3. Static Calls: Makes static calls to fetch data from other contracts. See the `staticCall` method.
|
||||
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
|
||||
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
|
||||
// 6. Proof Verification: The `verifyWithDataAttestationMulti` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
|
||||
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
|
||||
// then calls the `verifyProof` method to verify the proof on the verifier.
|
||||
|
||||
contract DataAttestationMulti is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
|
||||
* @param the address of the account to make calls to
|
||||
* @param the abi encoded function calls to make to the `contractAddress`
|
||||
*/
|
||||
struct AccountCall {
|
||||
address contractAddress;
|
||||
mapping(uint256 => bytes) callData;
|
||||
mapping(uint256 => uint256) decimals;
|
||||
uint callCount;
|
||||
}
|
||||
AccountCall[] public accountCalls;
|
||||
|
||||
uint[] public scales;
|
||||
|
||||
address public admin;
|
||||
|
||||
/**
|
||||
* @notice EZKL P value
|
||||
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
|
||||
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
|
||||
*/
|
||||
uint256 constant ORDER =
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
);
|
||||
|
||||
uint256 constant INPUT_CALLS = 0;
|
||||
|
||||
uint256 constant OUTPUT_CALLS = 0;
|
||||
|
||||
uint8 public instanceOffset;
|
||||
|
||||
/**
|
||||
* @dev Initialize the contract with account calls the EZKL model will read from.
|
||||
* @param _contractAddresses - The calls to all the contracts EZKL reads storage from.
|
||||
* @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
|
||||
*/
|
||||
constructor(
|
||||
address[] memory _contractAddresses,
|
||||
bytes[][] memory _callData,
|
||||
uint256[][] memory _decimals,
|
||||
uint[] memory _scales,
|
||||
uint8 _instanceOffset,
|
||||
address _admin
|
||||
) {
|
||||
admin = _admin;
|
||||
for (uint i; i < _scales.length; i++) {
|
||||
scales.push(1 << _scales[i]);
|
||||
}
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
instanceOffset = _instanceOffset;
|
||||
}
|
||||
|
||||
function updateAdmin(address _admin) external {
|
||||
require(msg.sender == admin, "Only admin can update admin");
|
||||
if (_admin == address(0)) {
|
||||
revert();
|
||||
}
|
||||
admin = _admin;
|
||||
}
|
||||
|
||||
function updateAccountCalls(
|
||||
address[] memory _contractAddresses,
|
||||
bytes[][] memory _callData,
|
||||
uint256[][] memory _decimals
|
||||
) external {
|
||||
require(msg.sender == admin, "Only admin can update account calls");
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
}
|
||||
|
||||
function populateAccountCalls(
|
||||
address[] memory _contractAddresses,
|
||||
bytes[][] memory _callData,
|
||||
uint256[][] memory _decimals
|
||||
) internal {
|
||||
require(
|
||||
_contractAddresses.length == _callData.length &&
|
||||
accountCalls.length == _contractAddresses.length,
|
||||
"Invalid input length"
|
||||
);
|
||||
require(
|
||||
_decimals.length == _contractAddresses.length,
|
||||
"Invalid number of decimals"
|
||||
);
|
||||
// fill in the accountCalls storage array
|
||||
uint counter = 0;
|
||||
for (uint256 i = 0; i < _contractAddresses.length; i++) {
|
||||
AccountCall storage accountCall = accountCalls[i];
|
||||
accountCall.contractAddress = _contractAddresses[i];
|
||||
accountCall.callCount = _callData[i].length;
|
||||
for (uint256 j = 0; j < _callData[i].length; j++) {
|
||||
accountCall.callData[j] = _callData[i][j];
|
||||
accountCall.decimals[j] = 10 ** _decimals[i][j];
|
||||
}
|
||||
// count the total number of storage reads across all of the accounts
|
||||
counter += _callData[i].length;
|
||||
}
|
||||
require(
|
||||
counter == INPUT_CALLS + OUTPUT_CALLS,
|
||||
"Invalid number of calls"
|
||||
);
|
||||
}
|
||||
|
||||
function mulDiv(
|
||||
uint256 x,
|
||||
uint256 y,
|
||||
uint256 denominator
|
||||
) internal pure returns (uint256 result) {
|
||||
unchecked {
|
||||
uint256 prod0;
|
||||
uint256 prod1;
|
||||
assembly {
|
||||
let mm := mulmod(x, y, not(0))
|
||||
prod0 := mul(x, y)
|
||||
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
|
||||
}
|
||||
|
||||
if (prod1 == 0) {
|
||||
return prod0 / denominator;
|
||||
}
|
||||
|
||||
require(denominator > prod1, "Math: mulDiv overflow");
|
||||
|
||||
uint256 remainder;
|
||||
assembly {
|
||||
remainder := mulmod(x, y, denominator)
|
||||
prod1 := sub(prod1, gt(remainder, prod0))
|
||||
prod0 := sub(prod0, remainder)
|
||||
}
|
||||
|
||||
uint256 twos = denominator & (~denominator + 1);
|
||||
assembly {
|
||||
denominator := div(denominator, twos)
|
||||
prod0 := div(prod0, twos)
|
||||
twos := add(div(sub(0, twos), twos), 1)
|
||||
}
|
||||
|
||||
prod0 |= prod1 * twos;
|
||||
|
||||
uint256 inverse = (3 * denominator) ^ 2;
|
||||
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
|
||||
result = prod0 * inverse;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
|
||||
* @param data - The data returned from the account calls.
|
||||
* @param decimals - The number of decimals the data returned from the account calls has (for floating point representation).
|
||||
* @param scale - The scale used to convert the floating point value into a fixed point value.
|
||||
*/
|
||||
function quantizeData(
|
||||
bytes memory data,
|
||||
uint256 decimals,
|
||||
uint256 scale
|
||||
) internal pure returns (int256 quantized_data) {
|
||||
int x = abi.decode(data, (int256));
|
||||
bool neg = x < 0;
|
||||
if (neg) x = -x;
|
||||
uint output = mulDiv(uint256(x), scale, decimals);
|
||||
if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) {
|
||||
output += 1;
|
||||
}
|
||||
quantized_data = neg ? -int256(output) : int256(output);
|
||||
}
|
||||
/**
|
||||
* @dev Make a static call to the account to fetch the data that EZKL reads from.
|
||||
* @param target - The address of the account to make calls to.
|
||||
* @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
|
||||
* @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise)
|
||||
*/
|
||||
function staticCall(
|
||||
address target,
|
||||
bytes memory data
|
||||
) internal view returns (bytes memory) {
|
||||
(bool success, bytes memory returndata) = target.staticcall(data);
|
||||
if (success) {
|
||||
if (returndata.length == 0) {
|
||||
require(
|
||||
target.code.length > 0,
|
||||
"Address: call to non-contract"
|
||||
);
|
||||
}
|
||||
return returndata;
|
||||
} else {
|
||||
revert("Address: low-level call failed");
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @dev Convert the fixed point quantized data into a field element.
|
||||
* @param x - The quantized data.
|
||||
* @return field_element - The field element.
|
||||
*/
|
||||
function toFieldElement(
|
||||
int256 x
|
||||
) internal pure returns (uint256 field_element) {
|
||||
// The casting down to uint256 is safe because the order is about 2^254, and the value
|
||||
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
|
||||
return uint256(x + int(ORDER)) % ORDER;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
|
||||
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
|
||||
*/
|
||||
function attestData(uint256[] memory instances) internal view {
|
||||
require(
|
||||
instances.length >= INPUT_CALLS + OUTPUT_CALLS,
|
||||
"Invalid public inputs length"
|
||||
);
|
||||
uint256 _accountCount = accountCalls.length;
|
||||
uint counter = 0;
|
||||
for (uint8 i = 0; i < _accountCount; ++i) {
|
||||
address account = accountCalls[i].contractAddress;
|
||||
for (uint8 j = 0; j < accountCalls[i].callCount; j++) {
|
||||
bytes memory returnData = staticCall(
|
||||
account,
|
||||
accountCalls[i].callData[j]
|
||||
);
|
||||
uint256 scale = scales[counter];
|
||||
int256 quantized_data = quantizeData(
|
||||
returnData,
|
||||
accountCalls[i].decimals[j],
|
||||
scale
|
||||
);
|
||||
uint256 field_element = toFieldElement(quantized_data);
|
||||
require(
|
||||
field_element == instances[counter + instanceOffset],
|
||||
"Public input does not match"
|
||||
);
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
## EZKL Security Note: Public Commitments and Low-Entropy Data
|
||||
|
||||
> **Disclaimer:** this a more technical post that requires some prior knowledge of how ZK proving systems like Halo2 operate, and in particular in how these APIs are constructed. For background reading we highly recommend the [Halo2 book](https://zcash.github.io/halo2/) and [Halo2 Club](https://halo2.club/).
|
||||
|
||||
## Overview of commitments in EZKL
|
||||
|
||||
A common design pattern in a zero knowledge (zk) application is thus:
|
||||
- A prover has some data which is used within a circuit.
|
||||
- This data, as it may be high-dimensional or somewhat private, is pre-committed to using some hash function.
|
||||
- The zk-circuit which forms the core of the application then proves (para-phrasing) a statement of the form:
|
||||
>"I know some data D which when hashed corresponds to the pre-committed to value H + whatever else the circuit is proving over D".
|
||||
|
||||
From our own experience, we've implemented such patterns using snark-friendly hash functions like [Poseidon](https://www.poseidon-hash.info/), for which there is a relatively well vetted [implementation](https://docs.rs/halo2_gadgets/latest/halo2_gadgets/poseidon/index.html) in Halo2. Even then these hash functions can introduce lots of overhead and can be very expensive to generate proofs for if the dimensionality of the data D is large.
|
||||
|
||||
You can also implement such a pattern using Halo2's `Fixed` columns _if the privacy preservation of the pre-image is not necessary_. These are Halo2 columns (i.e in reality just polynomials) that are left unblinded (unlike the blinded `Advice` columns), and whose commitments are shared with the verifier by way of the verifying key for the application's zk-circuit. These commitments are much lower cost to generate than implementing a hashing function, such as Poseidon, within a circuit.
|
||||
|
||||
> **Note:** Blinding is the process whereby a certain set of the final elements (i.e rows) of a Halo2 column are set to random field elements. This is the mechanism by which Halo2 achieves its zero knowledge properties for `Advice` columns. By contrast `Fixed` columns aren't zero-knowledge in that they are vulnerable to dictionary attacks in the same manner a hash function is. Given some set of known or popular data D an attacker can attempt to recover the pre-image of a hash by running D through the hash function to see if the outputs match a public commitment. These attacks aren't "possible" on blinded `Advice` columns.
|
||||
|
||||
> **Further Note:** Note that without blinding, with access to `M` proofs, each of which contains an evaluation of the polynomial at a different point, an attacker can more easily recover a non blinded column's pre-image. This is because each proof generates a new query and evaluation of the polynomial represented by the column and as such with repetition a clearer picture can emerge of the column's pre-image. Thus unblinded columns should only be used for privacy preservation, in the manner of a hash, if the number of proofs generated against a fixed set of values is limited. More formally if M independent and _unique_ queries are generated; if M is equal to the degree + 1 of the polynomial represented by the column (i.e the unique lagrange interpolation of the values in the columns), then the column's pre-image can be recovered. As such as the logrows K increases, the more queries are required to recover the pre-image (as 2^K unique queries are required). This assumes that the entries in the column are not structured, as if they are then the number of queries required to recover the pre-image is reduced (eg. if all rows above a certain point are known to be nil).
|
||||
|
||||
The annoyance in using `Fixed` columns comes from the fact that they require generating a new verifying key every time a new set of commitments is generated.
|
||||
|
||||
> **Example:** Say for instance an application leverages a zero-knowledge circuit to prove the correct execution of a neural network. Every week the neural network is finetuned or retrained on new data. If the architecture remains the same then commiting to the new network parameters, along with a new proof of performance on a test set, would be an ideal setup. If we leverage `Fixed` columns to commit to the model parameters, each new commitment will require re-generating a verifying key and sharing the new key with the verifier(s). This is not-ideal UX and can become expensive if the verifier is deployed on-chain.
|
||||
|
||||
An ideal commitment would thus have the low cost of a `Fixed` column but wouldn't require regenerating a new verifying key for each new commitment.
|
||||
|
||||
### Unblinded Advice Columns
|
||||
|
||||
A first step in designing such a commitment is to allow for optionally unblinded `Advice` columns within the Halo2 API. These won't be included in the verifying key, AND are blinded with a constant factor `1` -- such that if someone knows the pre-image to the commitment, they can recover it by running it through the corresponding polynomial commitment scheme (in ezkl's case [KZG commitments](https://dankradfeist.de/ethereum/2020/06/16/kate-polynomial-commitments.html)).
|
||||
|
||||
This is implemented using the `polycommit` visibility parameter in the ezkl API.
|
||||
|
||||
## The Vulnerability of Public Commitments
|
||||
|
||||
|
||||
Public commitments in EZKL (both Poseidon-hashed inputs and KZG commitments) can be vulnerable to brute-force attacks when input data has low entropy. A malicious actor could reveal committed data by searching through possible input values, compromising privacy in applications like anonymous credentials. This is particularly relevant when input data comes from known finite sets (e.g., names, dates).
|
||||
|
||||
Example Risk: In an anonymous credential system using EZKL for ID verification, an attacker could match hashed outputs against a database of common identifying information to deanonymize users.
|
||||
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# EZKL Security Note: Quantization-Activated Model Backdoors
|
||||
|
||||
## Model backdoors and provenance
|
||||
|
||||
Machine learning models inherently suffer from robustness issues, which can lead to various
|
||||
kinds of attacks, from backdoors to evasion attacks. These vulnerabilities are a direct byproductof how machine learning models learn and cannot be remediated.
|
||||
|
||||
We say a model has a backdoor whenever a specific attacker-chosen trigger in the input leads
|
||||
to the model misbehaving. For instance, if we have an image classifier discriminating cats from dogs, the ability to turn any image of a cat into an image classified as a dog by changing a specific pixel pattern constitutes a backdoor.
|
||||
|
||||
Backdoors can be introduced using many different vectors. An attacker can introduce a
|
||||
backdoor using traditional security vulnerabilities. For instance, they could directly alter the file containing model weights or dynamically hack the Python code of the model. In addition, backdoors can be introduced by the training data through a process known as poisoning. In this case, an attacker adds malicious data points to the dataset before the model is trained so that the model learns to associate the backdoor trigger with the intended misbehavior.
|
||||
|
||||
All these vectors constitute a whole range of provenance challenges, as any component of an
|
||||
AI system can virtually be an entrypoint for a backdoor. Although provenance is already a
|
||||
concern with traditional code, the issue is exacerbated with AI, as retraining a model is
|
||||
cost-prohibitive. It is thus impractical to translate the “recompile it yourself” thinking to AI.
|
||||
|
||||
## Quantization activated backdoors
|
||||
|
||||
Backdoors are a generic concern in AI that is outside the scope of EZKL. However, EZKL may
|
||||
activate a specific subset of backdoors. Several academic papers have demonstrated the
|
||||
possibility, both in theory and in practice, of implanting undetectable and inactive backdoors in a full precision model that can be reactivated by quantization.
|
||||
|
||||
An external attacker may trick the user of an application running EZKL into loading a model
|
||||
containing a quantization backdoor. This backdoor is active in the resulting model and circuit but not in the full-precision model supplied to EZKL, compromising the integrity of the target application and the resulting proof.
|
||||
|
||||
### When is this a concern for me as a user?
|
||||
|
||||
Any untrusted component in your AI stack may be a backdoor vector. In practice, the most
|
||||
sensitive parts include:
|
||||
|
||||
- Datasets downloaded from the web or containing crowdsourced data
|
||||
- Models downloaded from the web even after finetuning
|
||||
- Untrusted software dependencies (well-known frameworks such as PyTorch can typically
|
||||
be considered trusted)
|
||||
- Any component loaded through an unsafe serialization format, such as Pickle.
|
||||
Because backdoors are inherent to ML and cannot be eliminated, reviewing the provenance of
|
||||
these sensitive components is especially important.
|
||||
|
||||
### Responsibilities of the user and EZKL
|
||||
|
||||
As EZKL cannot prevent backdoored models from being used, it is the responsibility of the user to review the provenance of all the components in their AI stack to ensure that no backdoor could have been implanted. EZKL shall not be held responsible for misleading prediction proofs resulting from using a backdoored model or for any harm caused to a system or its users due to a misbehaving model.
|
||||
|
||||
### Limitations:
|
||||
|
||||
- Attack effectiveness depends on calibration settings and internal rescaling operations.
|
||||
- Further research needed on backdoor persistence through witness/proof stages.
|
||||
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model in pytorch or onnx-runtime as difference in evaluation could reveal a backdoor.
|
||||
|
||||
References:
|
||||
|
||||
1. [Quantization Backdoors to Deep Learning Commercial Frameworks (Ma et al., 2021)](https://arxiv.org/abs/2108.09187)
|
||||
2. [Planting Undetectable Backdoors in Machine Learning Models (Goldwasser et al., 2022)](https://arxiv.org/abs/2204.06974)
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '21.0.0'
|
||||
release = '17.1.5'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ use mnist::*;
|
||||
use rand::rngs::OsRng;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
||||
mod params;
|
||||
|
||||
const K: usize = 20;
|
||||
@@ -209,8 +208,6 @@ where
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::NCHW,
|
||||
kernel_format: KernelFormat::OIHW,
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,13 +0,0 @@
|
||||
# download tess data
|
||||
# check if first argument has been set
|
||||
if [ ! -z "$1" ]; then
|
||||
DATA_DIR=$1
|
||||
else
|
||||
DATA_DIR=data
|
||||
fi
|
||||
|
||||
echo "Downloading data to $DATA_DIR"
|
||||
|
||||
if [ ! -d "$DATA_DIR/CATDOG" ]; then
|
||||
kaggle datasets download tongpython/cat-and-dog -p $DATA_DIR/CATDOG --unzip
|
||||
fi
|
||||
@@ -272,21 +272,33 @@
|
||||
"\n",
|
||||
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
|
||||
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like:\n",
|
||||
" \n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": {\n",
|
||||
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
|
||||
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
|
||||
" \"len\": 12 // The number of data points returned by the view function (the length of the array)\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"```"
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": [\n",
|
||||
" {\n",
|
||||
" \"call_data\": [\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
|
||||
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
|
||||
" 5\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
|
||||
" 5\n",
|
||||
" ]\n",
|
||||
" ],\n",
|
||||
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -295,7 +307,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await ezkl.setup_test_evm_data(\n",
|
||||
"await ezkl.setup_test_evm_witness(\n",
|
||||
" data_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" # we write the call data to the same file as the input data\n",
|
||||
|
||||
@@ -337,7 +337,6 @@
|
||||
"w3 = Web3(HTTPProvider(RPC_URL))\n",
|
||||
"\n",
|
||||
"def test_on_chain_data(res):\n",
|
||||
" print(f'poseidon_hash: {res[\"processed_outputs\"][\"poseidon_hash\"]}')\n",
|
||||
" # Step 0: Convert the tensor to a flat list\n",
|
||||
" data = [int(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
|
||||
"\n",
|
||||
@@ -357,9 +356,6 @@
|
||||
" arr.push(_numbers[i]);\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" function getArr() public view returns (uint[] memory) {\n",
|
||||
" return arr;\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
@@ -386,30 +382,31 @@
|
||||
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
|
||||
"\n",
|
||||
" # Step 4: Interact with the contract\n",
|
||||
" calldata = contract.functions.getArr().build_transaction()['data'][2:]\n",
|
||||
" calldata = []\n",
|
||||
" for i, _ in enumerate(data):\n",
|
||||
" call = contract.functions.arr(i).build_transaction()\n",
|
||||
" calldata.append((call['data'][2:], 0))\n",
|
||||
"\n",
|
||||
" # Prepare the calls_to_account object\n",
|
||||
" # If you were calling view functions across multiple contracts,\n",
|
||||
" # you would have multiple entries in the calls_to_account array,\n",
|
||||
" # one for each contract.\n",
|
||||
" decimals = [0] * len(data)\n",
|
||||
" call_to_account = {\n",
|
||||
" calls_to_account = [{\n",
|
||||
" 'call_data': calldata,\n",
|
||||
" 'decimals': decimals,\n",
|
||||
" 'address': contract.address[2:], # remove the '0x' prefix\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"\n",
|
||||
" print(f'call_to_account: {call_to_account}')\n",
|
||||
" print(f'calls_to_account: {calls_to_account}')\n",
|
||||
"\n",
|
||||
" return call_to_account\n",
|
||||
" return calls_to_account\n",
|
||||
"\n",
|
||||
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
|
||||
"start_anvil()\n",
|
||||
"\n",
|
||||
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
|
||||
"call_to_account = test_on_chain_data(res)\n",
|
||||
"calls_to_account = test_on_chain_data(res)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array], output_data = {'rpc': RPC_URL, 'call': call_to_account })\n",
|
||||
"data = dict(input_data = [data_array], output_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
|
||||
"\n",
|
||||
"# Serialize on-chain data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n",
|
||||
@@ -637,7 +634,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"display_name": "ezkl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -651,7 +648,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.12.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -276,21 +276,33 @@
|
||||
"\n",
|
||||
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
|
||||
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like:\n",
|
||||
" \n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": {\n",
|
||||
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
|
||||
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
|
||||
" \"len\": 3 // The number of data points returned by the view function (the length of the array)\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"```"
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": [\n",
|
||||
" {\n",
|
||||
" \"call_data\": [\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
|
||||
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
|
||||
" 5\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
|
||||
" 5\n",
|
||||
" ]\n",
|
||||
" ],\n",
|
||||
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -299,7 +311,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await ezkl.setup_test_evm_data(\n",
|
||||
"await ezkl.setup_test_evm_witness(\n",
|
||||
" data_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" # we write the call data to the same file as the input data\n",
|
||||
@@ -325,7 +337,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -336,6 +348,27 @@
|
||||
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -358,27 +391,6 @@
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@@ -569,7 +581,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"display_name": "ezkl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -583,7 +595,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.9.13"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -77,7 +77,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gip_run_args = ezkl.PyRunArgs()\n",
|
||||
"gip_run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
|
||||
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
|
||||
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
|
||||
@@ -336,9 +335,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -308,11 +308,8 @@
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.decomp_legs = 4\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
|
||||
@@ -453,18 +453,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now mock aggregate the proofs\n",
|
||||
"# proofs = []\n",
|
||||
"# for i in range(3):\n",
|
||||
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
"# proofs.append(proof_path)\n",
|
||||
"proofs = []\n",
|
||||
"for i in range(3):\n",
|
||||
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
" proofs.append(proof_path)\n",
|
||||
"\n",
|
||||
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
|
||||
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"display_name": "ezkl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -478,7 +478,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -152,11 +152,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# logrows\n",
|
||||
"run_args.logrows = 20\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -304,7 +302,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -220,6 +220,15 @@
|
||||
"Check that the generated verifiers are identical for all models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"start_anvil()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -167,8 +167,6 @@
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# \"hashed/private\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
|
||||
"run_args.input_visibility = \"hashed/private/0\"\n",
|
||||
"# as the inputs are felts we turn off input range checks\n",
|
||||
"run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"# we set it to fix the set we want to check membership for\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"# the output is public -- set membership fails if it is not = 0\n",
|
||||
@@ -521,4 +519,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,7 +204,6 @@
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
|
||||
"run_args.input_visibility = \"polycommit\"\n",
|
||||
"run_args.ignore_range_check_inputs_outputs = True\n",
|
||||
"# the parameters are public\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"# the output is public (this is the inequality test)\n",
|
||||
@@ -515,4 +514,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -60,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -94,7 +94,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -134,7 +134,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -183,7 +183,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -201,7 +201,6 @@
|
||||
"run_args.input_visibility = \"public\"\n",
|
||||
"run_args.param_visibility = \"private\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.decomp_legs=5\n",
|
||||
"run_args.num_inner_cols = 1\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]"
|
||||
]
|
||||
@@ -270,7 +269,7 @@
|
||||
"{\n",
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"call\": {\n",
|
||||
" \"calls\": {\n",
|
||||
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
|
||||
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
|
||||
@@ -295,6 +294,7 @@
|
||||
"import torch\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"# This function counts the decimal places of a floating point number\n",
|
||||
"def count_decimal_places(num):\n",
|
||||
" num_str = str(num)\n",
|
||||
" if '.' in num_str:\n",
|
||||
@@ -302,28 +302,69 @@
|
||||
" else:\n",
|
||||
" return 0\n",
|
||||
"\n",
|
||||
"# setup web3 instance\n",
|
||||
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
|
||||
"\n",
|
||||
"def set_next_block_timestamp(anvil_url, timestamp):\n",
|
||||
" # Send the JSON-RPC request to Anvil\n",
|
||||
" payload = {\n",
|
||||
" \"jsonrpc\": \"2.0\",\n",
|
||||
" \"id\": 1,\n",
|
||||
" \"method\": \"evm_setNextBlockTimestamp\",\n",
|
||||
" \"params\": [timestamp]\n",
|
||||
" }\n",
|
||||
" response = requests.post(anvil_url, json=payload)\n",
|
||||
" if response.status_code == 200:\n",
|
||||
" print(f\"Next block timestamp set to: {timestamp}\")\n",
|
||||
" else:\n",
|
||||
" print(f\"Failed to set next block timestamp: {response.text}\")\n",
|
||||
"\n",
|
||||
"def on_chain_data(tensor):\n",
|
||||
" # Step 0: Convert the tensor to a flat list\n",
|
||||
" data = tensor.view(-1).tolist()\n",
|
||||
"\n",
|
||||
" # Step 1: Prepare the calldata\n",
|
||||
" secondsAgo = [len(data) - 1 - i for i in range(len(data))]\n",
|
||||
"\n",
|
||||
" # Step 2: Prepare and compile the contract UniTickAttestor contract\n",
|
||||
" contract_source_code = '''\n",
|
||||
" // SPDX-License-Identifier: MIT\n",
|
||||
" pragma solidity ^0.8.20;\n",
|
||||
"\n",
|
||||
" /// @title Pool state that is not stored\n",
|
||||
" /// @notice Contains view functions to provide information about the pool that is computed rather than stored on the\n",
|
||||
" /// blockchain. The functions here may have variable gas costs.\n",
|
||||
" interface IUniswapV3PoolDerivedState {\n",
|
||||
" /// @notice Returns the cumulative tick and liquidity as of each timestamp `secondsAgo` from the current block timestamp\n",
|
||||
" /// @dev To get a time weighted average tick or liquidity-in-range, you must call this with two values, one representing\n",
|
||||
" /// the beginning of the period and another for the end of the period. E.g., to get the last hour time-weighted average tick,\n",
|
||||
" /// you must call it with secondsAgos = [3600, 0].\n",
|
||||
" /// log base sqrt(1.0001) of token1 / token0. The TickMath library can be used to go from a tick value to a ratio.\n",
|
||||
" /// @dev The time weighted average tick represents the geometric time weighted average price of the pool, in\n",
|
||||
" /// @param secondsAgos From how long ago each cumulative tick and liquidity value should be returned\n",
|
||||
" /// @return tickCumulatives Cumulative tick values as of each `secondsAgos` from the current block timestamp\n",
|
||||
" /// @return secondsPerLiquidityCumulativeX128s Cumulative seconds per liquidity-in-range value as of each `secondsAgos` from the current block\n",
|
||||
" /// timestamp\n",
|
||||
" function observe(\n",
|
||||
" uint32[] calldata secondsAgos\n",
|
||||
" ) external view returns (\n",
|
||||
" int56[] memory tickCumulatives,\n",
|
||||
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
|
||||
" );\n",
|
||||
" )\n",
|
||||
" external\n",
|
||||
" view\n",
|
||||
" returns (\n",
|
||||
" int56[] memory tickCumulatives,\n",
|
||||
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
|
||||
" );\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" /// @title Uniswap Wrapper around `pool.observe` that stores the parameters for fetching and then attesting to historical data\n",
|
||||
" /// @notice Provides functions to integrate with V3 pool oracle\n",
|
||||
" contract UniTickAttestor {\n",
|
||||
" int256[] private cachedTicks;\n",
|
||||
"\n",
|
||||
" /**\n",
|
||||
" * @notice Calculates time-weighted means of tick and liquidity for a given Uniswap V3 pool\n",
|
||||
" * @param pool Address of the pool that we want to observe\n",
|
||||
" * @param secondsAgo Number of seconds in the past from which to calculate the time-weighted means\n",
|
||||
" * @return tickCumulatives The cumulative tick values as of each `secondsAgo` from the current block timestamp\n",
|
||||
" */\n",
|
||||
" function consult(\n",
|
||||
" IUniswapV3PoolDerivedState pool,\n",
|
||||
" uint32[] memory secondsAgo\n",
|
||||
@@ -334,21 +375,6 @@
|
||||
" tickCumulatives[i] = int256(_ticks[i]);\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" function cache_price(\n",
|
||||
" IUniswapV3PoolDerivedState pool,\n",
|
||||
" uint32[] memory secondsAgo\n",
|
||||
" ) public {\n",
|
||||
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
|
||||
" cachedTicks = new int256[](_ticks.length);\n",
|
||||
" for (uint256 i = 0; i < _ticks.length; i++) {\n",
|
||||
" cachedTicks[i] = int256(_ticks[i]);\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" function readPriceCache() public view returns (int256[] memory) {\n",
|
||||
" return cachedTicks;\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
@@ -358,44 +384,69 @@
|
||||
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
" # Get bytecode\n",
|
||||
" bytecode = compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['evm']['bytecode']['object']\n",
|
||||
"\n",
|
||||
" # Get ABI\n",
|
||||
" # In production if you are reading from really large contracts you can just use\n",
|
||||
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
|
||||
" abi = json.loads(compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['metadata'])['output']['abi']\n",
|
||||
"\n",
|
||||
" # Deploy contract\n",
|
||||
" # Step 3: Deploy the contract\n",
|
||||
" UniTickAttestor = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
|
||||
" tx_hash = UniTickAttestor.constructor().transact()\n",
|
||||
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
|
||||
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
|
||||
" # passing the address and abi of the contract you are fetching data from.\n",
|
||||
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
|
||||
"\n",
|
||||
" # Step 4: Store data via cache_price transaction\n",
|
||||
" tx_hash = contract.functions.cache_price(\n",
|
||||
" # Step 4: Interact with the contract\n",
|
||||
" call = contract.functions.consult(\n",
|
||||
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
|
||||
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
|
||||
" secondsAgo\n",
|
||||
" ).transact()\n",
|
||||
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
|
||||
"\n",
|
||||
" # Step 5: Prepare calldata for readPriceCache\n",
|
||||
" call = contract.functions.readPriceCache().build_transaction()\n",
|
||||
" ).build_transaction()\n",
|
||||
" result = contract.functions.consult(\n",
|
||||
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
|
||||
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
|
||||
" secondsAgo\n",
|
||||
" ).call()\n",
|
||||
" \n",
|
||||
" print(f'result: {result}')\n",
|
||||
" calldata = call['data'][2:]\n",
|
||||
"\n",
|
||||
" # Get stored data\n",
|
||||
" result = contract.functions.readPriceCache().call()\n",
|
||||
" print(f'Cached ticks: {result}')\n",
|
||||
" time_stamp = w3.eth.get_block('latest')['timestamp']\n",
|
||||
"\n",
|
||||
" decimals = [0] * len(data)\n",
|
||||
" print(f'time_stamp: {time_stamp}')\n",
|
||||
"\n",
|
||||
" # Set the next block timestamp using the fetched time_stamp\n",
|
||||
" set_next_block_timestamp(RPC_URL, time_stamp)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Prepare the calls_to_account object\n",
|
||||
" # If you were calling view functions across multiple contracts,\n",
|
||||
" # you would have multiple entries in the calls_to_account array,\n",
|
||||
" # one for each contract.\n",
|
||||
" call_to_account = {\n",
|
||||
" 'call_data': calldata,\n",
|
||||
" 'decimals': decimals,\n",
|
||||
" 'address': contract.address[2:],\n",
|
||||
" 'decimals': 0,\n",
|
||||
" 'address': contract.address[2:], # remove the '0x' prefix\n",
|
||||
" 'len': len(data),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" print(f'call_to_account: {call_to_account}')\n",
|
||||
"\n",
|
||||
" return call_to_account\n",
|
||||
"\n",
|
||||
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
|
||||
"start_anvil()\n",
|
||||
"call_to_account = on_chain_data(x)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = {'rpc': RPC_URL, 'call': call_to_account })\n",
|
||||
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
|
||||
"calls_to_account = on_chain_data(x)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
|
||||
"\n",
|
||||
"# Serialize on-chain data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))"
|
||||
]
|
||||
},
|
||||
@@ -640,7 +691,34 @@
|
||||
"source": [
|
||||
"# !export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"# print(res)\n",
|
||||
"calls_to_account = on_chain_data(x)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
|
||||
"\n",
|
||||
"# Serialize on-chain data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n",
|
||||
"\n",
|
||||
"# setup web3 instance\n",
|
||||
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
|
||||
"\n",
|
||||
"time_stamp = w3.eth.get_block('latest')['timestamp']\n",
|
||||
"\n",
|
||||
"print(f'time_stamp: {time_stamp}')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"# read the verifier address\n",
|
||||
"addr_verifier = None\n",
|
||||
|
||||
@@ -246,7 +246,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ezkl.setup_test_evm_data(\n",
|
||||
"ezkl.setup_test_evm_witness(\n",
|
||||
" data_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" # we write the call data to the same file as the input data\n",
|
||||
@@ -374,6 +374,14 @@
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc888848",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -517,7 +525,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -531,7 +539,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
8761,
|
||||
7654,
|
||||
8501,
|
||||
2404,
|
||||
6929,
|
||||
8858,
|
||||
5946,
|
||||
3673,
|
||||
4131,
|
||||
3854,
|
||||
8137,
|
||||
8239,
|
||||
9038,
|
||||
6299,
|
||||
1118,
|
||||
9737,
|
||||
208,
|
||||
7954,
|
||||
3691,
|
||||
610,
|
||||
3468,
|
||||
3314,
|
||||
8658,
|
||||
8366,
|
||||
2850,
|
||||
477,
|
||||
6114,
|
||||
232,
|
||||
4601,
|
||||
7420,
|
||||
5713,
|
||||
2936,
|
||||
6061,
|
||||
2870,
|
||||
8421,
|
||||
177,
|
||||
7107,
|
||||
7382,
|
||||
6115,
|
||||
5487,
|
||||
8502,
|
||||
2559,
|
||||
1875,
|
||||
129,
|
||||
8533,
|
||||
8201,
|
||||
8414,
|
||||
4775,
|
||||
9817,
|
||||
3127,
|
||||
8761,
|
||||
7654,
|
||||
8501,
|
||||
2404,
|
||||
6929,
|
||||
8858,
|
||||
5946,
|
||||
3673,
|
||||
4131,
|
||||
3854,
|
||||
8137,
|
||||
8239,
|
||||
9038,
|
||||
6299,
|
||||
1118,
|
||||
9737,
|
||||
208,
|
||||
7954,
|
||||
3691,
|
||||
610,
|
||||
3468,
|
||||
3314,
|
||||
8658,
|
||||
8366,
|
||||
2850,
|
||||
477,
|
||||
6114,
|
||||
232,
|
||||
4601,
|
||||
7420,
|
||||
5713,
|
||||
2936,
|
||||
6061,
|
||||
2870,
|
||||
8421,
|
||||
177,
|
||||
7107,
|
||||
7382,
|
||||
6115,
|
||||
5487,
|
||||
8502,
|
||||
2559,
|
||||
1875,
|
||||
129,
|
||||
8533,
|
||||
8201,
|
||||
8414,
|
||||
4775,
|
||||
9817,
|
||||
3127
|
||||
]
|
||||
]
|
||||
}
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{"run_args":{"input_scale":7,"param_scale":7,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"num_inner_cols":2,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private","rebase_frac_zero_constants":false,"check_mode":"UNSAFE","commitment":"KZG","decomp_base":16384,"decomp_legs":2,"bounded_log_lookup":false,"ignore_range_check_inputs_outputs":false},"num_rows":54,"total_assignments":109,"total_const_size":4,"total_dynamic_col_size":0,"max_dynamic_input_len":0,"num_dynamic_lookups":0,"num_shuffles":0,"total_shuffle_col_size":0,"model_instance_shapes":[[1,1]],"model_output_scales":[7],"model_input_scales":[7],"module_sizes":{"polycommit":[],"poseidon":[0,[0]]},"required_lookups":[],"required_range_checks":[[-1,1],[0,16383]],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null,"timestamp":1739396322131,"input_types":["F32"],"output_types":["F32"]}
|
||||
File diff suppressed because one or more lines are too long
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -1,42 +0,0 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x // 3
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.randint(0, 10, (1, 2, 2, 8))
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(x)
|
||||
print(out)
|
||||
print(x/3)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
@@ -1 +0,0 @@
|
||||
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}
|
||||
Binary file not shown.
4
ezkl.pyi
4
ezkl.pyi
@@ -706,9 +706,9 @@ def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Pa
|
||||
"""
|
||||
...
|
||||
|
||||
def setup_test_evm_data(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
|
||||
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
Setup test evm data
|
||||
Setup test evm witness
|
||||
|
||||
Arguments
|
||||
---------
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "nightly-2025-02-17"
|
||||
channel = "nightly-2024-07-18"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
// ignore file if compiling for wasm
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::{CommandFactory, Parser};
|
||||
@@ -28,8 +24,6 @@ use std::env;
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
pub async fn main() {
|
||||
use log::debug;
|
||||
|
||||
let args = Cli::parse();
|
||||
|
||||
if let Some(generator) = args.generator {
|
||||
@@ -44,7 +38,7 @@ pub async fn main() {
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
}
|
||||
debug!(
|
||||
info!(
|
||||
"command: \n {}",
|
||||
&command.as_json().to_colored_json_auto().unwrap()
|
||||
);
|
||||
|
||||
@@ -4,10 +4,11 @@ use crate::circuit::modules::poseidon::{
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
|
||||
@@ -155,6 +156,9 @@ impl pyo3::ToPyObject for PyG1Affine {
|
||||
#[derive(Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// float: The tolerance for error on model outputs
|
||||
pub tolerance: f32,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing inputs
|
||||
pub input_scale: crate::Scale,
|
||||
@@ -203,9 +207,6 @@ struct PyRunArgs {
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
|
||||
#[pyo3(get, set)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -222,6 +223,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
num_inner_cols: py_run_args.num_inner_cols,
|
||||
@@ -237,7 +239,6 @@ impl From<PyRunArgs> for RunArgs {
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
decomp_base: py_run_args.decomp_base,
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -246,6 +247,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
num_inner_cols: self.num_inner_cols,
|
||||
@@ -261,7 +263,6 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
commitment: self.commitment.into(),
|
||||
decomp_base: self.decomp_base,
|
||||
decomp_legs: self.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -332,8 +333,6 @@ enum PyInputType {
|
||||
Int,
|
||||
///
|
||||
TDim,
|
||||
///
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl From<InputType> for PyInputType {
|
||||
@@ -345,7 +344,6 @@ impl From<InputType> for PyInputType {
|
||||
InputType::F64 => PyInputType::F64,
|
||||
InputType::Int => PyInputType::Int,
|
||||
InputType::TDim => PyInputType::TDim,
|
||||
InputType::Unknown => PyInputType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -359,7 +357,6 @@ impl From<PyInputType> for InputType {
|
||||
PyInputType::F64 => InputType::F64,
|
||||
PyInputType::Int => InputType::Int,
|
||||
PyInputType::TDim => InputType::TDim,
|
||||
PyInputType::Unknown => InputType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -374,7 +371,6 @@ impl FromStr for PyInputType {
|
||||
"f64" => Ok(PyInputType::F64),
|
||||
"int" => Ok(PyInputType::Int),
|
||||
"tdim" => Ok(PyInputType::TDim),
|
||||
"unknown" => Ok(PyInputType::Unknown),
|
||||
_ => Err("Invalid value for InputType".to_string()),
|
||||
}
|
||||
}
|
||||
@@ -577,7 +573,10 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
.map(crate::pfsys::string_to_field::<Fr>)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
.map_err(|_| PyIOError::new_err("Failed to run poseidon"))?;
|
||||
|
||||
let hash = output[0]
|
||||
@@ -592,7 +591,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represented as strings
|
||||
/// List of field elements represnted as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
@@ -651,7 +650,7 @@ fn kzg_commit(
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represented as strings
|
||||
/// List of field elements represnted as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
@@ -1009,7 +1008,7 @@ fn gen_random_data(
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
data = String::from(DEFAULT_CALIBRATION_FILE),
|
||||
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
|
||||
model = PathBuf::from(DEFAULT_MODEL),
|
||||
settings = PathBuf::from(DEFAULT_SETTINGS),
|
||||
target = CalibrationTarget::default(), // default is "resources
|
||||
@@ -1021,7 +1020,7 @@ fn gen_random_data(
|
||||
#[gen_stub_pyfunction]
|
||||
fn calibrate_settings(
|
||||
py: Python,
|
||||
data: String,
|
||||
data: PathBuf,
|
||||
model: PathBuf,
|
||||
settings: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
@@ -1076,7 +1075,7 @@ fn calibrate_settings(
|
||||
/// Python object containing the witness values
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
data=String::from(DEFAULT_DATA),
|
||||
data=PathBuf::from(DEFAULT_DATA),
|
||||
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
|
||||
output=PathBuf::from(DEFAULT_WITNESS),
|
||||
vk_path=None,
|
||||
@@ -1085,7 +1084,7 @@ fn calibrate_settings(
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_witness(
|
||||
py: Python,
|
||||
data: String,
|
||||
data: PathBuf,
|
||||
model: PathBuf,
|
||||
output: Option<PathBuf>,
|
||||
vk_path: Option<PathBuf>,
|
||||
@@ -1754,7 +1753,7 @@ fn create_evm_vka(
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
input_data=String::from(DEFAULT_DATA),
|
||||
input_data=PathBuf::from(DEFAULT_DATA),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
|
||||
@@ -1763,7 +1762,7 @@ fn create_evm_vka(
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_data_attestation(
|
||||
py: Python,
|
||||
input_data: String,
|
||||
input_data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
@@ -1819,12 +1818,12 @@ fn create_evm_data_attestation(
|
||||
test_data,
|
||||
input_source,
|
||||
output_source,
|
||||
rpc_url=None
|
||||
rpc_url=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup_test_evm_data(
|
||||
fn setup_test_evm_witness(
|
||||
py: Python,
|
||||
data_path: String,
|
||||
data_path: PathBuf,
|
||||
compiled_circuit_path: PathBuf,
|
||||
test_data: PathBuf,
|
||||
input_source: PyTestDataSource,
|
||||
@@ -1832,7 +1831,7 @@ fn setup_test_evm_data(
|
||||
rpc_url: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::setup_test_evm_data(
|
||||
crate::execute::setup_test_evm_witness(
|
||||
data_path,
|
||||
compiled_circuit_path,
|
||||
test_data,
|
||||
@@ -1842,7 +1841,7 @@ fn setup_test_evm_data(
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run setup_test_evm_data: {}", e);
|
||||
let err_str = format!("Failed to run setup_test_evm_witness: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
@@ -1902,7 +1901,7 @@ fn deploy_evm(
|
||||
fn deploy_da_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
input_data: String,
|
||||
input_data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
@@ -1945,7 +1944,7 @@ fn deploy_da_evm(
|
||||
/// does the verifier use data attestation ?
|
||||
///
|
||||
/// addr_vk: str
|
||||
/// The address of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
@@ -2107,7 +2106,7 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_data, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
|
||||
|
||||
@@ -8,7 +8,10 @@ use crate::{
|
||||
Module,
|
||||
},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
|
||||
graph::{quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings},
|
||||
graph::{
|
||||
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
|
||||
GraphSettings,
|
||||
},
|
||||
};
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::{
|
||||
@@ -228,7 +231,10 @@ pub fn poseidonHash(
|
||||
let message: Vec<Fr> = serde_json::from_slice(&message[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
|
||||
|
||||
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/*
|
||||
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
|
||||
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
*/
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
/*
|
||||
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
|
||||
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
*/
|
||||
|
||||
pub mod poseidon_params;
|
||||
pub mod spec;
|
||||
|
||||
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
use halo2_gadgets::poseidon::{
|
||||
primitives::VariableLength, primitives::*, Hash, Pow5Chip, Pow5Config,
|
||||
};
|
||||
use halo2_gadgets::poseidon::{primitives::*, Hash, Pow5Chip, Pow5Config};
|
||||
use halo2_proofs::arithmetic::Field;
|
||||
use halo2_proofs::halo2curves::bn256::Fr as Fp;
|
||||
use halo2_proofs::{circuit::*, plonk::*};
|
||||
// use maybe_rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
|
||||
use maybe_rayon::prelude::ParallelIterator;
|
||||
use maybe_rayon::slice::ParallelSlice;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -38,17 +40,22 @@ pub struct PoseidonConfig<const WIDTH: usize, const RATE: usize> {
|
||||
pub pow5_config: Pow5Config<Fp, WIDTH, RATE>,
|
||||
}
|
||||
|
||||
type InputAssignments = Vec<AssignedCell<Fp, Fp>>;
|
||||
type InputAssignments = (Vec<AssignedCell<Fp, Fp>>, AssignedCell<Fp, Fp>);
|
||||
|
||||
/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseidonChip<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> {
|
||||
pub struct PoseidonChip<
|
||||
S: Spec<Fp, WIDTH, RATE> + Sync,
|
||||
const WIDTH: usize,
|
||||
const RATE: usize,
|
||||
const L: usize,
|
||||
> {
|
||||
config: PoseidonConfig<WIDTH, RATE>,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
PoseidonChip<S, WIDTH, RATE>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
PoseidonChip<S, WIDTH, RATE, L>
|
||||
{
|
||||
/// Creates a new PoseidonChip
|
||||
pub fn configure_with_cols(
|
||||
@@ -75,8 +82,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
PoseidonChip<S, WIDTH, RATE>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
PoseidonChip<S, WIDTH, RATE, L>
|
||||
{
|
||||
/// Configuration of the PoseidonChip
|
||||
pub fn configure_with_optional_instance(
|
||||
@@ -93,6 +100,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
Self::configure_with_cols(
|
||||
@@ -106,8 +116,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Module<Fp>
|
||||
for PoseidonChip<S, WIDTH, RATE>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
Module<Fp> for PoseidonChip<S, WIDTH, RATE, L>
|
||||
{
|
||||
type Config = PoseidonConfig<WIDTH, RATE>;
|
||||
type InputAssignments = InputAssignments;
|
||||
@@ -142,6 +152,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
let instance = meta.instance_column();
|
||||
@@ -163,10 +176,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
if message.len() != 1 {
|
||||
return Err(ModuleError::InputWrongLength(message.len()));
|
||||
}
|
||||
|
||||
assert_eq!(message.len(), 1);
|
||||
let message = message[0].clone();
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -176,81 +186,95 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
|
||||
let res = layouter.assign_region(
|
||||
|| "load message",
|
||||
|mut region| {
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, _> = match &message {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
|
||||
match &message {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
v.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
|
||||
match value {
|
||||
ValType::Value(v) => region
|
||||
.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
match value {
|
||||
ValType::Value(v) => region
|
||||
.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
)
|
||||
.map_err(|e| e.into()),
|
||||
ValType::PrevAssigned(v)
|
||||
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants
|
||||
.get(f)
|
||||
.unwrap()
|
||||
.assigned_cell()
|
||||
.ok_or(ModuleError::ConstantNotAssigned)?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
constants.insert(
|
||||
*f,
|
||||
ValType::AssignedConstant(res.clone(), *f),
|
||||
);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"PrevAssigned".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
)
|
||||
.map_err(|e| e.into()),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
Ok(v.clone())
|
||||
}
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants
|
||||
.get(f)
|
||||
.unwrap()
|
||||
.assigned_cell()
|
||||
.ok_or(ModuleError::ConstantNotAssigned)?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
};
|
||||
|
||||
constants
|
||||
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
|
||||
let offset = message.len() / WIDTH + 1;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"AssignedValue".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
};
|
||||
let zero_val = region
|
||||
.assign_advice_from_constant(
|
||||
|| "",
|
||||
self.config.hash_inputs[0],
|
||||
offset,
|
||||
Fp::ZERO,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(assigned_message?)
|
||||
Ok((assigned_message?, zero_val))
|
||||
},
|
||||
);
|
||||
log::trace!(
|
||||
@@ -271,13 +295,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
let input_cells = self.layout_inputs(layouter, input, constants)?;
|
||||
|
||||
// empty hash case
|
||||
if input_cells.is_empty() {
|
||||
return Ok(input[0].clone());
|
||||
}
|
||||
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
input_cells.iter().map(|e| ValType::from(e.clone())).into();
|
||||
@@ -285,25 +303,52 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
|
||||
// initialize the hasher
|
||||
let hasher = Hash::<_, _, S, VariableLength, WIDTH, RATE>::init(
|
||||
pow5_chip,
|
||||
layouter.namespace(|| "block_hasher"),
|
||||
)?;
|
||||
let mut one_iter = false;
|
||||
// do the Tree dance baby
|
||||
while input_cells.len() > 1 || !one_iter {
|
||||
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
|
||||
.chunks(L)
|
||||
.enumerate()
|
||||
.map(|(i, block)| {
|
||||
let _start_time = instant::Instant::now();
|
||||
|
||||
let hash: AssignedCell<Fp, Fp> = hasher.hash(
|
||||
layouter.namespace(|| "hash"),
|
||||
input_cells
|
||||
.to_vec()
|
||||
.try_into()
|
||||
.map_err(|_| Error::Synthesis)?,
|
||||
)?;
|
||||
let mut block = block.to_vec();
|
||||
let remainder = block.len() % L;
|
||||
|
||||
if remainder != 0 {
|
||||
block.extend(vec![zero_val.clone(); L - remainder]);
|
||||
}
|
||||
|
||||
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
|
||||
// initialize the hasher
|
||||
let hasher = Hash::<_, _, S, ConstantLength<L>, WIDTH, RATE>::init(
|
||||
pow5_chip,
|
||||
layouter.namespace(|| "block_hasher"),
|
||||
)?;
|
||||
|
||||
let hash = hasher.hash(
|
||||
layouter.namespace(|| "hash"),
|
||||
block.to_vec().try_into().map_err(|_| Error::Synthesis)?,
|
||||
);
|
||||
|
||||
if i == 0 {
|
||||
log::trace!("block (L={:?}) took: {:?}", L, _start_time.elapsed());
|
||||
}
|
||||
|
||||
hash
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into());
|
||||
|
||||
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
|
||||
one_iter = true;
|
||||
input_cells = hashes?;
|
||||
}
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
log::trace!("layout (N={:?}) took: {:?}", len, duration);
|
||||
|
||||
let result = Tensor::from(vec![ValType::from(hash.clone())].into_iter());
|
||||
let result = Tensor::from(input_cells.iter().map(|e| ValType::from(e.clone())));
|
||||
|
||||
let output = match result[0].clone() {
|
||||
ValType::PrevAssigned(v) => v,
|
||||
@@ -342,59 +387,69 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Mod
|
||||
|
||||
///
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
|
||||
let len = message.len();
|
||||
if len == 0 {
|
||||
return Ok(vec![vec![]]);
|
||||
}
|
||||
let mut hash_inputs = message;
|
||||
|
||||
let len = hash_inputs.len();
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let hash = halo2_gadgets::poseidon::primitives::Hash::<
|
||||
_,
|
||||
S,
|
||||
VariableLength,
|
||||
{ WIDTH },
|
||||
{ RATE },
|
||||
>::init()
|
||||
.hash(message);
|
||||
let mut one_iter = false;
|
||||
// do the Tree dance baby
|
||||
while hash_inputs.len() > 1 || !one_iter {
|
||||
let hashes: Vec<Fp> = hash_inputs
|
||||
.par_chunks(L)
|
||||
.map(|block| {
|
||||
let mut block = block.to_vec();
|
||||
let remainder = block.len() % L;
|
||||
|
||||
if remainder != 0 {
|
||||
block.extend(vec![Fp::ZERO; L - remainder].iter());
|
||||
}
|
||||
|
||||
let block_len = block.len();
|
||||
|
||||
let message = block
|
||||
.try_into()
|
||||
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
|
||||
|
||||
Ok(halo2_gadgets::poseidon::primitives::Hash::<
|
||||
_,
|
||||
S,
|
||||
ConstantLength<L>,
|
||||
{ WIDTH },
|
||||
{ RATE },
|
||||
>::init()
|
||||
.hash(message))
|
||||
})
|
||||
.collect::<Result<Vec<_>, ModuleError>>()?;
|
||||
one_iter = true;
|
||||
hash_inputs = hashes;
|
||||
}
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
log::trace!("run (N={:?}) took: {:?}", len, duration);
|
||||
|
||||
Ok(vec![vec![hash]])
|
||||
Ok(vec![hash_inputs])
|
||||
}
|
||||
|
||||
fn num_rows(input_len: usize) -> usize {
|
||||
fn num_rows(mut input_len: usize) -> usize {
|
||||
// this was determined by running the circuit and looking at the number of constraints
|
||||
// in the test called hash_for_a_range_of_input_sizes, then regressing in python to find the slope
|
||||
// import numpy as np
|
||||
// from scipy import stats
|
||||
let fixed_cost: usize = 41 * L;
|
||||
|
||||
// x = np.array([32, 64, 96, 128, 160, 192])
|
||||
// y = np.array([1298, 2594, 3890, 5186, 6482, 7778])
|
||||
let mut num_rows = 0;
|
||||
|
||||
// slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
|
||||
loop {
|
||||
// the number of times the input_len is divisible by L
|
||||
let num_chunks = input_len / L + 1;
|
||||
num_rows += num_chunks * fixed_cost;
|
||||
if num_chunks == 1 {
|
||||
break;
|
||||
}
|
||||
input_len = num_chunks;
|
||||
}
|
||||
|
||||
// print(f"slope: {slope}")
|
||||
// print(f"intercept: {intercept}")
|
||||
// print(f"R^2: {r_value**2}")
|
||||
|
||||
// # Predict for any x
|
||||
// def predict(x):
|
||||
// return slope * x + intercept
|
||||
|
||||
// # Test prediction
|
||||
// test_x = 256
|
||||
// print(f"Predicted value for x={test_x}: {predict(test_x)}")
|
||||
// our output:
|
||||
// slope: 40.5
|
||||
// intercept: 2.0
|
||||
// R^2: 1.0
|
||||
// Predicted value for x=256: 10370.0
|
||||
let fixed_cost: usize = 41 * input_len;
|
||||
|
||||
// the cost of the hash function is linear with the number of inputs
|
||||
fixed_cost + 2
|
||||
num_rows
|
||||
}
|
||||
}
|
||||
|
||||
@@ -421,12 +476,12 @@ mod tests {
|
||||
const RATE: usize = POSEIDON_RATE;
|
||||
const R: usize = 240;
|
||||
|
||||
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>> {
|
||||
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>, const L: usize> {
|
||||
message: ValTensor<Fp>,
|
||||
_spec: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE>> Circuit<Fp> for HashCircuit<S> {
|
||||
impl<S: Spec<Fp, WIDTH, RATE>, const L: usize> Circuit<Fp> for HashCircuit<S, L> {
|
||||
type Config = PoseidonConfig<WIDTH, RATE>;
|
||||
type FloorPlanner = ModulePlanner;
|
||||
type Params = ();
|
||||
@@ -442,7 +497,7 @@ mod tests {
|
||||
}
|
||||
|
||||
fn configure(meta: &mut ConstraintSystem<Fp>) -> PoseidonConfig<WIDTH, RATE> {
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(meta, ())
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, L>::configure(meta, ())
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -450,7 +505,7 @@ mod tests {
|
||||
config: PoseidonConfig<WIDTH, RATE>,
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> = PoseidonChip::new(config);
|
||||
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
|
||||
chip.layout(
|
||||
&mut layouter,
|
||||
&[self.message.clone()],
|
||||
@@ -462,33 +517,18 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash_empty() {
|
||||
let message = [];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, vec![vec![]]).unwrap();
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poseidon_hash() {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
let message = [Fp::random(rng), Fp::random(rng)];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
let circuit = HashCircuit::<PoseidonSpec, 2> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -501,13 +541,13 @@ mod tests {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
let message = [Fp::random(rng), Fp::random(rng), Fp::random(rng)];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 3>::run(message.to_vec()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
let circuit = HashCircuit::<PoseidonSpec, 3> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -523,21 +563,23 @@ mod tests {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
env_logger::init();
|
||||
|
||||
for i in (32..128).step_by(32) {
|
||||
{
|
||||
let i = 32;
|
||||
// print a bunch of new lines
|
||||
log::info!(
|
||||
println!(
|
||||
"i is {} -------------------------------------------------",
|
||||
i
|
||||
);
|
||||
|
||||
let message: Vec<Fp> = (0..i).map(|_| Fp::random(rng)).collect::<Vec<_>>();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, 32>::run(message.clone()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 17;
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
let circuit = HashCircuit::<PoseidonSpec, 32> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -554,13 +596,13 @@ mod tests {
|
||||
|
||||
let mut message: Vec<Fp> = (0..2048).map(|_| Fp::random(rng)).collect::<Vec<_>>();
|
||||
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 25>::run(message.clone()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 17;
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
let circuit = HashCircuit::<PoseidonSpec, 25> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
|
||||
@@ -17,14 +17,12 @@ pub enum BaseOp {
|
||||
Sub,
|
||||
SumInit,
|
||||
Sum,
|
||||
IsBoolean,
|
||||
}
|
||||
|
||||
/// Matches a [BaseOp] to an operation over inputs
|
||||
impl BaseOp {
|
||||
/// forward func for non-accumulating operations
|
||||
/// # Panics
|
||||
/// Panics if called on an accumulating operation
|
||||
/// # Examples
|
||||
/// forward func
|
||||
pub fn nonaccum_f<
|
||||
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
|
||||
>(
|
||||
@@ -36,13 +34,12 @@ impl BaseOp {
|
||||
BaseOp::Add => a + b,
|
||||
BaseOp::Sub => a - b,
|
||||
BaseOp::Mult => a * b,
|
||||
BaseOp::IsBoolean => b,
|
||||
_ => panic!("nonaccum_f called on accumulating operation"),
|
||||
}
|
||||
}
|
||||
|
||||
/// forward func for accumulating operations
|
||||
/// # Panics
|
||||
/// Panics if called on a non-accumulating operation
|
||||
/// forward func
|
||||
pub fn accum_f<
|
||||
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
|
||||
>(
|
||||
@@ -77,6 +74,7 @@ impl BaseOp {
|
||||
BaseOp::Mult => "MULT",
|
||||
BaseOp::Sum => "SUM",
|
||||
BaseOp::SumInit => "SUMINIT",
|
||||
BaseOp::IsBoolean => "ISBOOLEAN",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,6 +90,7 @@ impl BaseOp {
|
||||
BaseOp::Mult => (0, 1),
|
||||
BaseOp::Sum => (-1, 2),
|
||||
BaseOp::SumInit => (0, 1),
|
||||
BaseOp::IsBoolean => (0, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,6 +106,7 @@ impl BaseOp {
|
||||
BaseOp::Mult => 2,
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,6 +122,7 @@ impl BaseOp {
|
||||
BaseOp::SumInit => 0,
|
||||
BaseOp::CumProd => 1,
|
||||
BaseOp::CumProdInit => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::str::FromStr;
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector, TableColumn},
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
poly::Rotation,
|
||||
};
|
||||
use log::debug;
|
||||
@@ -20,6 +20,7 @@ use crate::{
|
||||
circuit::{
|
||||
ops::base::BaseOp,
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
@@ -84,6 +85,55 @@ impl CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
|
||||
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
|
||||
pub struct Tolerance {
|
||||
pub val: f32,
|
||||
pub scale: utils::F32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Tolerance {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:.2}", self.val)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Tolerance {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Tolerance {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if let Ok(val) = s.parse::<f32>() {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(1.0),
|
||||
})
|
||||
} else {
|
||||
Err(
|
||||
"Invalid tolerance value provided. It should expressed as a percentage (f32)."
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for Tolerance {
|
||||
fn from(value: f32) -> Self {
|
||||
Tolerance {
|
||||
val: value,
|
||||
scale: utils::F32(1.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CheckMode {
|
||||
@@ -108,6 +158,29 @@ impl<'source> FromPyObject<'source> for CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Tolerance into a PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Tolerance {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
(self.val, self.scale.0).to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Tolerance {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(scale),
|
||||
})
|
||||
} else {
|
||||
Err(PyValueError::new_err("Invalid tolerance value provided. "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookups {
|
||||
@@ -142,16 +215,15 @@ impl DynamicLookups {
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
|
||||
pub struct Shuffles {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub output_selectors: Vec<Selector>,
|
||||
pub reference_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// tables
|
||||
pub outputs: Vec<VarTensor>,
|
||||
pub references: Vec<VarTensor>,
|
||||
}
|
||||
|
||||
impl Shuffles {
|
||||
@@ -162,13 +234,9 @@ impl Shuffles {
|
||||
|
||||
Self {
|
||||
input_selectors: BTreeMap::new(),
|
||||
output_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
|
||||
outputs: vec![
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
],
|
||||
reference_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone()],
|
||||
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,8 +336,6 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
/// shared table inputs
|
||||
pub shared_table_inputs: Vec<TableColumn>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
@@ -282,7 +348,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
shared_table_inputs: vec![],
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -309,18 +374,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
if inputs[0].num_cols() != output.num_cols() {
|
||||
log::warn!("input and output shapes do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
|
||||
log::warn!("input number of inner columns do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != output.num_inner_cols() {
|
||||
log::warn!("input and output number of inner columns do not match");
|
||||
}
|
||||
|
||||
for i in 0..output.num_blocks() {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,13 +414,24 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.query_offset_rng();
|
||||
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("non accum: output query failed");
|
||||
let constraints = match base_op {
|
||||
BaseOp::IsBoolean => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res]
|
||||
let output = expected_output[base_op.constraint_idx()].clone();
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.constraint_idx()].clone() - res]
|
||||
}
|
||||
};
|
||||
|
||||
Constraints::with_selector(selector, constraints)
|
||||
@@ -415,7 +486,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
shared_table_inputs: vec![],
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -446,9 +516,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
return Err(CircuitError::WrongColumnType(output.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let table = if !self.static_lookups.tables.contains_key(nl) {
|
||||
let table =
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, &mut self.shared_table_inputs);
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let table = if let Some(table) = self.static_lookups.tables.values().next() {
|
||||
Table::<F>::configure(
|
||||
cs,
|
||||
lookup_range,
|
||||
logrows,
|
||||
nl,
|
||||
Some(table.table_inputs.clone()),
|
||||
)
|
||||
} else {
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
|
||||
};
|
||||
self.static_lookups.tables.insert(nl.clone(), table.clone());
|
||||
table
|
||||
} else {
|
||||
@@ -499,9 +581,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
// this is 0 if the index is the same as the column index (starting from 1)
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* (table
|
||||
* table
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
|
||||
let multiplier =
|
||||
table.selector_constructor.get_selector_val_at_idx(col_idx);
|
||||
@@ -533,40 +615,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.static_lookups
|
||||
.selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
@@ -692,8 +740,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
pub fn configure_shuffles(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
inputs: &[VarTensor; 3],
|
||||
outputs: &[VarTensor; 3],
|
||||
inputs: &[VarTensor; 2],
|
||||
references: &[VarTensor; 2],
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
@@ -704,14 +752,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
}
|
||||
|
||||
for t in outputs.iter() {
|
||||
for t in references.iter() {
|
||||
if !t.is_advice() || t.num_inner_cols() > 1 {
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// assert all tables have the same number of blocks
|
||||
if outputs
|
||||
if references
|
||||
.iter()
|
||||
.map(|t| t.num_blocks())
|
||||
.collect::<Vec<_>>()
|
||||
@@ -719,23 +767,23 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
.any(|w| w[0] != w[1])
|
||||
{
|
||||
return Err(CircuitError::WrongDynamicColumnType(
|
||||
"outputs inner cols".to_string(),
|
||||
"references inner cols".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
for q in 0..outputs[0].num_blocks() {
|
||||
let s_output = cs.complex_selector();
|
||||
for q in 0..references[0].num_blocks() {
|
||||
let s_reference = cs.complex_selector();
|
||||
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("shuffle", |cs| {
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_outputq = cs.query_selector(s_output);
|
||||
let s_referenceq = cs.query_selector(s_reference);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
|
||||
for input in inputs {
|
||||
@@ -747,9 +795,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
});
|
||||
}
|
||||
|
||||
let mut output_queries = vec![one.clone()];
|
||||
for output in outputs {
|
||||
output_queries.push(match output {
|
||||
let mut ref_queries = vec![one.clone()];
|
||||
for reference in references {
|
||||
ref_queries.push(match reference {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[q][0], Rotation(0))
|
||||
}
|
||||
@@ -758,7 +806,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
|
||||
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
@@ -769,13 +817,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
.or_insert(s_input);
|
||||
}
|
||||
}
|
||||
self.shuffles.output_selectors.push(s_output);
|
||||
self.shuffles.reference_selectors.push(s_reference);
|
||||
}
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.shuffles.outputs.is_empty() {
|
||||
debug!("assigning shuffles output");
|
||||
self.shuffles.outputs = outputs.to_vec();
|
||||
if self.shuffles.references.is_empty() {
|
||||
debug!("assigning shuffles reference");
|
||||
self.shuffles.references = references.to_vec();
|
||||
}
|
||||
if self.shuffles.inputs.is_empty() {
|
||||
debug!("assigning shuffles input");
|
||||
@@ -807,6 +855,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
|
||||
self.range_checks.ranges.entry(range)
|
||||
{
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
@@ -844,9 +893,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
let default_x = range_check.get_first_element(col_idx);
|
||||
|
||||
let col_expr = sel.clone()
|
||||
* (range_check
|
||||
* range_check
|
||||
.selector_constructor
|
||||
.get_expr_at_idx(col_idx, synthetic_sel));
|
||||
.get_expr_at_idx(col_idx, synthetic_sel);
|
||||
|
||||
let multiplier = range_check
|
||||
.selector_constructor
|
||||
@@ -869,40 +918,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
|
||||
// add a degree-k custom constraint of the following form to the range check and
|
||||
// static lookup configuration.
|
||||
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 − 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
|
||||
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
|
||||
cs.create_gate("range_check_on_sel", |cs| {
|
||||
let synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(1)),
|
||||
_ => match index {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
|
||||
let range_check_on_synthetic_sel = match len {
|
||||
1 => Expression::Constant(F::from(0)),
|
||||
_ => {
|
||||
let mut initial_expr = Expression::Constant(F::from(1));
|
||||
for i in 0..len {
|
||||
initial_expr = initial_expr
|
||||
* (synthetic_sel.clone()
|
||||
- Expression::Constant(F::from(i as u64)))
|
||||
}
|
||||
initial_expr
|
||||
}
|
||||
};
|
||||
|
||||
let sel = cs.query_selector(multi_col_selector);
|
||||
|
||||
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
|
||||
});
|
||||
|
||||
self.range_checks
|
||||
.selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
|
||||
@@ -25,7 +25,7 @@ pub enum CircuitError {
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
/// Invalid einsum expression
|
||||
///
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
/// Flush error
|
||||
@@ -100,13 +100,4 @@ pub enum CircuitError {
|
||||
#[error("invalid input type {0}")]
|
||||
/// Invalid input type
|
||||
InvalidInputType(String),
|
||||
#[error("an element is missing from the shuffled version of the tensor")]
|
||||
/// An element is missing from the shuffled version of the tensor
|
||||
MissingShuffleElement,
|
||||
/// Visibility has not been set
|
||||
#[error("visibility has not been set")]
|
||||
UnsetVisibility,
|
||||
/// A decomposition base overflowed
|
||||
#[error("decomposition base overflowed")]
|
||||
DecompositionBaseOverflow,
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::integer_rep_to_felt,
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -57,13 +57,11 @@ pub enum HybridOp {
|
||||
stride: Vec<usize>,
|
||||
kernel_shape: Vec<usize>,
|
||||
normalized: bool,
|
||||
data_format: DataFormat,
|
||||
},
|
||||
MaxPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
pool_dims: Vec<usize>,
|
||||
data_format: DataFormat,
|
||||
},
|
||||
ReduceMin {
|
||||
axes: Vec<usize>,
|
||||
@@ -78,9 +76,7 @@ pub enum HybridOp {
|
||||
output_scale: utils::F32,
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Output {
|
||||
decomp: bool,
|
||||
},
|
||||
RangeCheck(Tolerance),
|
||||
Greater,
|
||||
GreaterEqual,
|
||||
Less,
|
||||
@@ -155,10 +151,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized, data_format
|
||||
normalized,
|
||||
} => format!(
|
||||
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
|
||||
padding, stride, kernel_shape, normalized, data_format
|
||||
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
|
||||
padding, stride, kernel_shape, normalized
|
||||
),
|
||||
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
|
||||
@@ -166,10 +162,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
data_format,
|
||||
} => format!(
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
|
||||
padding, stride, pool_dims, data_format
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
padding, stride, pool_dims
|
||||
),
|
||||
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
|
||||
@@ -183,9 +178,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale, output_scale, axes
|
||||
)
|
||||
}
|
||||
HybridOp::Output { decomp } => {
|
||||
format!("OUTPUT (decomp={})", decomp)
|
||||
}
|
||||
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
|
||||
HybridOp::Greater => "GREATER".to_string(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
|
||||
HybridOp::Less => "LESS".to_string(),
|
||||
@@ -241,7 +234,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
data_format,
|
||||
} => layouts::sumpool(
|
||||
config,
|
||||
region,
|
||||
@@ -250,7 +242,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
stride,
|
||||
kernel_shape,
|
||||
*normalized,
|
||||
*data_format,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
@@ -259,8 +250,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as IntegerRep),
|
||||
integer_rep_to_felt(output_scale.0 as IntegerRep),
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
@@ -268,7 +259,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(denom.0 as IntegerRep),
|
||||
integer_rep_to_felt(denom.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
@@ -291,7 +282,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
data_format,
|
||||
} => layouts::max_pool(
|
||||
config,
|
||||
region,
|
||||
@@ -299,7 +289,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
*data_format,
|
||||
)?,
|
||||
HybridOp::ReduceMax { axes } => {
|
||||
layouts::max_axes(config, region, values[..].try_into()?, axes)?
|
||||
@@ -325,9 +314,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
*output_scale,
|
||||
axes,
|
||||
)?,
|
||||
HybridOp::Output { decomp } => {
|
||||
layouts::output(config, region, values[..].try_into()?, *decomp)?
|
||||
}
|
||||
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
tol.scale,
|
||||
tol.val,
|
||||
)?,
|
||||
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
|
||||
HybridOp::GreaterEqual => {
|
||||
layouts::greater_equal(config, region, values[..].try_into()?)?
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,6 @@
|
||||
use std::any::Any;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::prelude::DatumType;
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
@@ -98,8 +96,6 @@ pub enum InputType {
|
||||
Int,
|
||||
///
|
||||
TDim,
|
||||
///
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl InputType {
|
||||
@@ -136,7 +132,6 @@ impl InputType {
|
||||
let int_input = input.clone().to_i64().unwrap();
|
||||
*input = T::from_i64(int_input).unwrap();
|
||||
}
|
||||
InputType::Unknown => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -157,30 +152,6 @@ impl std::str::FromStr for InputType {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<DatumType> for InputType {
|
||||
/// # Panics
|
||||
/// Panics if the datum type is not supported
|
||||
fn from(datum_type: DatumType) -> Self {
|
||||
match datum_type {
|
||||
DatumType::Bool => InputType::Bool,
|
||||
DatumType::F16 => InputType::F16,
|
||||
DatumType::F32 => InputType::F32,
|
||||
DatumType::F64 => InputType::F64,
|
||||
DatumType::I8 => InputType::Int,
|
||||
DatumType::I16 => InputType::Int,
|
||||
DatumType::I32 => InputType::Int,
|
||||
DatumType::I64 => InputType::Int,
|
||||
DatumType::U8 => InputType::Int,
|
||||
DatumType::U16 => InputType::Int,
|
||||
DatumType::U32 => InputType::Int,
|
||||
DatumType::U64 => InputType::Int,
|
||||
DatumType::TDim => InputType::TDim,
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Input {
|
||||
@@ -188,8 +159,6 @@ pub struct Input {
|
||||
pub scale: crate::Scale,
|
||||
///
|
||||
pub datum_type: InputType,
|
||||
/// decomp check
|
||||
pub decomp: bool,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
@@ -227,7 +196,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
self.decomp,
|
||||
)?)),
|
||||
}
|
||||
} else {
|
||||
@@ -283,26 +251,20 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
///
|
||||
#[serde(skip)]
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
///
|
||||
pub decomp: bool,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>, decomp: bool) -> Self {
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
Self {
|
||||
quantized_values,
|
||||
raw_values,
|
||||
pre_assigned_val: None,
|
||||
decomp,
|
||||
}
|
||||
}
|
||||
/// Rebase the scale of the constant
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
|
||||
let visibility = match self.quantized_values.visibility() {
|
||||
Some(v) => v,
|
||||
None => return Err(CircuitError::UnsetVisibility),
|
||||
};
|
||||
let visibility = self.quantized_values.visibility().unwrap();
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -319,8 +281,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
@@ -341,12 +308,7 @@ impl<
|
||||
self.quantized_values.clone().try_into()?
|
||||
};
|
||||
// we gotta constrain it once if its used multiple times
|
||||
Ok(Some(layouts::identity(
|
||||
config,
|
||||
region,
|
||||
&[value],
|
||||
self.decomp,
|
||||
)?))
|
||||
Ok(Some(layouts::identity(config, region, &[value])?))
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::{
|
||||
utils::{self, F32},
|
||||
},
|
||||
tensor::{self, Tensor, TensorError},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
};
|
||||
|
||||
use super::{base::BaseOp, *};
|
||||
@@ -44,12 +43,10 @@ pub enum PolyOp {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
stride: isize,
|
||||
stride: usize,
|
||||
modulo: usize,
|
||||
},
|
||||
DeConv {
|
||||
@@ -57,8 +54,6 @@ pub enum PolyOp {
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -108,8 +103,13 @@ pub enum PolyOp {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
@@ -165,12 +165,10 @@ impl<
|
||||
stride,
|
||||
padding,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => {
|
||||
format!(
|
||||
"CONV (stride={:?}, padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
|
||||
stride, padding, group, data_format, kernel_format
|
||||
"CONV (stride={:?}, padding={:?}, group={})",
|
||||
stride, padding, group
|
||||
)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
@@ -178,12 +176,10 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
|
||||
stride, padding, output_padding, group, data_format, kernel_format
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
|
||||
stride, padding, output_padding, group
|
||||
)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
@@ -246,8 +242,6 @@ impl<
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => layouts::conv(
|
||||
config,
|
||||
region,
|
||||
@@ -255,17 +249,9 @@ impl<
|
||||
padding,
|
||||
stride,
|
||||
*group,
|
||||
*data_format,
|
||||
*kernel_format,
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 1 {
|
||||
return Err(TensorError::DimError(
|
||||
"GatherElements only accepts single inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
@@ -283,12 +269,6 @@ impl<
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
if values.len() != 2 {
|
||||
return Err(TensorError::DimError(
|
||||
"ScatterElements requires two inputs".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
@@ -317,8 +297,6 @@ impl<
|
||||
output_padding,
|
||||
stride,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => layouts::deconv(
|
||||
config,
|
||||
region,
|
||||
@@ -327,17 +305,13 @@ impl<
|
||||
output_padding,
|
||||
stride,
|
||||
*group,
|
||||
*data_format,
|
||||
*kernel_format,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
PolyOp::Mult => {
|
||||
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
|
||||
}
|
||||
PolyOp::Identity { .. } => {
|
||||
layouts::identity(config, region, values[..].try_into()?, false)?
|
||||
}
|
||||
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
|
||||
PolyOp::Pad(p) => {
|
||||
if values.len() != 1 {
|
||||
|
||||
@@ -132,16 +132,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
/// calculates the column size given the number of rows and reserved blinding rows
|
||||
///
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / col_size as IntegerRep) as usize + 1
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
@@ -163,7 +168,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
nonlinearity: &LookupOp,
|
||||
preexisting_inputs: &mut Vec<TableColumn>,
|
||||
preexisting_inputs: Option<Vec<TableColumn>>,
|
||||
) -> Table<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
@@ -172,28 +177,28 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
debug!("table range: {:?}", range);
|
||||
|
||||
// validate enough columns are provided to store the range
|
||||
if preexisting_inputs.len() < num_cols {
|
||||
// add columns to match the required number of columns
|
||||
let diff = num_cols - preexisting_inputs.len();
|
||||
for _ in 0..diff {
|
||||
preexisting_inputs.push(cs.lookup_table_column());
|
||||
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
|
||||
let mut cols = vec![];
|
||||
for _ in 0..num_cols {
|
||||
cols.push(cs.lookup_table_column());
|
||||
}
|
||||
}
|
||||
cols
|
||||
});
|
||||
|
||||
let num_cols = table_inputs.len();
|
||||
|
||||
let num_cols = preexisting_inputs.len();
|
||||
if num_cols > 1 {
|
||||
warn!("Using {} columns for non-linearity table.", num_cols);
|
||||
}
|
||||
|
||||
let table_outputs = preexisting_inputs
|
||||
let table_outputs = table_inputs
|
||||
.iter()
|
||||
.map(|_| cs.lookup_table_column())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Table {
|
||||
nonlinearity: nonlinearity.clone(),
|
||||
table_inputs: preexisting_inputs.clone(),
|
||||
table_inputs,
|
||||
table_outputs,
|
||||
is_assigned: false,
|
||||
selector_constructor: SelectorConstructor::new(num_cols),
|
||||
@@ -350,11 +355,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
/// calculates the column size
|
||||
///
|
||||
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(logrows as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
///
|
||||
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
|
||||
2usize.pow(bits as u32) - reserved_blinding_rows
|
||||
}
|
||||
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::circuit::ops::poly::PolyOp;
|
||||
use crate::circuit::*;
|
||||
use crate::tensor::{DataFormat, KernelFormat};
|
||||
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
@@ -1041,10 +1040,6 @@ mod conv {
|
||||
let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
|
||||
// column for constants
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1066,8 +1061,6 @@ mod conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1178,7 +1171,7 @@ mod conv_col_ultra_overflow {
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 6;
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -1198,10 +1191,9 @@ mod conv_col_ultra_overflow {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * LEN * LEN * LEN, false);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1223,8 +1215,6 @@ mod conv_col_ultra_overflow {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1382,8 +1372,6 @@ mod conv_relu_col_ultra_overflow {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
@@ -1788,18 +1776,13 @@ mod shuffle {
|
||||
|
||||
let d = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let e = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
|
||||
|
||||
let mut config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
|
||||
config
|
||||
.configure_shuffles(
|
||||
cs,
|
||||
&[a.clone(), b.clone(), c.clone()],
|
||||
&[d.clone(), e.clone(), f.clone()],
|
||||
)
|
||||
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
@@ -1820,7 +1803,6 @@ mod shuffle {
|
||||
&mut region,
|
||||
&self.inputs[i],
|
||||
&self.references[i],
|
||||
layouts::SortCollisionMode::Unsorted,
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)?;
|
||||
}
|
||||
@@ -2006,7 +1988,7 @@ mod add_with_overflow_and_poseidon {
|
||||
let base = BaseConfig::configure(cs, &[a, b], &output, CheckMode::SAFE);
|
||||
VarTensor::constant_cols(cs, K, 2, false);
|
||||
|
||||
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(cs, ());
|
||||
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::configure(cs, ());
|
||||
|
||||
MyCircuitConfig { base, poseidon }
|
||||
}
|
||||
@@ -2016,7 +1998,7 @@ mod add_with_overflow_and_poseidon {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> =
|
||||
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
|
||||
PoseidonChip::new(config.poseidon.clone());
|
||||
|
||||
let assigned_inputs_a =
|
||||
@@ -2051,9 +2033,11 @@ mod add_with_overflow_and_poseidon {
|
||||
let b = (0..LEN)
|
||||
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
|
||||
.collect::<Vec<_>>();
|
||||
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0];
|
||||
let commitment_a =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone()).unwrap()[0][0];
|
||||
|
||||
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0];
|
||||
let commitment_b =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone()).unwrap()[0][0];
|
||||
|
||||
// parameters
|
||||
let a = Tensor::from(a.into_iter().map(Value::known));
|
||||
@@ -2075,11 +2059,13 @@ mod add_with_overflow_and_poseidon {
|
||||
let b = (0..LEN)
|
||||
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
|
||||
.collect::<Vec<_>>();
|
||||
let commitment_a =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0] + Fr::one();
|
||||
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone())
|
||||
.unwrap()[0][0]
|
||||
+ Fr::one();
|
||||
|
||||
let commitment_b =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0] + Fr::one();
|
||||
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone())
|
||||
.unwrap()[0][0]
|
||||
+ Fr::one();
|
||||
|
||||
// parameters
|
||||
let a = Tensor::from(a.into_iter().map(Value::known));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use alloy::primitives::Address as H160;
|
||||
use clap::{Command, Parser, Subcommand};
|
||||
use clap_complete::{Generator, Shell, generate};
|
||||
use clap_complete::{generate, Generator, Shell};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -8,7 +8,7 @@ use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use tosubcommand::{ToFlags, ToSubcommand};
|
||||
|
||||
use crate::{Commitments, RunArgs, pfsys::ProofType};
|
||||
use crate::{pfsys::ProofType, Commitments, RunArgs};
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::graph::TestDataSource;
|
||||
@@ -83,7 +83,7 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,10";
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
/// Default use reduced srs for verification
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
@@ -360,13 +360,8 @@ pub fn get_styles() -> clap::builder::Styles {
|
||||
}
|
||||
|
||||
/// Print completions for the given generator
|
||||
pub fn print_completions<G: Generator>(r#gen: G, cmd: &mut Command) {
|
||||
generate(
|
||||
r#gen,
|
||||
cmd,
|
||||
cmd.get_name().to_string(),
|
||||
&mut std::io::stdout(),
|
||||
);
|
||||
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
|
||||
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
@@ -402,7 +397,7 @@ pub enum Commands {
|
||||
GenWitness {
|
||||
/// The path to the .json data file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
@@ -448,7 +443,7 @@ pub enum Commands {
|
||||
CalibrateSettings {
|
||||
/// The path to the .json calibration data file.
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
@@ -632,7 +627,7 @@ pub enum Commands {
|
||||
SetupTestEvmData {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
@@ -651,6 +646,19 @@ pub enum Commands {
|
||||
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
|
||||
output_source: TestDataSource,
|
||||
},
|
||||
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
|
||||
#[command(arg_required_else_help = true)]
|
||||
TestUpdateAccountCalls {
|
||||
/// The path to the verifier contract's address
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr: H160Flag,
|
||||
/// The path to the .json data file.
|
||||
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
},
|
||||
/// Swaps the positions in the transcript that correspond to commitments
|
||||
SwapProofCommitments {
|
||||
/// The path to the proof file
|
||||
@@ -763,7 +771,7 @@ pub enum Commands {
|
||||
/// view functions that return the data that the network
|
||||
/// ingests as inputs.
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the witness file. This is needed for proof swapping for kzg commitments.
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
|
||||
witness: Option<PathBuf>,
|
||||
@@ -859,7 +867,7 @@ pub enum Commands {
|
||||
DeployEvmDataAttestation {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
|
||||
693
src/eth.rs
693
src/eth.rs
File diff suppressed because one or more lines are too long
137
src/execute.rs
137
src/execute.rs
@@ -1,30 +1,32 @@
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_sol};
|
||||
use crate::eth::{
|
||||
deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_multi_sol,
|
||||
fix_da_single_sol,
|
||||
};
|
||||
#[allow(unused_imports)]
|
||||
use crate::eth::{get_contract_artifacts, verify_proof_via_solidity};
|
||||
use crate::graph::input::GraphData;
|
||||
use crate::graph::input::{Calls, GraphData};
|
||||
use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
|
||||
use crate::graph::{TestDataSource, TestSources};
|
||||
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
|
||||
use crate::pfsys::{
|
||||
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
|
||||
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
|
||||
};
|
||||
use crate::pfsys::{
|
||||
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
|
||||
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
|
||||
};
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::{commands::*, EZKLError};
|
||||
use crate::{Commitments, RunArgs};
|
||||
use crate::{EZKLError, commands::*};
|
||||
use colored::Colorize;
|
||||
#[cfg(unix)]
|
||||
use gag::Gag;
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
use halo2_proofs::plonk::{self, Circuit};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
|
||||
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
|
||||
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
|
||||
@@ -37,6 +39,7 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
|
||||
use halo2_solidity_verifier;
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
@@ -47,12 +50,12 @@ use instant::Instant;
|
||||
use itertools::Itertools;
|
||||
use log::debug;
|
||||
use log::{info, trace, warn};
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::Config;
|
||||
use snark_verifier::system::halo2::compile;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::system::halo2::Config;
|
||||
use std::fs::File;
|
||||
use std::io::BufWriter;
|
||||
use std::io::{Cursor, Write};
|
||||
@@ -115,7 +118,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
} => gen_srs_cmd(
|
||||
srs_path,
|
||||
logrows as u32,
|
||||
commitment.unwrap_or_else(|| Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
),
|
||||
Commands::GetSrs {
|
||||
srs_path,
|
||||
@@ -298,7 +301,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
input_source,
|
||||
output_source,
|
||||
} => {
|
||||
setup_test_evm_data(
|
||||
setup_test_evm_witness(
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
|
||||
test_data,
|
||||
@@ -308,6 +311,11 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
Commands::TestUpdateAccountCalls {
|
||||
addr,
|
||||
data,
|
||||
rpc_url,
|
||||
} => test_update_account_calls(addr, data.unwrap_or(DEFAULT_DATA.into()), rpc_url).await,
|
||||
Commands::SwapProofCommitments {
|
||||
proof_path,
|
||||
witness_path,
|
||||
@@ -508,9 +516,7 @@ fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
|
||||
.status()
|
||||
.is_err()
|
||||
{
|
||||
log::warn!(
|
||||
"bash is not installed on this system, trying to run the install script with sh (may fail)"
|
||||
);
|
||||
log::warn!("bash is not installed on this system, trying to run the install script with sh (may fail)");
|
||||
"sh"
|
||||
} else {
|
||||
"bash"
|
||||
@@ -719,7 +725,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLErr
|
||||
|
||||
pub(crate) async fn gen_witness(
|
||||
compiled_circuit_path: PathBuf,
|
||||
data: String,
|
||||
data: PathBuf,
|
||||
output: Option<PathBuf>,
|
||||
vk_path: Option<PathBuf>,
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -727,7 +733,7 @@ pub(crate) async fn gen_witness(
|
||||
// these aren't real values so the sanity checks are mostly meaningless
|
||||
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
let data = GraphData::from_str(&data)?;
|
||||
let data: GraphData = GraphData::from_path(data)?;
|
||||
let settings = circuit.settings().clone();
|
||||
|
||||
let vk = if let Some(vk) = vk_path {
|
||||
@@ -870,7 +876,7 @@ pub(crate) fn gen_random_data(
|
||||
|
||||
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
|
||||
let slice = tensor.as_slice_mut::<f32>().unwrap();
|
||||
slice.iter_mut().for_each(|x| *x = rng.r#gen());
|
||||
slice.iter_mut().for_each(|x| *x = rng.gen());
|
||||
tensor.cast_to_dt(datum_type).unwrap().into_owned()
|
||||
}
|
||||
|
||||
@@ -1038,7 +1044,7 @@ impl AccuracyResults {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn calibrate(
|
||||
model_path: PathBuf,
|
||||
data: String,
|
||||
data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: f64,
|
||||
@@ -1052,7 +1058,7 @@ pub(crate) async fn calibrate(
|
||||
|
||||
use crate::fieldutils::IntegerRep;
|
||||
|
||||
let data = GraphData::from_str(&data)?;
|
||||
let data = GraphData::from_path(data)?;
|
||||
// load the pre-generated settings
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
// now retrieve the run args
|
||||
@@ -1516,7 +1522,7 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
input: String,
|
||||
input: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
) -> Result<String, EZKLError> {
|
||||
#[allow(unused_imports)]
|
||||
@@ -1529,31 +1535,52 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
trace!("params computed");
|
||||
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data =
|
||||
GraphData::from_str(&input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
debug!("data attestation data: {:?}", data);
|
||||
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
// The number of input and output instances we attest to for the single call data attestation
|
||||
let mut input_len = None;
|
||||
let mut output_len = None;
|
||||
|
||||
if let Some(DataSource::OnChain(source)) = data.output_data {
|
||||
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
|
||||
if visibility.output.is_private() {
|
||||
return Err("private output data on chain is not supported on chain".into());
|
||||
}
|
||||
output_len = Some(source.call.decimals.len());
|
||||
let mut on_chain_output_data = vec![];
|
||||
match source.calls {
|
||||
Calls::Multiple(calls) => {
|
||||
for call in calls {
|
||||
on_chain_output_data.push(call);
|
||||
}
|
||||
}
|
||||
Calls::Single(call) => {
|
||||
output_len = Some(call.len);
|
||||
}
|
||||
}
|
||||
Some(on_chain_output_data)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let DataSource::OnChain(source) = data.input_data {
|
||||
let input_data = if let DataSource::OnChain(source) = data.input_data {
|
||||
if visibility.input.is_private() {
|
||||
return Err("private input data on chain is not supported on chain".into());
|
||||
}
|
||||
input_len = Some(source.call.decimals.len());
|
||||
let mut on_chain_input_data = vec![];
|
||||
match source.calls {
|
||||
Calls::Multiple(calls) => {
|
||||
for call in calls {
|
||||
on_chain_input_data.push(call);
|
||||
}
|
||||
}
|
||||
Calls::Single(call) => {
|
||||
input_len = Some(call.len);
|
||||
}
|
||||
}
|
||||
Some(on_chain_input_data)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// If both model inputs and outputs are attested to then we
|
||||
|
||||
// Read the settings file. Look if either the run_ars.input_visibility, run_args.output_visibility or run_args.param_visibility is KZGCommit
|
||||
// if so, then we need to load the witness
|
||||
|
||||
@@ -1574,22 +1601,30 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
None
|
||||
};
|
||||
|
||||
let output: String = fix_da_sol(
|
||||
commitment_bytes,
|
||||
input_len.is_none() && output_len.is_none(),
|
||||
)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
// if either input_len or output_len is Some then we are in the single call data attestation mode
|
||||
if input_len.is_some() || output_len.is_some() {
|
||||
let output = fix_da_single_sol(input_len, output_len)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationSingle", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
} else {
|
||||
let output = fix_da_multi_sol(input_data, output_data, commitment_bytes)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationMulti", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
}
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) async fn deploy_da_evm(
|
||||
data: String,
|
||||
data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
@@ -1831,8 +1866,8 @@ pub(crate) fn setup(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) async fn setup_test_evm_data(
|
||||
data_path: String,
|
||||
pub(crate) async fn setup_test_evm_witness(
|
||||
data_path: PathBuf,
|
||||
compiled_circuit_path: PathBuf,
|
||||
test_data: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
@@ -1841,7 +1876,7 @@ pub(crate) async fn setup_test_evm_data(
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::graph::TestOnChainData;
|
||||
|
||||
let mut data = GraphData::from_str(&data_path)?;
|
||||
let mut data = GraphData::from_path(data_path)?;
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
|
||||
// if both input and output are from files fail
|
||||
@@ -1867,6 +1902,17 @@ pub(crate) async fn setup_test_evm_data(
|
||||
}
|
||||
|
||||
use crate::pfsys::ProofType;
|
||||
pub(crate) async fn test_update_account_calls(
|
||||
addr: H160Flag,
|
||||
data: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::eth::update_account_calls;
|
||||
|
||||
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn prove(
|
||||
@@ -2080,7 +2126,6 @@ pub(crate) fn mock_aggregate(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
|
||||
@@ -5,12 +5,10 @@ use halo2curves::ff::PrimeField;
|
||||
/// Integer representation of a PrimeField element.
|
||||
pub type IntegerRep = i128;
|
||||
|
||||
/// Converts an integer rep to a PrimeField element.
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else if x == IntegerRep::MIN {
|
||||
-F::from_u128(x.saturating_neg() as u128) - F::ONE
|
||||
} else {
|
||||
-F::from_u128(x.saturating_neg() as u128)
|
||||
}
|
||||
@@ -34,9 +32,6 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN;
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -56,7 +51,7 @@ mod test {
|
||||
use halo2curves::pasta::Fp as F;
|
||||
|
||||
#[test]
|
||||
fn integerreptofelt() {
|
||||
fn test_conv() {
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
@@ -74,24 +69,8 @@ mod test {
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmin() {
|
||||
let x = IntegerRep::MIN;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrepmax() {
|
||||
let x = IntegerRep::MAX;
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,12 +11,6 @@ pub enum GraphError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Non scalar power
|
||||
#[error("we only support scalar powers")]
|
||||
NonScalarPower,
|
||||
/// Non scalar base for exponentiation
|
||||
#[error("we only support scalar bases for exponentiation")]
|
||||
NonScalarBase,
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
@@ -33,7 +27,7 @@ pub enum GraphError {
|
||||
#[error("a node is missing required params: {0}")]
|
||||
MissingParams(String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node has misformed params: {0}")]
|
||||
#[error("a node is has misformed params: {0}")]
|
||||
MisformedParams(String),
|
||||
/// Error in the configuration of the visibility of variables
|
||||
#[error("there should be at least one set of public variables")]
|
||||
@@ -119,13 +113,13 @@ pub enum GraphError {
|
||||
/// Missing input for a node
|
||||
#[error("missing input for node {0}")]
|
||||
MissingInput(usize),
|
||||
/// Ranges can only be constant
|
||||
///
|
||||
#[error("range only supports constant inputs in a zk circuit")]
|
||||
NonConstantRange,
|
||||
/// Trilu diagonal must be constant
|
||||
///
|
||||
#[error("trilu only supports constant diagonals in a zk circuit")]
|
||||
NonConstantTrilu,
|
||||
/// The witness was too short
|
||||
///
|
||||
#[error("insufficient witness values to generate a fixed output")]
|
||||
InsufficientWitnessValues,
|
||||
/// Missing scale
|
||||
@@ -149,13 +143,4 @@ pub enum GraphError {
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
/// Only nearest neighbor interpolation is supported
|
||||
#[error("only nearest neighbor interpolation is supported")]
|
||||
InvalidInterpolation,
|
||||
/// Node has a missing output
|
||||
#[error("node {0} has a missing output")]
|
||||
MissingOutput(usize),
|
||||
/// Inssuficient advice columns
|
||||
#[error("insuficcient advice columns (need {0} at least)")]
|
||||
InsufficientAdviceColumns(usize),
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -455,10 +455,6 @@ pub struct GraphSettings {
|
||||
pub num_blinding_factors: Option<usize>,
|
||||
/// unix time timestamp
|
||||
pub timestamp: Option<u128>,
|
||||
/// Model inputs types (if any)
|
||||
pub input_types: Option<Vec<InputType>>,
|
||||
/// Model outputs types (if any)
|
||||
pub output_types: Option<Vec<InputType>>,
|
||||
}
|
||||
|
||||
impl GraphSettings {
|
||||
@@ -623,6 +619,11 @@ impl GraphSettings {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn uses_modules(&self) -> bool {
|
||||
!self.module_sizes.max_constraints() > 0
|
||||
}
|
||||
|
||||
/// if any visibility is encrypted or hashed
|
||||
pub fn module_requires_fixed(&self) -> bool {
|
||||
self.run_args.input_visibility.is_hashed()
|
||||
@@ -765,7 +766,7 @@ pub struct TestOnChainData {
|
||||
pub data: std::path::PathBuf,
|
||||
/// rpc endpoint
|
||||
pub rpc: Option<String>,
|
||||
/// data sources for the on chain data
|
||||
///
|
||||
pub data_sources: TestSources,
|
||||
}
|
||||
|
||||
@@ -953,7 +954,7 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
_ => Err(GraphError::OnChainDataSource),
|
||||
_ => unreachable!("cannot load from on-chain data"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1026,11 +1027,24 @@ impl GraphCircuit {
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
|
||||
use crate::eth::{
|
||||
evm_quantize_multi, evm_quantize_single, read_on_chain_inputs_multi,
|
||||
read_on_chain_inputs_single, setup_eth_backend,
|
||||
};
|
||||
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
|
||||
let input = read_on_chain_inputs(client.clone(), client_address, &source.call).await?;
|
||||
let quantized_evm_inputs =
|
||||
evm_quantize(client, scales, &input, &source.call.decimals).await?;
|
||||
let quantized_evm_inputs = match source.calls {
|
||||
input::Calls::Single(call) => {
|
||||
let (inputs, decimals) =
|
||||
read_on_chain_inputs_single(client.clone(), client_address, call).await?;
|
||||
|
||||
evm_quantize_single(client, scales, &inputs, decimals).await?
|
||||
}
|
||||
input::Calls::Multiple(calls) => {
|
||||
let inputs =
|
||||
read_on_chain_inputs_multi(client.clone(), client_address, &calls).await?;
|
||||
evm_quantize_multi(client, scales, &inputs).await?
|
||||
}
|
||||
};
|
||||
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
|
||||
let mut inputs: Vec<Tensor<Fp>> = vec![];
|
||||
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
|
||||
@@ -1431,8 +1445,6 @@ impl GraphCircuit {
|
||||
let output_scales = self.model().graph.get_output_scales()?;
|
||||
let input_shapes = self.model().graph.input_shapes()?;
|
||||
let output_shapes = self.model().graph.output_shapes()?;
|
||||
let mut input_data = None;
|
||||
let mut output_data = None;
|
||||
|
||||
if matches!(
|
||||
test_on_chain_data.data_sources.input,
|
||||
@@ -1443,12 +1455,23 @@ impl GraphCircuit {
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
|
||||
input_data = match &data.input_data {
|
||||
DataSource::File(input_data) => Some(input_data),
|
||||
let input_data = match &data.input_data {
|
||||
DataSource::File(input_data) => input_data,
|
||||
_ => {
|
||||
return Err(GraphError::MissingDataSource);
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
};
|
||||
// Get the flatten length of input_data
|
||||
// if the input source is a field then set scale to 0
|
||||
|
||||
let datam: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
|
||||
input_data,
|
||||
input_scales,
|
||||
input_shapes,
|
||||
test_on_chain_data.rpc.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
data.input_data = datam.1.into();
|
||||
}
|
||||
if matches!(
|
||||
test_on_chain_data.data_sources.output,
|
||||
@@ -1459,43 +1482,20 @@ impl GraphCircuit {
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
|
||||
output_data = match &data.output_data {
|
||||
Some(DataSource::File(output_data)) => Some(output_data),
|
||||
let output_data = match &data.output_data {
|
||||
Some(DataSource::File(output_data)) => output_data,
|
||||
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
|
||||
_ => return Err(GraphError::MissingDataSource),
|
||||
};
|
||||
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
|
||||
output_data,
|
||||
output_scales,
|
||||
output_shapes,
|
||||
test_on_chain_data.rpc.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
data.output_data = Some(datum.1.into());
|
||||
}
|
||||
// Merge the input and output data
|
||||
let mut file_data: Vec<Vec<input::FileSourceInner>> = vec![];
|
||||
let mut scales: Vec<crate::Scale> = vec![];
|
||||
let mut shapes: Vec<Vec<usize>> = vec![];
|
||||
if let Some(input_data) = input_data {
|
||||
file_data.extend(input_data.clone());
|
||||
scales.extend(input_scales.clone());
|
||||
shapes.extend(input_shapes.clone());
|
||||
}
|
||||
if let Some(output_data) = output_data {
|
||||
file_data.extend(output_data.clone());
|
||||
scales.extend(output_scales.clone());
|
||||
shapes.extend(output_shapes.clone());
|
||||
};
|
||||
// print file data
|
||||
debug!("file data: {:?}", file_data);
|
||||
|
||||
let on_chain_data: OnChainSource = OnChainSource::test_from_file_data(
|
||||
&file_data,
|
||||
scales,
|
||||
shapes,
|
||||
test_on_chain_data.rpc.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
// Here we update the GraphData struct with the on-chain data
|
||||
if input_data.is_some() {
|
||||
data.input_data = on_chain_data.clone().into();
|
||||
}
|
||||
if output_data.is_some() {
|
||||
data.output_data = Some(on_chain_data.into());
|
||||
}
|
||||
debug!("test on-chain data: {:?}", data);
|
||||
// Save the updated GraphData struct to the data_path
|
||||
data.save(test_on_chain_data.data)?;
|
||||
Ok(())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::errors::GraphError;
|
||||
use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::scale_to_multiplier;
|
||||
use super::vars::*;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
@@ -378,18 +379,13 @@ pub struct ParsedNodes {
|
||||
pub nodes: BTreeMap<usize, NodeType>,
|
||||
inputs: Vec<usize>,
|
||||
outputs: Vec<Outlet>,
|
||||
output_types: Vec<InputType>,
|
||||
}
|
||||
|
||||
impl ParsedNodes {
|
||||
/// Returns the output types of the computational graph.
|
||||
pub fn get_output_types(&self) -> Vec<InputType> {
|
||||
self.output_types.clone()
|
||||
}
|
||||
|
||||
/// Returns the number of the computational graph's inputs
|
||||
pub fn num_inputs(&self) -> usize {
|
||||
self.inputs.len()
|
||||
let input_nodes = self.inputs.iter();
|
||||
input_nodes.len()
|
||||
}
|
||||
|
||||
/// Input types
|
||||
@@ -429,7 +425,8 @@ impl ParsedNodes {
|
||||
|
||||
/// Returns the number of the computational graph's outputs
|
||||
pub fn num_outputs(&self) -> usize {
|
||||
self.outputs.len()
|
||||
let output_nodes = self.outputs.iter();
|
||||
output_nodes.len()
|
||||
}
|
||||
|
||||
/// Returns shapes of the computational graph's outputs
|
||||
@@ -496,16 +493,6 @@ impl Model {
|
||||
Ok(om)
|
||||
}
|
||||
|
||||
/// Gets the input types from the parsed nodes
|
||||
pub fn get_input_types(&self) -> Result<Vec<InputType>, GraphError> {
|
||||
self.graph.get_input_types()
|
||||
}
|
||||
|
||||
/// Gets the output types from the parsed nodes
|
||||
pub fn get_output_types(&self) -> Vec<InputType> {
|
||||
self.graph.get_output_types()
|
||||
}
|
||||
|
||||
///
|
||||
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
|
||||
let f = std::fs::File::create(&path).map_err(|e| {
|
||||
@@ -589,11 +576,6 @@ impl Model {
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
input_types: match self.get_input_types() {
|
||||
Ok(x) => Some(x),
|
||||
Err(_) => None,
|
||||
},
|
||||
output_types: Some(self.get_output_types()),
|
||||
num_dynamic_lookups: res.num_dynamic_lookups,
|
||||
total_dynamic_col_size: res.dynamic_lookup_col_coord,
|
||||
num_shuffles: res.num_shuffles,
|
||||
@@ -652,10 +634,6 @@ impl Model {
|
||||
|
||||
for (i, id) in model.clone().inputs.iter().enumerate() {
|
||||
let input = model.node_mut(id.node);
|
||||
|
||||
if input.outputs.len() == 0 {
|
||||
return Err(GraphError::MissingOutput(id.node));
|
||||
}
|
||||
let mut fact: InferenceFact = input.outputs[0].fact.clone();
|
||||
|
||||
for (i, x) in fact.clone().shape.dims().enumerate() {
|
||||
@@ -724,11 +702,6 @@ impl Model {
|
||||
nodes,
|
||||
inputs: model.inputs.iter().map(|o| o.node).collect(),
|
||||
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
|
||||
output_types: model
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|o| Ok::<InputType, GraphError>(model.outlet_fact(*o)?.datum_type.into()))
|
||||
.collect::<Result<Vec<_>, GraphError>>()?,
|
||||
};
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
@@ -887,15 +860,6 @@ impl Model {
|
||||
nodes: subgraph_nodes,
|
||||
inputs: model.inputs.iter().map(|o| o.node).collect(),
|
||||
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
|
||||
output_types: model
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|o| {
|
||||
Ok::<InputType, GraphError>(
|
||||
model.outlet_fact(*o)?.datum_type.into(),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?,
|
||||
};
|
||||
|
||||
let om = Model {
|
||||
@@ -942,7 +906,6 @@ impl Model {
|
||||
n.opkind = SupportedOp::Input(Input {
|
||||
scale,
|
||||
datum_type: inp.datum_type,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
});
|
||||
input_idx += 1;
|
||||
n.out_scale = scale;
|
||||
@@ -1053,10 +1016,6 @@ impl Model {
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
if vars.advices.len() < 3 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(3));
|
||||
}
|
||||
|
||||
let mut base_gate = PolyConfig::configure(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
@@ -1076,10 +1035,6 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_dynamic_lookup() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
|
||||
base_gate.configure_dynamic_lookup(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
@@ -1088,13 +1043,10 @@ impl Model {
|
||||
}
|
||||
|
||||
if settings.requires_shuffle() {
|
||||
if vars.advices.len() < 6 {
|
||||
return Err(GraphError::InsufficientAdviceColumns(6));
|
||||
}
|
||||
base_gate.configure_shuffles(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
vars.advices[3..6].try_into()?,
|
||||
vars.advices[0..2].try_into()?,
|
||||
vars.advices[3..5].try_into()?,
|
||||
)?;
|
||||
}
|
||||
|
||||
@@ -1109,7 +1061,6 @@ impl Model {
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `witnessed_outputs` - The values to compare against.
|
||||
/// * `constants` - The constants for the circuit.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn layout(
|
||||
&self,
|
||||
mut config: ModelConfig,
|
||||
@@ -1172,10 +1123,17 @@ impl Model {
|
||||
})?;
|
||||
|
||||
if run_args.output_visibility.is_public() || run_args.output_visibility.is_fixed() {
|
||||
let output_scales = self.graph.get_output_scales().map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let comparators = if run_args.output_visibility == Visibility::Public {
|
||||
let res = vars
|
||||
.instance
|
||||
@@ -1197,9 +1155,7 @@ impl Model {
|
||||
.layout(
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::Output {
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
.map_err(|e| e.into())
|
||||
})
|
||||
@@ -1459,9 +1415,11 @@ impl Model {
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
|
||||
let output_scales = self.graph.get_output_scales()?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.map(|output| {
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut comparator: ValTensor<Fp> = (0..output.len())
|
||||
.map(|_| {
|
||||
if !self.visibility.output.is_fixed() {
|
||||
@@ -1474,12 +1432,13 @@ impl Model {
|
||||
.into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
Box::new(HybridOp::Output {
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
@@ -1501,7 +1460,7 @@ impl Model {
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_felt_evals()
|
||||
.unwrap_or_else(|_| Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1571,7 +1530,6 @@ impl Model {
|
||||
let mut op = crate::circuit::Constant::new(
|
||||
c.quantized_values.clone(),
|
||||
c.raw_values.clone(),
|
||||
c.decomp,
|
||||
);
|
||||
op.pre_assign(consts[const_idx].clone());
|
||||
n.opkind = SupportedOp::Constant(op);
|
||||
@@ -1599,16 +1557,4 @@ impl Model {
|
||||
}
|
||||
Ok(instance_shapes)
|
||||
}
|
||||
|
||||
/// Input types of the computational graph's public inputs (if any)
|
||||
pub fn instance_types(&self) -> Result<Vec<InputType>, GraphError> {
|
||||
let mut instance_types = vec![];
|
||||
if self.visibility.input.is_public() {
|
||||
instance_types.extend(self.graph.get_input_types()?);
|
||||
}
|
||||
if self.visibility.output.is_public() {
|
||||
instance_types.extend(self.graph.get_output_types());
|
||||
}
|
||||
Ok(instance_types)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,11 +14,14 @@ use serde::{Deserialize, Serialize};
|
||||
use super::errors::GraphError;
|
||||
use super::{VarVisibility, Visibility};
|
||||
|
||||
/// poseidon len to hash in tree
|
||||
pub const POSEIDON_LEN_GRAPH: usize = 32;
|
||||
/// Poseidon number of instances
|
||||
pub const POSEIDON_INSTANCES: usize = 1;
|
||||
|
||||
/// Poseidon module type
|
||||
pub type ModulePoseidon = PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>;
|
||||
pub type ModulePoseidon =
|
||||
PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>;
|
||||
/// Poseidon module config
|
||||
pub type ModulePoseidonConfig = PoseidonConfig<POSEIDON_WIDTH, POSEIDON_RATE>;
|
||||
|
||||
@@ -281,6 +284,7 @@ impl GraphModules {
|
||||
log::error!("Poseidon config not initialized");
|
||||
return Err(Error::Synthesis);
|
||||
}
|
||||
// If the module is encrypted, then we need to encrypt the inputs
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
// Import dependencies for scaling operations
|
||||
use super::scale_to_multiplier;
|
||||
|
||||
// Import ONNX-specific utilities when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::utilities::node_output_shapes;
|
||||
|
||||
// Import scale management types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
|
||||
// Import visibility settings for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::Visibility;
|
||||
|
||||
// Import operation types for different circuit components
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
@@ -22,49 +13,28 @@ use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
|
||||
// Import graph error types for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::errors::GraphError;
|
||||
|
||||
// Import ONNX operation conversion utilities
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
|
||||
// Import tensor error handling
|
||||
use crate::tensor::TensorError;
|
||||
|
||||
// Import curve-specific field type
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
|
||||
// Import logging for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use log::trace;
|
||||
|
||||
// Import serialization traits
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
// Import data structures for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
// Import formatting traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::fmt;
|
||||
|
||||
// Import table display formatting for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tabled::Tabled;
|
||||
|
||||
// Import ONNX-specific types and traits for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::{
|
||||
self,
|
||||
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
|
||||
};
|
||||
|
||||
/// Helper function to format vectors for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
if !v.is_empty() {
|
||||
@@ -74,35 +44,29 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to format operation kinds for display
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn display_opkind(v: &SupportedOp) -> String {
|
||||
v.as_string()
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled to handle different precision requirements.
|
||||
/// This enables operations to work with inputs that have been scaled to different fixed-point representations.
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Rescaled {
|
||||
/// The underlying operation that needs to be rescaled
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// Vector of (index, scale) pairs defining how each input should be scaled
|
||||
/// The scale of the operation's inputs.
|
||||
pub scale: Vec<(usize, u128)>,
|
||||
}
|
||||
|
||||
/// Implementation of the Op trait for Rescaled operations
|
||||
impl Op<Fp> for Rescaled {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
@@ -113,7 +77,6 @@ impl Op<Fp> for Rescaled {
|
||||
Op::<Fp>::out_scale(&*self.inner, in_scales)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -130,40 +93,28 @@ impl Op<Fp> for Rescaled {
|
||||
self.inner.layout(config, region, res)
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone())
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for operations that require scale rebasing
|
||||
/// This handles cases where operation scales need to be adjusted to a target scale
|
||||
/// while preserving the numerical relationships
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RebaseScale {
|
||||
/// The operation that needs to be rescaled
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// Operation used for rebasing, typically division
|
||||
/// rebase op
|
||||
pub rebase_op: HybridOp,
|
||||
/// Scale that we're rebasing to
|
||||
/// scale being rebased to
|
||||
pub target_scale: i32,
|
||||
/// Original scale of operation's inputs before rebasing
|
||||
/// The original scale of the operation's inputs.
|
||||
pub original_scale: i32,
|
||||
/// Scaling multiplier used in rebasing
|
||||
/// multiplier
|
||||
pub multiplier: f64,
|
||||
}
|
||||
|
||||
impl RebaseScale {
|
||||
/// Creates a rebased version of an operation if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `global_scale` - Base scale for the system
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
/// * `scale_rebase_multiplier` - Factor determining when rebasing should occur
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation depending on scale relationships
|
||||
pub fn rebase(
|
||||
inner: SupportedOp,
|
||||
global_scale: crate::Scale,
|
||||
@@ -204,15 +155,7 @@ impl RebaseScale {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a rebased operation with increased scale
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `inner` - Operation to potentially rebase
|
||||
/// * `target_scale` - Scale to rebase to
|
||||
/// * `op_out_scale` - Current output scale of the operation
|
||||
///
|
||||
/// # Returns
|
||||
/// Original or rebased operation with increased scale
|
||||
pub fn rebase_up(
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
@@ -249,12 +192,10 @@ impl RebaseScale {
|
||||
}
|
||||
|
||||
impl Op<Fp> for RebaseScale {
|
||||
/// Convert to Any type for runtime type checking
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Get string representation of the operation
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
"REBASED (div={:?}, rebasing_op={}) ({})",
|
||||
@@ -264,12 +205,10 @@ impl Op<Fp> for RebaseScale {
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculate output scale based on input scales
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
/// Layout the operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -283,40 +222,34 @@ impl Op<Fp> for RebaseScale {
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone())
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents all supported operation types in the circuit
|
||||
/// Each variant encapsulates a different type of operation with specific behavior
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum SupportedOp {
|
||||
/// Linear operations (polynomial-based)
|
||||
/// A linear operation.
|
||||
Linear(PolyOp),
|
||||
/// Nonlinear operations requiring lookup tables
|
||||
/// A nonlinear operation.
|
||||
Nonlinear(LookupOp),
|
||||
/// Mixed operations combining different approaches
|
||||
/// A hybrid operation.
|
||||
Hybrid(HybridOp),
|
||||
/// Input values to the circuit
|
||||
///
|
||||
Input(Input),
|
||||
/// Constant values in the circuit
|
||||
///
|
||||
Constant(Constant<Fp>),
|
||||
/// Placeholder for unsupported operations
|
||||
///
|
||||
Unknown(Unknown),
|
||||
/// Operations requiring rescaling of inputs
|
||||
///
|
||||
Rescaled(Rescaled),
|
||||
/// Operations requiring scale rebasing
|
||||
///
|
||||
RebaseScale(RebaseScale),
|
||||
}
|
||||
|
||||
impl SupportedOp {
|
||||
/// Checks if the operation is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `true` if operation requires lookup table
|
||||
/// * `false` otherwise
|
||||
pub fn is_lookup(&self) -> bool {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(_) => true,
|
||||
@@ -324,12 +257,7 @@ impl SupportedOp {
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns input operation if this is an input
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(Input)` if this is an input operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_input(&self) -> Option<Input> {
|
||||
match self {
|
||||
SupportedOp::Input(op) => Some(op.clone()),
|
||||
@@ -337,11 +265,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to rebased operation if this is a rebased operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&RebaseScale)` if this is a rebased operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_rebased(&self) -> Option<&RebaseScale> {
|
||||
match self {
|
||||
SupportedOp::RebaseScale(op) => Some(op),
|
||||
@@ -349,11 +273,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to lookup operation if this is a lookup operation
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&LookupOp)` if this is a lookup operation
|
||||
/// * `None` otherwise
|
||||
pub fn get_lookup(&self) -> Option<&LookupOp> {
|
||||
match self {
|
||||
SupportedOp::Nonlinear(op) => Some(op),
|
||||
@@ -361,11 +281,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_constant(&self) -> Option<&Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -373,11 +289,7 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns mutable reference to constant if this is a constant
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Some(&mut Constant)` if this is a constant
|
||||
/// * `None` otherwise
|
||||
pub fn get_mutable_constant(&mut self) -> Option<&mut Constant<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Constant(op) => Some(op),
|
||||
@@ -385,19 +297,18 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a homogeneously rescaled version of this operation if needed
|
||||
/// Only available with EZKL feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn homogenous_rescale(
|
||||
&self,
|
||||
in_scales: Vec<crate::Scale>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
let op = self.clone_dyn();
|
||||
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
|
||||
}
|
||||
|
||||
/// Returns reference to underlying Op implementation
|
||||
/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
|
||||
fn as_op(&self) -> &dyn Op<Fp> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => op,
|
||||
@@ -411,10 +322,9 @@ impl SupportedOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if this is an identity operation
|
||||
///
|
||||
/// check if is the identity operation
|
||||
/// # Returns
|
||||
/// * `true` if this operation passes input through unchanged
|
||||
/// * `true` if the operation is the identity operation
|
||||
/// * `false` otherwise
|
||||
pub fn is_identity(&self) -> bool {
|
||||
match self {
|
||||
@@ -451,11 +361,9 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
if let Some(op) = value.as_any().downcast_ref::<Unknown>() {
|
||||
return SupportedOp::Unknown(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<Rescaled>() {
|
||||
return SupportedOp::Rescaled(op.clone());
|
||||
};
|
||||
|
||||
if let Some(op) = value.as_any().downcast_ref::<RebaseScale>() {
|
||||
return SupportedOp::RebaseScale(op.clone());
|
||||
};
|
||||
@@ -467,7 +375,6 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
/// Layout this operation in the circuit
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
@@ -477,61 +384,54 @@ impl Op<Fp> for SupportedOp {
|
||||
self.as_op().layout(config, region, values)
|
||||
}
|
||||
|
||||
/// Check if this is an input operation
|
||||
fn is_input(&self) -> bool {
|
||||
self.as_op().is_input()
|
||||
}
|
||||
|
||||
/// Check if this is a constant operation
|
||||
fn is_constant(&self) -> bool {
|
||||
self.as_op().is_constant()
|
||||
}
|
||||
|
||||
/// Get which inputs require homogeneous scales
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
self.as_op().requires_homogenous_input_scales()
|
||||
}
|
||||
|
||||
/// Create a clone of this operation
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
self.as_op().clone_dyn()
|
||||
}
|
||||
|
||||
/// Get string representation
|
||||
fn as_string(&self) -> String {
|
||||
self.as_op().as_string()
|
||||
}
|
||||
|
||||
/// Convert to Any type
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Calculate output scale from input scales
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a connection to another node's output
|
||||
/// First element is node index, second is output slot index
|
||||
/// A node's input is a tensor from another node's output.
|
||||
pub type Outlet = (usize, usize);
|
||||
|
||||
/// Represents a single computational node in the circuit graph
|
||||
/// Contains all information needed to execute and connect operations
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Node {
|
||||
/// The operation this node performs
|
||||
/// [Op] i.e what operation this node represents.
|
||||
pub opkind: SupportedOp,
|
||||
/// Fixed point scale factor for this node's output
|
||||
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
|
||||
pub out_scale: i32,
|
||||
/// Connections to other nodes' outputs that serve as inputs
|
||||
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
|
||||
// but in_dim is [in], out_dim is [out]
|
||||
/// The indices of the node's inputs.
|
||||
pub inputs: Vec<Outlet>,
|
||||
/// Shape of this node's output tensor
|
||||
/// Dimensions of output.
|
||||
pub out_dims: Vec<usize>,
|
||||
/// Unique identifier for this node
|
||||
/// The node's unique identifier.
|
||||
pub idx: usize,
|
||||
/// Number of times this node's output is used
|
||||
/// The node's num of uses
|
||||
pub num_uses: usize,
|
||||
}
|
||||
|
||||
@@ -569,19 +469,12 @@ impl PartialEq for Node {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Creates a new Node from an ONNX node
|
||||
/// Only available when EZKL feature is enabled
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `node` - Source ONNX node
|
||||
/// * `other_nodes` - Map of existing nodes in the graph
|
||||
/// * `scales` - Scale factors for variables
|
||||
/// * `idx` - Unique identifier for this node
|
||||
/// * `symbol_values` - ONNX symbol values
|
||||
/// * `run_args` - Runtime configuration arguments
|
||||
///
|
||||
/// # Returns
|
||||
/// New Node instance or error if creation fails
|
||||
/// Converts a tract [OnnxNode] into an ezkl [Node].
|
||||
/// # Arguments:
|
||||
/// * `node` - [OnnxNode]
|
||||
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
|
||||
/// * `public_params` - flag if parameters of model are public
|
||||
/// * `idx` - The node's unique identifier.
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
@@ -719,14 +612,16 @@ impl Node {
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if this node performs softmax operation
|
||||
/// check if it is a softmax node
|
||||
pub fn is_softmax(&self) -> bool {
|
||||
matches!(self.opkind, SupportedOp::Hybrid(HybridOp::Softmax { .. }))
|
||||
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to rescale constants that are only used once
|
||||
/// Only available when EZKL feature is enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn rescale_const_with_single_use(
|
||||
constant: &mut Constant<Fp>,
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
use super::errors::GraphError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
use super::errors::GraphError;
|
||||
use super::{Rescaled, SupportedOp, Visibility};
|
||||
use crate::circuit::Op;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -22,7 +22,6 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
Downsample,
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
@@ -32,6 +31,7 @@ use tract_onnx::tract_core::ops::{
|
||||
einsum::EinSum,
|
||||
element_wise::ElementWiseOp,
|
||||
nn::{LeakyRelu, Reduce, Softmax},
|
||||
Downsample,
|
||||
};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::{
|
||||
@@ -39,15 +39,16 @@ use tract_onnx::tract_hir::{
|
||||
ops::array::{Pad, PadMode, TypedConcat},
|
||||
ops::cnn::PoolSpec,
|
||||
ops::konst::Const,
|
||||
ops::nn::DataFormat,
|
||||
tract_core::ops::cast::Cast,
|
||||
tract_core::ops::cnn::{MaxPool, SumPool},
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
|
||||
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
|
||||
/// Arguments
|
||||
///
|
||||
/// * `elem` - the element to quantize.
|
||||
/// * `vec` - the vector to quantize.
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(
|
||||
@@ -58,7 +59,7 @@ pub fn quantize_float(
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value || *elem < -max_value {
|
||||
if *elem > max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
@@ -84,7 +85,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
|
||||
f64::powf(2., scale as f64)
|
||||
}
|
||||
|
||||
/// Converts a fixed point multiplier to a scale (log base 2).
|
||||
/// Converts a scale (log base 2) to a fixed point multiplier.
|
||||
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
mult.log2().round() as crate::Scale
|
||||
}
|
||||
@@ -227,7 +228,10 @@ pub fn extract_tensor_value(
|
||||
.iter()
|
||||
.map(|x| match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
Err(_) => match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -273,10 +277,12 @@ pub fn new_op_from_onnx(
|
||||
symbol_values: &SymbolValues,
|
||||
run_args: &crate::RunArgs,
|
||||
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
|
||||
use crate::circuit::InputType;
|
||||
use std::f64::consts::E;
|
||||
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
use crate::circuit::InputType;
|
||||
|
||||
let input_scales = inputs
|
||||
.iter()
|
||||
.flat_map(|x| x.out_scales())
|
||||
@@ -306,9 +312,6 @@ pub fn new_op_from_onnx(
|
||||
let mut deleted_indices = vec![];
|
||||
let node = match node.op().name().as_ref() {
|
||||
"ShiftLeft" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -321,13 +324,10 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
|
||||
}
|
||||
}
|
||||
"ShiftRight" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
|
||||
};
|
||||
// load shift amount
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
@@ -340,7 +340,7 @@ pub fn new_op_from_onnx(
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
@@ -363,10 +363,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
if input_ops.len() != 3 {
|
||||
return Err(GraphError::InvalidDims(idx, "range".to_string()));
|
||||
}
|
||||
|
||||
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
|
||||
let input_ops = input_ops
|
||||
.iter()
|
||||
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
|
||||
@@ -381,11 +378,7 @@ pub fn new_op_from_onnx(
|
||||
// Quantize the raw value (integers)
|
||||
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
|
||||
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
raw_value,
|
||||
!run_args.ignore_range_check_inputs_outputs,
|
||||
);
|
||||
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
|
||||
// Create a constant op
|
||||
SupportedOp::Constant(c)
|
||||
}
|
||||
@@ -426,10 +419,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
|
||||
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| {
|
||||
@@ -447,7 +436,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: false,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -459,17 +447,8 @@ pub fn new_op_from_onnx(
|
||||
"Topk" => {
|
||||
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
};
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
|
||||
}
|
||||
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
c.raw_values.map(|x| x as usize)[0]
|
||||
@@ -509,10 +488,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
|
||||
}
|
||||
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -524,7 +499,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -548,9 +522,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -561,7 +532,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -585,9 +555,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -599,7 +566,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -623,9 +589,6 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
|
||||
}
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -637,7 +600,6 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
@@ -712,11 +674,7 @@ pub fn new_op_from_onnx(
|
||||
constant_scale,
|
||||
&run_args.param_visibility,
|
||||
)?;
|
||||
let c = crate::circuit::ops::Constant::new(
|
||||
quantized_value,
|
||||
raw_value,
|
||||
run_args.ignore_range_check_inputs_outputs,
|
||||
);
|
||||
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
|
||||
// Create a constant op
|
||||
SupportedOp::Constant(c)
|
||||
}
|
||||
@@ -726,9 +684,7 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
|
||||
}
|
||||
assert_eq!(axes.len(), 1, "only support argmax over one axis");
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
|
||||
}
|
||||
@@ -738,9 +694,7 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
if axes.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
|
||||
}
|
||||
assert_eq!(axes.len(), 1, "only support argmin over one axis");
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
|
||||
}
|
||||
@@ -849,9 +803,6 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
@@ -895,9 +846,6 @@ pub fn new_op_from_onnx(
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
|
||||
};
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
@@ -979,19 +927,13 @@ pub fn new_op_from_onnx(
|
||||
DatumType::F64 => (scales.input, InputType::F64),
|
||||
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
|
||||
};
|
||||
SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale,
|
||||
datum_type,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
})
|
||||
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
|
||||
}
|
||||
"Cast" => {
|
||||
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
|
||||
let dt = op.to;
|
||||
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
|
||||
};
|
||||
assert_eq!(input_scales.len(), 1);
|
||||
|
||||
match dt {
|
||||
DatumType::Bool
|
||||
@@ -1041,11 +983,6 @@ pub fn new_op_from_onnx(
|
||||
|
||||
if const_idx.len() == 1 {
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if inputs.len() <= const_idx {
|
||||
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
|
||||
// if not divisible by 2 then we need to add a range check
|
||||
@@ -1120,9 +1057,6 @@ pub fn new_op_from_onnx(
|
||||
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
|
||||
}
|
||||
};
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
|
||||
}
|
||||
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
@@ -1145,6 +1079,13 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
@@ -1153,45 +1094,24 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
stride: stride.to_vec(),
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
data_format: pool_spec.data_format.into(),
|
||||
})
|
||||
}
|
||||
"Ceil" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Floor" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Round" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "round".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"RoundHalfToEven" => {
|
||||
if input_scales.len() != 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
|
||||
}
|
||||
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
// Extract the slope layer hyperparams from a const
|
||||
@@ -1201,9 +1121,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.len() > 1 {
|
||||
return Err(GraphError::NonScalarPower);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
unimplemented!("only support scalar pow")
|
||||
}
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
@@ -1220,9 +1138,7 @@ pub fn new_op_from_onnx(
|
||||
inputs[0].decrement_use();
|
||||
deleted_indices.push(0);
|
||||
if c.raw_values.len() > 1 {
|
||||
return Err(GraphError::NonScalarBase);
|
||||
} else if c.raw_values.is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
unimplemented!("only support scalar base")
|
||||
}
|
||||
|
||||
let base = c.raw_values[0];
|
||||
@@ -1232,14 +1148,10 @@ pub fn new_op_from_onnx(
|
||||
base: base.into(),
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
|
||||
unimplemented!("only support constant base or pow for now")
|
||||
}
|
||||
}
|
||||
"Div" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -1247,15 +1159,14 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 || const_idx.is_empty() {
|
||||
if const_idx.len() > 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if const_idx != 1 {
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
unimplemented!("only support div with constant as second input")
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
@@ -1265,28 +1176,14 @@ pub fn new_op_from_onnx(
|
||||
// get the non constant index
|
||||
let denom = c.raw_values[0];
|
||||
|
||||
let op = SupportedOp::Hybrid(HybridOp::Div {
|
||||
SupportedOp::Hybrid(HybridOp::Div {
|
||||
denom: denom.into(),
|
||||
});
|
||||
|
||||
// if the input is scale 0 we re up to the max scale
|
||||
if input_scales[0] == 0 {
|
||||
SupportedOp::Rescaled(Rescaled {
|
||||
inner: Box::new(op),
|
||||
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
|
||||
})
|
||||
} else {
|
||||
op
|
||||
}
|
||||
})
|
||||
} else {
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support non zero divisors of size 1".to_string(),
|
||||
));
|
||||
unimplemented!("only support non zero divisors of size 1")
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support div with constant as second input".to_string(),
|
||||
));
|
||||
unimplemented!("only support div with constant as second input")
|
||||
}
|
||||
}
|
||||
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
|
||||
@@ -1307,6 +1204,15 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
if ((conv_node.pool_spec.data_format != DataFormat::NCHW)
|
||||
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|
||||
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool_spec = &conv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
@@ -1334,8 +1240,6 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
data_format: conv_node.pool_spec.data_format.into(),
|
||||
kernel_format: conv_node.kernel_fmt.into(),
|
||||
})
|
||||
}
|
||||
"Not" => SupportedOp::Linear(PolyOp::Not),
|
||||
@@ -1359,6 +1263,14 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|
||||
|| (deconv_node.kernel_format != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool_spec = &deconv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
@@ -1384,8 +1296,6 @@ pub fn new_op_from_onnx(
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
stride,
|
||||
group: deconv_node.group,
|
||||
data_format: deconv_node.pool_spec.data_format.into(),
|
||||
kernel_format: deconv_node.kernel_format.into(),
|
||||
})
|
||||
}
|
||||
"Downsample" => {
|
||||
@@ -1398,7 +1308,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::Downsample {
|
||||
axis: downsample_node.axis,
|
||||
stride: downsample_node.stride,
|
||||
stride: downsample_node.stride as usize,
|
||||
modulo: downsample_node.modulo,
|
||||
})
|
||||
}
|
||||
@@ -1413,7 +1323,7 @@ pub fn new_op_from_onnx(
|
||||
if !resize_node.contains("interpolator: Nearest")
|
||||
&& !resize_node.contains("nearest: Floor")
|
||||
{
|
||||
return Err(GraphError::InvalidInterpolation);
|
||||
unimplemented!("Only nearest neighbor interpolation is supported")
|
||||
}
|
||||
// check if optional scale factor is present
|
||||
if inputs.len() != 2 && inputs.len() != 3 {
|
||||
@@ -1469,6 +1379,13 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
|
||||
@@ -1477,7 +1394,6 @@ pub fn new_op_from_onnx(
|
||||
stride: stride.to_vec(),
|
||||
kernel_shape: pool_spec.kernel_shape.to_vec(),
|
||||
normalized: sumpool_node.normalize,
|
||||
data_format: pool_spec.data_format.into(),
|
||||
})
|
||||
}
|
||||
"Pad" => {
|
||||
@@ -1511,10 +1427,6 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Linear(PolyOp::Reshape(output_shape))
|
||||
}
|
||||
"Flatten" => {
|
||||
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
|
||||
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
|
||||
};
|
||||
|
||||
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
|
||||
SupportedOp::Linear(PolyOp::Flatten(new_dims))
|
||||
}
|
||||
@@ -1588,10 +1500,12 @@ pub fn homogenize_input_scales(
|
||||
input_scales: Vec<crate::Scale>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let relevant_input_scales = inputs_to_scale
|
||||
.iter()
|
||||
.filter(|idx| input_scales.len() > **idx)
|
||||
.map(|&idx| input_scales[idx])
|
||||
let relevant_input_scales = input_scales
|
||||
.clone()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| inputs_to_scale.contains(idx))
|
||||
.map(|(_, scale)| scale)
|
||||
.collect_vec();
|
||||
|
||||
if inputs_to_scale.is_empty() {
|
||||
@@ -1632,30 +1546,10 @@ pub fn homogenize_input_scales(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// tests for the utility module
|
||||
pub mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
// quantization tests
|
||||
#[test]
|
||||
fn test_quantize_tensor() {
|
||||
let tensor: Tensor<f32> = (0..10).map(|x| x as f32).into();
|
||||
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
let scale = 0;
|
||||
let visibility = &Visibility::Public;
|
||||
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
|
||||
assert_eq!(quantized.len(), 10);
|
||||
assert_eq!(quantized, reference);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_edge_cases() {
|
||||
assert_eq!(quantize_float(&f64::NAN, 0.0, 0).unwrap(), 0);
|
||||
assert!(quantize_float(&f64::INFINITY, 0.0, 0).is_err());
|
||||
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flatten_valtensors() {
|
||||
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();
|
||||
|
||||
@@ -11,34 +11,35 @@ use log::debug;
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Defines the visibility level of values within the zero-knowledge circuit
|
||||
/// Controls how values are handled during proof generation and verification
|
||||
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
|
||||
pub enum Visibility {
|
||||
/// Value is private to the prover and not included in proof
|
||||
/// Mark an item as private to the prover (not in the proof submitted for verification)
|
||||
#[default]
|
||||
Private,
|
||||
/// Value is public and included in proof for verification
|
||||
/// Mark an item as public (sent in the proof submitted for verification)
|
||||
Public,
|
||||
/// Value is hashed and the hash is included in proof
|
||||
/// Mark an item as publicly committed to (hash sent in the proof submitted for verification)
|
||||
Hashed {
|
||||
/// Controls how the hash is handled in proof
|
||||
/// true - hash is included directly in proof (public)
|
||||
/// false - hash is used as advice and passed to computational graph
|
||||
/// Whether the hash is used as an instance (sent in the proof submitted for verification)
|
||||
/// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph
|
||||
/// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph
|
||||
hash_is_public: bool,
|
||||
/// Specifies which outputs this hash affects
|
||||
///
|
||||
outlets: Vec<usize>,
|
||||
},
|
||||
/// Value is committed using KZG commitment scheme
|
||||
/// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification)
|
||||
KZGCommit,
|
||||
/// Value is assigned as a constant in the circuit
|
||||
/// assigned as a constant in the circuit
|
||||
Fixed,
|
||||
}
|
||||
|
||||
@@ -65,17 +66,15 @@ impl Display for Visibility {
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Visibility {
|
||||
/// Converts visibility to command line flags
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
/// Converts string representation to Visibility
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
// Split on last occurrence of '/'
|
||||
// split on last occurrence of '/'
|
||||
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -107,8 +106,8 @@ impl<'a> From<&'a str> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Visibility {
|
||||
/// Converts Visibility to Python object
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
Visibility::Private => "private".to_object(py),
|
||||
@@ -135,13 +134,14 @@ impl IntoPy<PyObject> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Visibility {
|
||||
/// Extracts Visibility from Python object
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
let strval = strval.as_str();
|
||||
|
||||
if strval.contains("hashed/private") {
|
||||
// split on last occurence of '/'
|
||||
let (_, outlets) = strval.split_at(strval.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -174,32 +174,29 @@ impl<'source> FromPyObject<'source> for Visibility {
|
||||
}
|
||||
|
||||
impl Visibility {
|
||||
/// Returns true if visibility is Fixed
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_fixed(&self) -> bool {
|
||||
matches!(&self, Visibility::Fixed)
|
||||
}
|
||||
|
||||
/// Returns true if visibility is Private or hashed private
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_private(&self) -> bool {
|
||||
matches!(&self, Visibility::Private) || self.is_hashed_private()
|
||||
}
|
||||
|
||||
/// Returns true if visibility is Public
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_public(&self) -> bool {
|
||||
matches!(&self, Visibility::Public)
|
||||
}
|
||||
|
||||
/// Returns true if visibility involves hashing
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. })
|
||||
}
|
||||
|
||||
/// Returns true if visibility uses KZG commitment
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_polycommit(&self) -> bool {
|
||||
matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
/// Returns true if visibility is hashed with public hash
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed_public(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: true,
|
||||
@@ -210,8 +207,7 @@ impl Visibility {
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns true if visibility is hashed with private hash
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_hashed_private(&self) -> bool {
|
||||
if let Visibility::Hashed {
|
||||
hash_is_public: false,
|
||||
@@ -223,12 +219,11 @@ impl Visibility {
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns true if visibility requires additional processing
|
||||
#[allow(missing_docs)]
|
||||
pub fn requires_processing(&self) -> bool {
|
||||
matches!(&self, Visibility::Hashed { .. }) | matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
/// Returns vector of output indices that this visibility setting affects
|
||||
#[allow(missing_docs)]
|
||||
pub fn overwrites_inputs(&self) -> Vec<usize> {
|
||||
if let Visibility::Hashed { outlets, .. } = self {
|
||||
return outlets.clone();
|
||||
@@ -237,14 +232,14 @@ impl Visibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages scaling factors for different parts of the model
|
||||
/// Represents the scale of the model input, model parameters.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarScales {
|
||||
/// Scale factor for input values
|
||||
///
|
||||
pub input: crate::Scale,
|
||||
/// Scale factor for parameter values
|
||||
///
|
||||
pub params: crate::Scale,
|
||||
/// Multiplier for scale rebasing
|
||||
///
|
||||
pub rebase_multiplier: u32,
|
||||
}
|
||||
|
||||
@@ -255,17 +250,17 @@ impl std::fmt::Display for VarScales {
|
||||
}
|
||||
|
||||
impl VarScales {
|
||||
/// Returns maximum scale value
|
||||
///
|
||||
pub fn get_max(&self) -> crate::Scale {
|
||||
std::cmp::max(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Returns minimum scale value
|
||||
///
|
||||
pub fn get_min(&self) -> crate::Scale {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Creates VarScales from runtime arguments
|
||||
/// Place in [VarScales] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
@@ -275,17 +270,16 @@ impl VarScales {
|
||||
}
|
||||
}
|
||||
|
||||
/// Controls visibility settings for different parts of the model
|
||||
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
|
||||
pub struct VarVisibility {
|
||||
/// Visibility of model inputs
|
||||
/// Input to the model or computational graph
|
||||
pub input: Visibility,
|
||||
/// Visibility of model parameters (weights, biases)
|
||||
/// Parameters, such as weights and biases, in the model
|
||||
pub params: Visibility,
|
||||
/// Visibility of model outputs
|
||||
/// Output of the model or computational graph
|
||||
pub output: Visibility,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VarVisibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
@@ -307,7 +301,8 @@ impl Default for VarVisibility {
|
||||
}
|
||||
|
||||
impl VarVisibility {
|
||||
/// Creates visibility settings from runtime arguments
|
||||
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Place in [VarVisibility] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let input_vis = &args.input_visibility;
|
||||
let params_vis = &args.param_visibility;
|
||||
@@ -318,17 +313,17 @@ impl VarVisibility {
|
||||
}
|
||||
|
||||
if !output_vis.is_public()
|
||||
&& !params_vis.is_public()
|
||||
&& !input_vis.is_public()
|
||||
&& !output_vis.is_fixed()
|
||||
&& !params_vis.is_fixed()
|
||||
&& !input_vis.is_fixed()
|
||||
&& !output_vis.is_hashed()
|
||||
&& !params_vis.is_hashed()
|
||||
&& !input_vis.is_hashed()
|
||||
&& !output_vis.is_polycommit()
|
||||
&& !params_vis.is_polycommit()
|
||||
&& !input_vis.is_polycommit()
|
||||
& !params_vis.is_public()
|
||||
& !input_vis.is_public()
|
||||
& !output_vis.is_fixed()
|
||||
& !params_vis.is_fixed()
|
||||
& !input_vis.is_fixed()
|
||||
& !output_vis.is_hashed()
|
||||
& !params_vis.is_hashed()
|
||||
& !input_vis.is_hashed()
|
||||
& !output_vis.is_polycommit()
|
||||
& !params_vis.is_polycommit()
|
||||
& !input_vis.is_polycommit()
|
||||
{
|
||||
return Err(GraphError::Visibility);
|
||||
}
|
||||
@@ -340,17 +335,17 @@ impl VarVisibility {
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for circuit columns used by a model
|
||||
/// A wrapper for holding all columns that will be assigned to by a model.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Advice columns for circuit assignments
|
||||
#[allow(missing_docs)]
|
||||
pub advices: Vec<VarTensor>,
|
||||
/// Optional instance column for public inputs
|
||||
#[allow(missing_docs)]
|
||||
pub instance: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
/// Gets reference to instance column if it exists
|
||||
/// Get instance col
|
||||
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
|
||||
if let Some(instance) = &self.instance {
|
||||
match instance {
|
||||
@@ -362,14 +357,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets initial offset for instance values
|
||||
/// Set the initial instance offset
|
||||
pub fn set_initial_instance_offset(&mut self, offset: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_initial_instance_offset(offset);
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets total length of instance data
|
||||
/// Get the total instance len
|
||||
pub fn get_instance_len(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_total_instance_len()
|
||||
@@ -378,21 +373,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments instance index
|
||||
/// Increment the instance offset
|
||||
pub fn increment_instance_idx(&mut self) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.increment_idx();
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets instance index to specific value
|
||||
/// Reset the instance offset
|
||||
pub fn set_instance_idx(&mut self, val: usize) {
|
||||
if let Some(instance) = &mut self.instance {
|
||||
instance.set_idx(val);
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets current instance index
|
||||
/// Get the instance offset
|
||||
pub fn get_instance_idx(&self) -> usize {
|
||||
if let Some(instance) = &self.instance {
|
||||
instance.get_idx()
|
||||
@@ -401,7 +396,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes instance column with specified dimensions and scale
|
||||
///
|
||||
pub fn instantiate_instance(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -422,7 +417,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates new ModelVars with allocated columns based on settings
|
||||
/// Allocate all columns that will be assigned to by a model.
|
||||
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
@@ -440,7 +435,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
.collect_vec();
|
||||
|
||||
if requires_dynamic_lookup || requires_shuffle {
|
||||
let num_cols = 3;
|
||||
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
|
||||
for _ in 0..num_cols {
|
||||
let dynamic_lookup =
|
||||
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
|
||||
|
||||
298
src/lib.rs
298
src/lib.rs
@@ -28,9 +28,6 @@
|
||||
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
use log::warn;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc as _;
|
||||
|
||||
/// Error type
|
||||
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
|
||||
@@ -97,11 +94,12 @@ impl From<String> for EZKLError {
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode};
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::Args;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use graph::Visibility;
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
};
|
||||
@@ -167,6 +165,7 @@ pub mod srs_sha;
|
||||
pub mod tensor;
|
||||
#[cfg(feature = "ios-bindings")]
|
||||
uniffi::setup_scaffolding!();
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -181,9 +180,11 @@ lazy_static! {
|
||||
.unwrap_or("8000".to_string())
|
||||
.parse()
|
||||
.unwrap();
|
||||
|
||||
/// The serialization format for the keys
|
||||
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
|
||||
.unwrap_or("raw-bytes".to_string());
|
||||
|
||||
}
|
||||
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
@@ -265,101 +266,80 @@ impl From<String> for Commitments {
|
||||
}
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
///
|
||||
/// RunArgs contains all configuration parameters needed to control the proving process,
|
||||
/// including scaling factors, visibility settings, and circuit parameters.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(Args, ToFlags)
|
||||
)]
|
||||
pub struct RunArgs {
|
||||
/// Fixed point scaling factor for quantizing inputs
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
/// The tolerance for error on model outputs
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
|
||||
pub tolerance: Tolerance,
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub input_scale: Scale,
|
||||
/// Fixed point scaling factor for quantizing parameters
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
/// The denominator in the fixed point representation used when quantizing parameters
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub param_scale: Scale,
|
||||
/// Scale rebase threshold multiplier
|
||||
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
|
||||
/// Advanced parameter that should be used with caution
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// Range for lookup table input column values
|
||||
/// Specified as (min, max) pair
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
|
||||
pub lookup_range: Range,
|
||||
/// Log2 of the number of rows in the circuit
|
||||
/// Controls circuit size and proving time
|
||||
/// The log_2 number of rows
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
|
||||
pub logrows: u32,
|
||||
/// Number of inner columns per block
|
||||
/// Affects circuit layout and efficiency
|
||||
/// The log_2 number of rows
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
pub num_inner_cols: usize,
|
||||
/// Graph variables for parameterizing the computation
|
||||
/// Format: "name->value", e.g. "batch_size->1"
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Visibility setting for input values
|
||||
/// Controls whether inputs are public or private in the circuit
|
||||
/// Flags whether inputs are public, private, fixed, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub input_visibility: Visibility,
|
||||
/// Visibility setting for output values
|
||||
/// Controls whether outputs are public or private in the circuit
|
||||
/// Flags whether outputs are public, private, fixed, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
|
||||
pub output_visibility: Visibility,
|
||||
/// Visibility setting for parameters
|
||||
/// Controls how parameters are handled in the circuit
|
||||
/// Flags whether params are fixed, private, hashed, polycommit
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
|
||||
pub param_visibility: Visibility,
|
||||
/// Whether to rebase constants with zero fractional part to scale 0
|
||||
/// Can improve efficiency for integer constants
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
/// Circuit checking mode
|
||||
/// Controls level of constraint verification
|
||||
/// check mode (safe, unsafe, etc)
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
|
||||
pub check_mode: CheckMode,
|
||||
/// Commitment scheme for circuit proving
|
||||
/// Affects proof size and verification time
|
||||
/// commitment scheme
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
|
||||
pub commitment: Option<Commitments>,
|
||||
/// Base for number decomposition
|
||||
/// Must be a power of 2
|
||||
/// the base used for decompositions
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
|
||||
pub decomp_base: usize,
|
||||
/// Number of decomposition legs
|
||||
/// Controls decomposition granularity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
/// Whether to use bounded lookup for logarithm computation
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
/// Range check inputs and outputs (turn off if the inputs are felts)
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
/// Creates a new RunArgs instance with default values
|
||||
///
|
||||
/// Default configuration is optimized for common use cases
|
||||
/// while maintaining reasonable proving time and circuit size
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
scale_rebase_multiplier: 1,
|
||||
@@ -375,139 +355,54 @@ impl Default for RunArgs {
|
||||
commitment: None,
|
||||
decomp_base: 16384,
|
||||
decomp_legs: 2,
|
||||
ignore_range_check_inputs_outputs: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
/// Validates the RunArgs configuration
|
||||
///
|
||||
/// Performs comprehensive validation of all parameters to ensure they are within
|
||||
/// acceptable ranges and follow required constraints. Returns accumulated errors
|
||||
/// if any validations fail.
|
||||
///
|
||||
/// # Returns
|
||||
/// - Ok(()) if all validations pass
|
||||
/// - Err(String) with detailed error message if any validation fails
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// check if the largest represented integer in the decomposed form overflows IntegerRep
|
||||
// try it with the largest possible value
|
||||
let max_decomp = (self.decomp_base as IntegerRep).checked_pow(self.decomp_legs as u32);
|
||||
if max_decomp.is_none() {
|
||||
errors.push(format!(
|
||||
"decomp_base^decomp_legs overflows IntegerRep: {}^{}",
|
||||
self.decomp_base, self.decomp_legs
|
||||
));
|
||||
}
|
||||
|
||||
// Visibility validations
|
||||
if self.param_visibility == Visibility::Public {
|
||||
errors.push(
|
||||
"Parameters cannot be public instances. Use 'fixed' or 'kzgcommit' instead"
|
||||
.to_string(),
|
||||
return Err(
|
||||
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Scale validations
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
errors.push("scale_rebase_multiplier must be >= 1".to_string());
|
||||
return Err("scale_rebase_multiplier must be >= 1".into());
|
||||
}
|
||||
|
||||
// if any of the scales are too small
|
||||
if self.input_scale < 8 || self.param_scale < 8 {
|
||||
warn!("low scale values (<8) may impact precision");
|
||||
}
|
||||
|
||||
// Lookup range validations
|
||||
if self.lookup_range.0 > self.lookup_range.1 {
|
||||
errors.push(format!(
|
||||
"Invalid lookup range: min ({}) is greater than max ({})",
|
||||
self.lookup_range.0, self.lookup_range.1
|
||||
));
|
||||
return Err("lookup_range min is greater than max".into());
|
||||
}
|
||||
|
||||
// Size validations
|
||||
if self.logrows < 1 {
|
||||
errors.push("logrows must be >= 1".to_string());
|
||||
return Err("logrows must be >= 1".into());
|
||||
}
|
||||
|
||||
if self.num_inner_cols < 1 {
|
||||
errors.push("num_inner_cols must be >= 1".to_string());
|
||||
return Err("num_inner_cols must be >= 1".into());
|
||||
}
|
||||
|
||||
let batch_size = self.variables.iter().find(|(name, _)| name == "batch_size");
|
||||
if let Some(batch_size) = batch_size {
|
||||
if batch_size.1 == 0 {
|
||||
errors.push("'batch_size' cannot be 0".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Decomposition validations
|
||||
if self.decomp_base == 0 {
|
||||
errors.push("decomp_base cannot be 0".to_string());
|
||||
}
|
||||
|
||||
if self.decomp_legs == 0 {
|
||||
errors.push("decomp_legs cannot be 0".to_string());
|
||||
}
|
||||
|
||||
// Performance validations
|
||||
if self.logrows > MAX_PUBLIC_SRS {
|
||||
warn!("logrows exceeds maximum public SRS size");
|
||||
}
|
||||
|
||||
// Performance warnings
|
||||
if self.input_scale > 20 || self.param_scale > 20 {
|
||||
warn!("High scale values (>20) may impact performance");
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors.join("\n"))
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
return Err("tolerance > 0.0 requires output_visibility to be public".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Exports the configuration as JSON
|
||||
///
|
||||
/// Serializes the RunArgs instance to a JSON string
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(String)` containing JSON representation
|
||||
/// * `Err` if serialization fails
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let res = serde_json::to_string(&self)?;
|
||||
Ok(res)
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
/// Parses configuration from JSON
|
||||
///
|
||||
/// Deserializes a RunArgs instance from a JSON string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `arg_json` - JSON string containing configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(RunArgs)` if parsing succeeds
|
||||
/// * `Err` if parsing fails
|
||||
/// Parse an ezkl configuration from a json
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(arg_json)
|
||||
}
|
||||
}
|
||||
|
||||
// Additional helper functions for the module
|
||||
|
||||
/// Parses a key-value pair from a string in the format "key->value"
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `s` - Input string in the format "key->value"
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok((T, U))` - Parsed key and value
|
||||
/// * `Err` - If parsing fails
|
||||
/// Parse a single key-value pair
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
@@ -520,15 +415,14 @@ where
|
||||
{
|
||||
let pos = s
|
||||
.find("->")
|
||||
.ok_or_else(|| format!("invalid KEY->VALUE: no `->` found in `{s}`"))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 2..].parse()?))
|
||||
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
|
||||
let a = s[..pos].parse()?;
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
}
|
||||
|
||||
/// Verifies that a version string matches the expected artifact version
|
||||
/// Logs warnings for version mismatches or unversioned artifacts
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `artifact_version` - Version string from the artifact
|
||||
/// Check if the version string matches the artifact version
|
||||
/// If the version string does not match the artifact version, log a warning
|
||||
pub fn check_version_string_matches(artifact_version: &str) {
|
||||
if artifact_version == "0.0.0"
|
||||
|| artifact_version == "source - no compatibility guaranteed"
|
||||
@@ -553,81 +447,3 @@ pub fn check_version_string_matches(artifact_version: &str) {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::field_reassign_with_default)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_valid_default_args() {
|
||||
let args = RunArgs::default();
|
||||
assert!(args.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_param_visibility() {
|
||||
let mut args = RunArgs::default();
|
||||
args.param_visibility = Visibility::Public;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Parameters cannot be public instances"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_scale_rebase() {
|
||||
let mut args = RunArgs::default();
|
||||
args.scale_rebase_multiplier = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("scale_rebase_multiplier must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_lookup_range() {
|
||||
let mut args = RunArgs::default();
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Invalid lookup range"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_logrows() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("logrows must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_inner_cols() {
|
||||
let mut args = RunArgs::default();
|
||||
args.num_inner_cols = 0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("num_inner_cols must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_batch_size() {
|
||||
let mut args = RunArgs::default();
|
||||
args.variables = vec![("batch_size".to_string(), 0)];
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("'batch_size' cannot be 0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_serialization() {
|
||||
let args = RunArgs::default();
|
||||
let json = args.as_json().unwrap();
|
||||
let deserialized = RunArgs::from_json(&json).unwrap();
|
||||
assert_eq!(args, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_validation_errors() {
|
||||
let mut args = RunArgs::default();
|
||||
args.logrows = 0;
|
||||
args.lookup_range = (100, -100);
|
||||
let err = args.validate().unwrap_err();
|
||||
// Should contain multiple error messages
|
||||
assert!(err.matches("\n").count() >= 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +133,7 @@ pub fn aggregate<'a>(
|
||||
.collect_vec()
|
||||
}));
|
||||
|
||||
// loader.ctx().constrain_equal(cell_0, cell_1)
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
|
||||
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
@@ -308,11 +309,11 @@ impl AggregationCircuit {
|
||||
})
|
||||
}
|
||||
|
||||
/// Number of limbs used for decomposition
|
||||
///
|
||||
pub fn num_limbs() -> usize {
|
||||
LIMBS
|
||||
}
|
||||
/// Number of bits used for decomposition
|
||||
///
|
||||
pub fn num_bits() -> usize {
|
||||
BITS
|
||||
}
|
||||
|
||||
@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use clap::ValueEnum;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
|
||||
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
|
||||
use halo2curves::CurveAffine;
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
|
||||
use halo2curves::serde::SerdeObject;
|
||||
use halo2curves::CurveAffine;
|
||||
use instant::Instant;
|
||||
use log::{debug, info, trace};
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
@@ -51,9 +51,6 @@ use pyo3::types::PyDictMethods;
|
||||
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
|
||||
/// Converts a string to a `SerdeFormat`.
|
||||
/// # Panics
|
||||
/// Panics if the provided `s` is not a valid `SerdeFormat` (i.e. not one of "processed", "raw-bytes-unchecked", or "raw-bytes").
|
||||
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
match s {
|
||||
"processed" => halo2_proofs::SerdeFormat::Processed,
|
||||
@@ -324,7 +321,7 @@ where
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
|
||||
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
|
||||
where
|
||||
@@ -348,15 +345,14 @@ where
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
|
||||
C: CurveAffine + Serialize + DeserializeOwned,
|
||||
> Snark<F, C>
|
||||
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
|
||||
C: CurveAffine + Serialize + DeserializeOwned,
|
||||
> Snark<F, C>
|
||||
where
|
||||
C::Scalar: Serialize + DeserializeOwned,
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
{
|
||||
/// Create a new application snark from proof and instance variables ready for aggregation
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
protocol: Option<PlonkProtocol<C>>,
|
||||
instances: Vec<Vec<F>>,
|
||||
@@ -532,6 +528,7 @@ pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
@@ -797,6 +794,7 @@ pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
@@ -819,6 +817,7 @@ pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{ops::DecompositionError, DataFormat};
|
||||
use super::ops::DecompositionError;
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
@@ -38,13 +38,4 @@ pub enum TensorError {
|
||||
/// Decomposition error
|
||||
#[error("decomposition error: {0}")]
|
||||
DecompositionError(#[from] DecompositionError),
|
||||
/// Invalid argument
|
||||
#[error("invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
/// Index out of bounds
|
||||
#[error("index {0} out of bounds for dimension {1}")]
|
||||
IndexOutOfBounds(usize, usize),
|
||||
/// Invalid data conversion
|
||||
#[error("invalid data conversion from format {0} to {1}")]
|
||||
InvalidDataConversion(DataFormat, DataFormat),
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ pub mod var;
|
||||
|
||||
pub use errors::TensorError;
|
||||
|
||||
use core::hash::Hash;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
@@ -25,9 +24,12 @@ use std::path::PathBuf;
|
||||
pub use val::*;
|
||||
pub use var::*;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use instant::Instant;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{IntegerRep, integer_rep_to_felt},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -38,6 +40,8 @@ use halo2_proofs::{
|
||||
poly::Rotation,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
#[cfg(feature = "metal")]
|
||||
use metal::{Device, MTLResourceOptions, MTLSize};
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::io::Read;
|
||||
@@ -45,6 +49,31 @@ use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
lazy_static::lazy_static! {
|
||||
static ref DEVICE: Device = Device::system_default().expect("no device found");
|
||||
|
||||
static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap();
|
||||
|
||||
static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue();
|
||||
|
||||
static ref PIPELINES: HashMap<String, metal::ComputePipelineState> = {
|
||||
let mut map = HashMap::new();
|
||||
for name in ["add", "sub", "mul"] {
|
||||
let function = LIB.get_function(name, None).unwrap();
|
||||
let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap();
|
||||
map.insert(name.to_string(), pipeline);
|
||||
}
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
pub trait TensorType: Clone + Debug + 'static {
|
||||
/// Returns the zero value.
|
||||
@@ -62,7 +91,7 @@ pub trait TensorType: Clone + Debug + 'static {
|
||||
}
|
||||
|
||||
macro_rules! tensor_type {
|
||||
($rust_type:ty, $tensor_type:ident, $zero:expr_2021, $one:expr_2021) => {
|
||||
($rust_type:ty, $tensor_type:ident, $zero:expr, $one:expr) => {
|
||||
impl TensorType for $rust_type {
|
||||
fn zero() -> Option<Self> {
|
||||
Some($zero)
|
||||
@@ -415,7 +444,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
Err(_) => {
|
||||
return Err(TensorError::FileLoadError(
|
||||
"Failed to read tensor".to_string(),
|
||||
));
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -804,12 +833,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot duplicate every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
|
||||
let mut offset = initial_offset;
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
@@ -839,17 +862,11 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
num_repeats: usize,
|
||||
initial_offset: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if n == 0 {
|
||||
return Err(TensorError::InvalidArgument(
|
||||
"Cannot remove every 0th element".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Pre-calculate capacity to avoid reallocations
|
||||
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
|
||||
let mut inner = Vec::with_capacity(estimated_size);
|
||||
|
||||
// Use iterator directly instead of creating intermediate collectionsif
|
||||
// Use iterator directly instead of creating intermediate collections
|
||||
let mut i = 0;
|
||||
while i < self.inner.len() {
|
||||
// Add the current element
|
||||
@@ -868,6 +885,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
|
||||
/// Remove indices
|
||||
/// WARN: assumes indices are in ascending order for speed
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
@@ -894,11 +912,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
// remove indices
|
||||
for elem in indices.iter().rev() {
|
||||
if *elem < self.len() {
|
||||
inner.remove(*elem);
|
||||
} else {
|
||||
return Err(TensorError::IndexOutOfBounds(*elem, self.len()));
|
||||
}
|
||||
inner.remove(*elem);
|
||||
}
|
||||
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
@@ -926,9 +940,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
));
|
||||
}
|
||||
self.dims = vec![];
|
||||
}
|
||||
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
|
||||
self.dims = Vec::from(new_dims);
|
||||
} else {
|
||||
let product = if new_dims != [0] {
|
||||
new_dims.iter().product::<usize>()
|
||||
@@ -1107,10 +1118,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
let mut output = self.clone();
|
||||
output.reshape(shape)?;
|
||||
return Ok(output);
|
||||
} else if self.dims() == &[0] && shape.iter().product::<usize>() == 1 {
|
||||
let mut output = self.clone();
|
||||
output.reshape(shape)?;
|
||||
return Ok(output);
|
||||
}
|
||||
|
||||
if self.dims().len() > shape.len() {
|
||||
@@ -1261,7 +1268,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get last element of empty tensor".to_string(),
|
||||
));
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1286,7 +1293,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get first element of empty tensor".to_string(),
|
||||
));
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1397,6 +1404,10 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "add");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1494,6 +1505,10 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "sub");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1561,6 +1576,10 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "mul");
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
@@ -1666,9 +1685,7 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
}
|
||||
|
||||
// implement remainder
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + PartialEq> Rem
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
|
||||
/// Elementwise remainder of a tensor with another tensor.
|
||||
@@ -1697,24 +1714,9 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| match T::zero() {
|
||||
Some(zero) => {
|
||||
if r != zero {
|
||||
*o = o.clone() % r;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Cannot divide by zero in remainder".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
_ => Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
)),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() % r;
|
||||
});
|
||||
|
||||
Ok(lhs)
|
||||
}
|
||||
@@ -1749,6 +1751,7 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
/// assert_eq!(c, vec![2, 3]);
|
||||
///
|
||||
/// ```
|
||||
|
||||
pub fn get_broadcasted_shape(
|
||||
shape_a: &[usize],
|
||||
shape_b: &[usize],
|
||||
@@ -1756,247 +1759,23 @@ pub fn get_broadcasted_shape(
|
||||
let num_dims_a = shape_a.len();
|
||||
let num_dims_b = shape_b.len();
|
||||
|
||||
if num_dims_a == num_dims_b {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
match (num_dims_a, num_dims_b) {
|
||||
(a, b) if a == b => {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
} else if num_dims_a < num_dims_b {
|
||||
Ok(shape_b.to_vec())
|
||||
} else if num_dims_a > num_dims_b {
|
||||
Ok(shape_a.to_vec())
|
||||
} else {
|
||||
Err(TensorError::DimError(
|
||||
(a, b) if a < b => Ok(shape_b.to_vec()),
|
||||
(a, b) if a > b => Ok(shape_a.to_vec()),
|
||||
_ => Err(TensorError::DimError(
|
||||
"Unknown condition for broadcasting".to_string(),
|
||||
))
|
||||
)),
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
///
|
||||
|
||||
/// The shape of data for some operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
|
||||
pub enum DataFormat {
|
||||
/// NCHW
|
||||
#[default]
|
||||
NCHW,
|
||||
/// NHWC
|
||||
NHWC,
|
||||
/// CHW
|
||||
CHW,
|
||||
/// HWC
|
||||
HWC,
|
||||
}
|
||||
|
||||
// as str
|
||||
impl core::fmt::Display for DataFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DataFormat::NCHW => write!(f, "NCHW"),
|
||||
DataFormat::NHWC => write!(f, "NHWC"),
|
||||
DataFormat::CHW => write!(f, "CHW"),
|
||||
DataFormat::HWC => write!(f, "HWC"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DataFormat {
|
||||
/// Get the format's canonical form
|
||||
pub fn canonical(&self) -> DataFormat {
|
||||
match self {
|
||||
DataFormat::NHWC => DataFormat::NCHW,
|
||||
DataFormat::HWC => DataFormat::CHW,
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// no batch dim
|
||||
pub fn has_no_batch(&self) -> bool {
|
||||
match self {
|
||||
DataFormat::CHW | DataFormat::HWC => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert tensor to canonical format (NCHW or CHW)
|
||||
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
tensor: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
DataFormat::NHWC => {
|
||||
// For ND: Move channels from last axis to position after batch
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 2 {
|
||||
tensor.move_axis(ndims - 1, 1)?;
|
||||
}
|
||||
}
|
||||
DataFormat::HWC => {
|
||||
// For ND: Move channels from last axis to first position
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 1 {
|
||||
tensor.move_axis(ndims - 1, 0)?;
|
||||
}
|
||||
}
|
||||
_ => {} // NCHW/CHW are already in canonical format
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert tensor from canonical format to target format
|
||||
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
tensor: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
DataFormat::NHWC => {
|
||||
// Move channels from position 1 to end
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 2 {
|
||||
tensor.move_axis(1, ndims - 1)?;
|
||||
}
|
||||
}
|
||||
DataFormat::HWC => {
|
||||
// Move channels from position 0 to end
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 1 {
|
||||
tensor.move_axis(0, ndims - 1)?;
|
||||
}
|
||||
}
|
||||
_ => {} // NCHW/CHW don't need conversion
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the position of the channel dimension
|
||||
pub fn get_channel_dim(&self, ndims: usize) -> usize {
|
||||
match self {
|
||||
DataFormat::NCHW => 1,
|
||||
DataFormat::NHWC => ndims - 1,
|
||||
DataFormat::CHW => 0,
|
||||
DataFormat::HWC => ndims - 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The shape of the kernel for some operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
|
||||
pub enum KernelFormat {
|
||||
/// HWIO
|
||||
HWIO,
|
||||
/// OIHW
|
||||
#[default]
|
||||
OIHW,
|
||||
/// OHWI
|
||||
OHWI,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for KernelFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
KernelFormat::HWIO => write!(f, "HWIO"),
|
||||
KernelFormat::OIHW => write!(f, "OIHW"),
|
||||
KernelFormat::OHWI => write!(f, "OHWI"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelFormat {
|
||||
/// Get the format's canonical form
|
||||
pub fn canonical(&self) -> KernelFormat {
|
||||
match self {
|
||||
KernelFormat::HWIO => KernelFormat::OIHW,
|
||||
KernelFormat::OHWI => KernelFormat::OIHW,
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert kernel to canonical format (OIHW)
|
||||
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
kernel: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
KernelFormat::HWIO => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move output channels from last to first
|
||||
kernel.move_axis(kdims - 1, 0)?;
|
||||
// Move input channels from new last to second position
|
||||
kernel.move_axis(kdims - 1, 1)?;
|
||||
}
|
||||
KernelFormat::OHWI => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from last to second position
|
||||
kernel.move_axis(kdims - 1, 1)?;
|
||||
}
|
||||
_ => {} // OIHW is already canonical
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert kernel from canonical format to target format
|
||||
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
kernel: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
KernelFormat::HWIO => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from second position to last
|
||||
kernel.move_axis(1, kdims - 1)?;
|
||||
// Move output channels from first to last
|
||||
kernel.move_axis(0, kdims - 1)?;
|
||||
}
|
||||
KernelFormat::OHWI => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from second position to last
|
||||
kernel.move_axis(1, kdims - 1)?;
|
||||
}
|
||||
_ => {} // OIHW doesn't need conversion
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the position of input and output channel dimensions
|
||||
pub fn get_channel_dims(&self, ndims: usize) -> (usize, usize) {
|
||||
// (input_ch, output_ch)
|
||||
match self {
|
||||
KernelFormat::OIHW => (1, 0),
|
||||
KernelFormat::HWIO => (ndims - 2, ndims - 1),
|
||||
KernelFormat::OHWI => (ndims - 1, 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<tract_onnx::tract_hir::ops::nn::DataFormat> for DataFormat {
|
||||
fn from(fmt: tract_onnx::tract_hir::ops::nn::DataFormat) -> Self {
|
||||
match fmt {
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::NCHW => DataFormat::NCHW,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::NHWC => DataFormat::NHWC,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::CHW => DataFormat::CHW,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::HWC => DataFormat::HWC,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat> for KernelFormat {
|
||||
fn from(fmt: tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat) -> Self {
|
||||
match fmt {
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::HWIO => {
|
||||
KernelFormat::HWIO
|
||||
}
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OIHW => {
|
||||
KernelFormat::OIHW
|
||||
}
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OHWI => {
|
||||
KernelFormat::OHWI
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -2032,4 +1811,66 @@ mod tests {
|
||||
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_int() {
|
||||
let a = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let b = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_felt() {
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
let a = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
let b = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ pub fn get_rep(
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if (*x).abs() > ((base as i128).pow(n as u32)) - 1 {
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
@@ -43,8 +43,8 @@ pub fn get_rep(
|
||||
let mut x = x.abs();
|
||||
//
|
||||
for i in (1..rep.len()).rev() {
|
||||
rep[i] = x % base as IntegerRep;
|
||||
x /= base as IntegerRep;
|
||||
rep[i] = x % base as i128;
|
||||
x /= base as i128;
|
||||
}
|
||||
|
||||
Ok(rep)
|
||||
@@ -127,7 +127,7 @@ pub fn decompose(
|
||||
.flatten()
|
||||
.collect::<Vec<IntegerRep>>();
|
||||
|
||||
let output = Tensor::<IntegerRep>::new(Some(&resp), &dims)?;
|
||||
let output = Tensor::<i128>::new(Some(&resp), &dims)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -385,12 +385,6 @@ pub fn resize<T: TensorType + Send + Sync>(
|
||||
pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.is_empty() {
|
||||
return Err(TensorError::DimMismatch("add".to_string()));
|
||||
}
|
||||
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -439,11 +433,6 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.is_empty() {
|
||||
return Err(TensorError::DimMismatch("sub".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -490,11 +479,6 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
|
||||
pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync>(
|
||||
t: &[Tensor<T>],
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.is_empty() {
|
||||
return Err(TensorError::DimMismatch("mult".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
let mut output: Tensor<T> = t[0].clone();
|
||||
|
||||
@@ -535,101 +519,30 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
|
||||
/// let result = downsample(&x, 1, 2, 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 6]), &[2, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// // Test case 1: Negative stride along dimension 0
|
||||
/// // This should flip the order along dimension 0
|
||||
/// let result = downsample(&x, 0, -1, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 6, 1, 2, 3]), // Flipped order of rows
|
||||
/// &[2, 3]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 2: Negative stride along dimension 1
|
||||
/// // This should flip the order along dimension 1
|
||||
/// let result = downsample(&x, 1, -1, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 2, 1, 6, 5, 4]), // Flipped order of columns
|
||||
/// &[2, 3]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 3: Negative stride with stride magnitude > 1
|
||||
/// // This should both skip and flip
|
||||
/// let result = downsample(&x, 1, -2, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 1, 6, 4]), // Take every 2nd element in reverse
|
||||
/// &[2, 2]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 4: Negative stride with non-zero modulo
|
||||
/// // This should start at (size - 1 - modulo) and reverse
|
||||
/// let result = downsample(&x, 1, -2, 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 5]), // Start at second element from end, take every 2nd in reverse
|
||||
/// &[2, 1]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Create a larger test case for more complex downsampling
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
|
||||
/// &[3, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// // Test case 5: Negative stride with modulo on larger tensor
|
||||
/// let result = downsample(&y, 1, -2, 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 1, 7, 5, 11, 9]), // Start at one after reverse, take every 2nd
|
||||
/// &[3, 2]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn downsample<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
dim: usize,
|
||||
stride: isize, // Changed from usize to isize to support negative strides
|
||||
stride: usize,
|
||||
modulo: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// Handle negative stride case
|
||||
if stride == 0 {
|
||||
return Err(TensorError::DimMismatch(
|
||||
"downsample stride cannot be zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let stride_abs = stride.unsigned_abs();
|
||||
let mut output_shape = input.dims().to_vec();
|
||||
// now downsample along axis dim offset by modulo, rounding up (+1 if remaidner is non-zero)
|
||||
let remainder = (input.dims()[dim] - modulo) % stride;
|
||||
let div = (input.dims()[dim] - modulo) / stride;
|
||||
output_shape[dim] = div + (remainder > 0) as usize;
|
||||
let mut output = Tensor::<T>::new(None, &output_shape)?;
|
||||
|
||||
if modulo >= input.dims()[dim] {
|
||||
if modulo > input.dims()[dim] {
|
||||
return Err(TensorError::DimMismatch("downsample".to_string()));
|
||||
}
|
||||
|
||||
// Calculate output shape based on the absolute value of stride
|
||||
let remainder = (input.dims()[dim] - modulo) % stride_abs;
|
||||
let div = (input.dims()[dim] - modulo) / stride_abs;
|
||||
output_shape[dim] = div + (remainder > 0) as usize;
|
||||
|
||||
let mut output = Tensor::<T>::new(None, &output_shape)?;
|
||||
|
||||
// Calculate indices based on stride direction
|
||||
// now downsample along axis dim offset by modulo
|
||||
let indices = (0..output_shape.len())
|
||||
.map(|i| {
|
||||
if i == dim {
|
||||
let mut index = vec![0; output_shape[i]];
|
||||
for (j, idx) in index.iter_mut().enumerate() {
|
||||
if stride > 0 {
|
||||
// Positive stride: move forward from modulo
|
||||
*idx = j * stride_abs + modulo;
|
||||
} else {
|
||||
// Negative stride: move backward from (size - 1 - modulo)
|
||||
*idx = (input.dims()[dim] - 1 - modulo) - j * stride_abs;
|
||||
}
|
||||
for (i, idx) in index.iter_mut().enumerate() {
|
||||
*idx = i * stride + modulo;
|
||||
}
|
||||
index
|
||||
} else {
|
||||
@@ -1397,6 +1310,7 @@ pub fn pad<T: TensorType>(
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns a TensorError if the tensors in `inputs` have incompatible dimensions for concatenation along the specified `axis`.
|
||||
|
||||
pub fn concat<T: TensorType + Send + Sync>(
|
||||
inputs: &[&Tensor<T>],
|
||||
axis: usize,
|
||||
@@ -2172,6 +2086,7 @@ pub mod nonlinearities {
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
|
||||
pub fn tanh(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
@@ -2346,11 +2261,7 @@ pub mod nonlinearities {
|
||||
pub fn recip(a: &Tensor<IntegerRep>, input_scale: f64, out_scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rescaled = (a_i as f64) / input_scale;
|
||||
let denom = if rescaled == 0_f64 {
|
||||
(1_f64) / (rescaled + f64::EPSILON)
|
||||
} else {
|
||||
(1_f64) / (rescaled)
|
||||
};
|
||||
let denom = (1_f64) / (rescaled + f64::EPSILON);
|
||||
let d_inv_x = out_scale * denom;
|
||||
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
|
||||
})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,38 +2,36 @@ use std::collections::HashSet;
|
||||
|
||||
use log::{debug, error, warn};
|
||||
|
||||
use crate::circuit::{CheckMode, region::ConstantsMap};
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
|
||||
use super::*;
|
||||
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
|
||||
/// VarTensors are used to store and manage circuit columns, typically for assigning ValTensor
|
||||
/// values during circuit layout. The tensor organizes storage into blocks of columns, where each
|
||||
/// block contains multiple columns and each column contains multiple rows.
|
||||
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
|
||||
/// Typically assign [ValTensor]s to [VarTensor]s when laying out a circuit.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq)]
|
||||
pub enum VarTensor {
|
||||
/// A VarTensor for holding Advice values, which are assigned at proving time.
|
||||
Advice {
|
||||
/// Vec of Advice columns, we have [[xx][xx][xx]...] where each inner vec is xx columns
|
||||
inner: Vec<Vec<Column<Advice>>>,
|
||||
/// The number of columns in each inner block
|
||||
///
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// A placeholder tensor used for testing or temporary storage
|
||||
/// Dummy var
|
||||
Dummy {
|
||||
/// The number of columns in each inner block
|
||||
///
|
||||
num_inner_cols: usize,
|
||||
/// Number of rows available to be used in each column of the storage
|
||||
col_size: usize,
|
||||
},
|
||||
/// An empty tensor with no storage
|
||||
/// Empty var
|
||||
#[default]
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// Returns the name of the tensor variant as a static string
|
||||
/// name of the tensor
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => "Advice",
|
||||
@@ -42,35 +40,22 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the tensor is an Advice variant
|
||||
///
|
||||
pub fn is_advice(&self) -> bool {
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
}
|
||||
|
||||
/// Calculates the maximum number of usable rows in the constraint system
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - Log base 2 of the total number of rows (including system and blinding rows)
|
||||
///
|
||||
/// # Returns
|
||||
/// The maximum number of usable rows after accounting for blinding factors
|
||||
pub fn max_rows<F: PrimeField>(cs: &ConstraintSystem<F>, logrows: usize) -> usize {
|
||||
let base = 2u32;
|
||||
base.pow(logrows as u32) as usize - cs.blinding_factors() - 1
|
||||
}
|
||||
|
||||
/// Creates a new VarTensor::Advice with unblinded columns. Unblinded columns are used when
|
||||
/// the values do not need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with unblinded columns enabled for equality constraints
|
||||
/// Create a new VarTensor::Advice that is unblinded
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn new_unblinded_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -108,17 +93,11 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
|
||||
/// the values need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice with blinded columns enabled for equality constraints
|
||||
/// Create a new VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn new_advice<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -154,17 +133,11 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
|
||||
/// Fixed columns are used for constant values that are known at circuit creation time.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_constants` - Number of constant values needed
|
||||
/// * `module_requires_fixed` - Whether the module requires at least one fixed column
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of fixed columns created
|
||||
/// Initializes fixed columns to support the VarTensor::Advice
|
||||
/// Arguments
|
||||
/// * `cs` - The constraint system
|
||||
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
|
||||
/// * `capacity` - The number of advice cells to allocate
|
||||
pub fn constant_cols<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
@@ -196,14 +169,7 @@ impl VarTensor {
|
||||
modulo
|
||||
}
|
||||
|
||||
/// Creates a new dummy VarTensor for testing or temporary storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Dummy with the specified dimensions
|
||||
/// Create a new VarTensor::Dummy
|
||||
pub fn dummy(logrows: usize, num_inner_cols: usize) -> Self {
|
||||
let base = 2u32;
|
||||
let max_rows = base.pow(logrows as u32) as usize - 6;
|
||||
@@ -213,7 +179,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of blocks in the tensor
|
||||
/// Gets the dims of the object the VarTensor represents
|
||||
pub fn num_blocks(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner.len(),
|
||||
@@ -221,7 +187,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of columns in each inner block
|
||||
/// Num inner cols
|
||||
pub fn num_inner_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { num_inner_cols, .. } | VarTensor::Dummy { num_inner_cols, .. } => {
|
||||
@@ -231,7 +197,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the total number of columns across all blocks
|
||||
/// Total number of columns
|
||||
pub fn num_cols(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { inner, .. } => inner[0].len() * inner.len(),
|
||||
@@ -239,7 +205,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the maximum number of rows in each column
|
||||
/// Gets the size of each column
|
||||
pub fn col_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice { col_size, .. } | VarTensor::Dummy { col_size, .. } => *col_size,
|
||||
@@ -247,7 +213,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the total size of each block (num_inner_cols * col_size)
|
||||
/// Gets the size of each column
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
VarTensor::Advice {
|
||||
@@ -264,13 +230,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a linear coordinate to (block, column, row) coordinates in the storage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `linear_coord` - The linear index to convert
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (block_index, column_index, row_index)
|
||||
/// Take a linear coordinate and output the (column, row) position in the storage block.
|
||||
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
|
||||
// x indexes over blocks of size num_inner_cols
|
||||
let x = linear_coord / self.block_size();
|
||||
@@ -283,17 +243,7 @@ impl VarTensor {
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// Queries a range of cells in the tensor during circuit synthesis
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `y` - Column index within block
|
||||
/// * `z` - Starting row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried cells
|
||||
/// Retrieve the value of a specific cell in the tensor.
|
||||
pub fn query_rng<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -318,16 +268,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Queries an entire block of cells at a given offset
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `meta` - Virtual cells accessor
|
||||
/// * `x` - Block index
|
||||
/// * `z` - Row offset
|
||||
/// * `rng` - Number of consecutive rows to query
|
||||
///
|
||||
/// # Returns
|
||||
/// A tensor of expressions representing the queried block
|
||||
/// Retrieve the value of a specific block at an offset in the tensor.
|
||||
pub fn query_whole_block<F: PrimeField>(
|
||||
&self,
|
||||
meta: &mut VirtualCells<'_, F>,
|
||||
@@ -352,16 +293,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a constant value to a specific cell in the tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `coord` - Coordinate within the tensor
|
||||
/// * `constant` - The constant value to assign
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned cell or an error if assignment fails
|
||||
/// Assigns a constant value to a specific cell in the tensor.
|
||||
pub fn assign_constant<F: PrimeField + TensorType + PartialOrd>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -381,17 +313,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values from a ValTensor to this tensor, excluding specified positions
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `omissions` - Set of positions to skip during assignment
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -403,10 +325,7 @@ impl VarTensor {
|
||||
let mut assigned_coord = 0;
|
||||
let mut res: ValTensor<F> = match values {
|
||||
ValTensor::Instance { .. } => {
|
||||
error!(
|
||||
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
unimplemented!("cannot assign instance to advice columns with omissions")
|
||||
}
|
||||
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
|
||||
v.enum_map(|coord, k| {
|
||||
@@ -425,16 +344,7 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Assigns values from a ValTensor to this tensor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -486,23 +396,14 @@ impl VarTensor {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Returns the remaining available space in a column for assignments
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `offset` - Current offset in the column
|
||||
/// * `values` - The ValTensor to check space for
|
||||
///
|
||||
/// # Returns
|
||||
/// The number of rows that need to be flushed or an error if space is insufficient
|
||||
/// Helper function to get the remaining size of the column
|
||||
pub fn get_column_flush<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<usize, halo2_proofs::plonk::Error> {
|
||||
if values.len() > self.col_size() {
|
||||
error!(
|
||||
"There are too many values to flush for this column size, try setting the logrows to a higher value (eg. --logrows 22 on the cli)"
|
||||
);
|
||||
error!("Values are too large for the column");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
|
||||
@@ -526,16 +427,8 @@ impl VarTensor {
|
||||
Ok(flush_len)
|
||||
}
|
||||
|
||||
/// Assigns values to a single column, avoiding column overflow by flushing to the next column if needed
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, number of rows flushed) or an error if assignment fails
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor. Whereby the values are assigned to a single column, without overflowing.
|
||||
/// So for instance if we are assigning 10 values and we are at index 18 of the column, and the columns are of length 20, we skip the last 2 values of current column and start from the beginning of the next column.
|
||||
pub fn assign_exact_column<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -550,17 +443,8 @@ impl VarTensor {
|
||||
Ok((assigned_vals, flush_len))
|
||||
}
|
||||
|
||||
/// Assigns values with duplication in dummy mode, used for testing and simulation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `single_inner_col` - Whether to treat as a single column
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn dummy_assign_with_duplication<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -572,13 +456,8 @@ impl VarTensor {
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
match values {
|
||||
ValTensor::Instance { .. } => {
|
||||
error!(
|
||||
"duplication is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
}
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
let duplication_freq = if single_inner_col {
|
||||
self.col_size()
|
||||
} else {
|
||||
@@ -591,20 +470,21 @@ impl VarTensor {
|
||||
self.num_inner_cols()
|
||||
};
|
||||
|
||||
let duplication_offset = if single_inner_col { row } else { offset };
|
||||
let duplication_offset = if single_inner_col {
|
||||
row
|
||||
} else {
|
||||
offset
|
||||
};
|
||||
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let mut res: ValTensor<F> = v
|
||||
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.unwrap()
|
||||
.into();
|
||||
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
|
||||
|
||||
let constants_map = res.create_constants_map();
|
||||
constants.extend(constants_map);
|
||||
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.unwrap();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.set_scale(values.scale());
|
||||
@@ -614,16 +494,7 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values with duplication but without enforcing constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
pub fn assign_with_duplication_unconstrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -634,13 +505,9 @@ impl VarTensor {
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
match values {
|
||||
ValTensor::Instance { .. } => {
|
||||
error!(
|
||||
"duplication is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
}
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
|
||||
let duplication_freq = self.block_size();
|
||||
|
||||
let num_repeats = self.num_inner_cols();
|
||||
@@ -648,31 +515,17 @@ impl VarTensor {
|
||||
let duplication_offset = offset;
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v
|
||||
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.map_err(|e| {
|
||||
error!("Error duplicating values: {:?}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
let mut res: ValTensor<F> = {
|
||||
v.enum_map(|coord, k| {
|
||||
let cell =
|
||||
self.assign_value(region, offset, k.clone(), coord, constants)?;
|
||||
Ok::<_, halo2_proofs::plonk::Error>(cell)
|
||||
})?
|
||||
.into()
|
||||
};
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.map_err(|e| {
|
||||
error!("Error duplicating values: {:?}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord, constants)?;
|
||||
Ok::<_, halo2_proofs::plonk::Error>(cell)
|
||||
|
||||
res.reshape(dims).map_err(|e| {
|
||||
error!("Error duplicating values: {:?}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
})?.into()};
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.set_scale(values.scale());
|
||||
|
||||
Ok((res, total_used_len))
|
||||
@@ -680,18 +533,8 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values with duplication and enforces equality constraints between duplicated values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `row` - Starting row for assignment
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `check_mode` - Mode for checking equality constraints
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn assign_with_duplication_constrained<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -706,71 +549,61 @@ impl VarTensor {
|
||||
let mut prev_cell = None;
|
||||
|
||||
match values {
|
||||
ValTensor::Instance { .. } => {
|
||||
error!(
|
||||
"duplication is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
}
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
|
||||
let duplication_freq = self.col_size();
|
||||
let num_repeats = 1;
|
||||
let duplication_offset = row;
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v
|
||||
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.unwrap();
|
||||
let mut res: ValTensor<F> = v
|
||||
.enum_map(|coord, k| {
|
||||
let step = self.num_inner_cols();
|
||||
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
let mut res: ValTensor<F> =
|
||||
v.enum_map(|coord, k| {
|
||||
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord * step);
|
||||
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
|
||||
// assert that duplication occurred correctly
|
||||
assert_eq!(
|
||||
Into::<IntegerRep>::into(k.clone()),
|
||||
Into::<IntegerRep>::into(v[coord - 1].clone())
|
||||
);
|
||||
};
|
||||
let step = self.num_inner_cols();
|
||||
|
||||
let cell =
|
||||
self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord * step);
|
||||
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
|
||||
// assert that duplication occurred correctly
|
||||
assert_eq!(Into::<IntegerRep>::into(k.clone()), Into::<IntegerRep>::into(v[coord - 1].clone()));
|
||||
};
|
||||
|
||||
let at_end_of_column = z == duplication_freq - 1;
|
||||
let at_beginning_of_column = z == 0;
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
|
||||
if at_end_of_column {
|
||||
// if we are at the end of the column, we need to copy the cell to the next column
|
||||
prev_cell = Some(cell.clone());
|
||||
} else if coord > 0 && at_beginning_of_column {
|
||||
if let Some(prev_cell) = prev_cell.as_ref() {
|
||||
let cell = if let Some(cell) = cell.cell() {
|
||||
cell
|
||||
} else {
|
||||
error!("Error getting cell: {:?}", (x, y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
|
||||
prev_cell
|
||||
} else {
|
||||
error!("Error getting prev cell: {:?}", (x, y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
region.constrain_equal(prev_cell, cell)?;
|
||||
let at_end_of_column = z == duplication_freq - 1;
|
||||
let at_beginning_of_column = z == 0;
|
||||
|
||||
if at_end_of_column {
|
||||
// if we are at the end of the column, we need to copy the cell to the next column
|
||||
prev_cell = Some(cell.clone());
|
||||
} else if coord > 0 && at_beginning_of_column {
|
||||
if let Some(prev_cell) = prev_cell.as_ref() {
|
||||
let cell = if let Some(cell) = cell.cell() {
|
||||
cell
|
||||
} else {
|
||||
error!("Previous cell was not set");
|
||||
error!("Error getting cell: {:?}", (x,y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
};
|
||||
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
|
||||
prev_cell
|
||||
} else {
|
||||
error!("Error getting prev cell: {:?}", (x,y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
region.constrain_equal(prev_cell,cell)?;
|
||||
} else {
|
||||
error!("Previous cell was not set");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(cell)
|
||||
})?
|
||||
.into();
|
||||
Ok(cell)
|
||||
|
||||
})?.into();
|
||||
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.unwrap();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.set_scale(values.scale());
|
||||
@@ -780,17 +613,6 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns a single value to the tensor. This is a helper function used by other assignment methods.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for the assignment
|
||||
/// * `k` - The value to assign
|
||||
/// * `coord` - The coordinate where to assign the value
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned value or an error if assignment fails
|
||||
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
@@ -801,49 +623,32 @@ impl VarTensor {
|
||||
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
let res = match k {
|
||||
// Handle direct value assignment
|
||||
ValType::Value(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
|
||||
}
|
||||
_ => {
|
||||
error!("VarTensor was not initialized");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned value
|
||||
ValType::PrevAssigned(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
|
||||
}
|
||||
_ => {
|
||||
error!("VarTensor was not initialized");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle copying previously assigned constant
|
||||
ValType::AssignedConstant(v, val) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
|
||||
}
|
||||
_ => {
|
||||
error!("VarTensor was not initialized");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle assigning evaluated value
|
||||
ValType::AssignedValue(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
|
||||
region
|
||||
.assign_advice(|| "k", advices[x][y], z, || v)?
|
||||
.evaluate(),
|
||||
),
|
||||
_ => {
|
||||
error!("VarTensor was not initialized");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
// Handle constant value assignment with caching
|
||||
ValType::Constant(v) => {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
|
||||
let value = ValType::AssignedConstant(
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -28,12 +28,11 @@
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false,
|
||||
"ignore_range_check_inputs_outputs": false
|
||||
"bounded_log_lookup": false
|
||||
},
|
||||
"num_rows": 236,
|
||||
"total_assignments": 472,
|
||||
"total_const_size": 4,
|
||||
"num_rows": 46,
|
||||
"total_assignments": 92,
|
||||
"total_const_size": 3,
|
||||
"total_dynamic_col_size": 0,
|
||||
"max_dynamic_input_len": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
|
||||
Binary file not shown.
14
tests/foundry/.gitignore
vendored
14
tests/foundry/.gitignore
vendored
@@ -1,14 +0,0 @@
|
||||
# Compiler files
|
||||
cache/
|
||||
out/
|
||||
|
||||
# Ignores development broadcast logs
|
||||
!/broadcast
|
||||
/broadcast/*/31337/
|
||||
/broadcast/**/dry-run/
|
||||
|
||||
# Docs
|
||||
docs/
|
||||
|
||||
# Dotenv file
|
||||
.env
|
||||
@@ -1,66 +0,0 @@
|
||||
## Foundry
|
||||
|
||||
**Foundry is a blazing fast, portable and modular toolkit for Ethereum application development written in Rust.**
|
||||
|
||||
Foundry consists of:
|
||||
|
||||
- **Forge**: Ethereum testing framework (like Truffle, Hardhat and DappTools).
|
||||
- **Cast**: Swiss army knife for interacting with EVM smart contracts, sending transactions and getting chain data.
|
||||
- **Anvil**: Local Ethereum node, akin to Ganache, Hardhat Network.
|
||||
- **Chisel**: Fast, utilitarian, and verbose solidity REPL.
|
||||
|
||||
## Documentation
|
||||
|
||||
https://book.getfoundry.sh/
|
||||
|
||||
## Usage
|
||||
|
||||
### Build
|
||||
|
||||
```shell
|
||||
$ forge build
|
||||
```
|
||||
|
||||
### Test
|
||||
|
||||
```shell
|
||||
$ forge test
|
||||
```
|
||||
|
||||
### Format
|
||||
|
||||
```shell
|
||||
$ forge fmt
|
||||
```
|
||||
|
||||
### Gas Snapshots
|
||||
|
||||
```shell
|
||||
$ forge snapshot
|
||||
```
|
||||
|
||||
### Anvil
|
||||
|
||||
```shell
|
||||
$ anvil
|
||||
```
|
||||
|
||||
### Deploy
|
||||
|
||||
```shell
|
||||
$ forge script script/Counter.s.sol:CounterScript --rpc-url <your_rpc_url> --private-key <your_private_key>
|
||||
```
|
||||
|
||||
### Cast
|
||||
|
||||
```shell
|
||||
$ cast <subcommand>
|
||||
```
|
||||
|
||||
### Help
|
||||
|
||||
```shell
|
||||
$ forge --help
|
||||
$ anvil --help
|
||||
$ cast --help
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
[profile.default]
|
||||
src = "../../contracts"
|
||||
out = "out"
|
||||
libs = ["lib"]
|
||||
|
||||
# See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options
|
||||
@@ -1 +0,0 @@
|
||||
contracts/=../../contracts/
|
||||
@@ -1,429 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
pragma solidity ^0.8.20;
|
||||
|
||||
import "forge-std/Test.sol";
|
||||
import {console} from "forge-std/console.sol";
|
||||
import "contracts/AttestData.sol" as AttestData;
|
||||
|
||||
contract MockVKA {
|
||||
constructor() {}
|
||||
}
|
||||
|
||||
contract MockVerifier {
|
||||
bool public shouldVerify;
|
||||
|
||||
constructor(bool _shouldVerify) {
|
||||
shouldVerify = _shouldVerify;
|
||||
}
|
||||
|
||||
function verifyProof(
|
||||
bytes calldata,
|
||||
uint256[] calldata
|
||||
) external view returns (bool) {
|
||||
require(shouldVerify, "Verification failed");
|
||||
return shouldVerify;
|
||||
}
|
||||
}
|
||||
|
||||
contract MockVerifierSeperate {
|
||||
bool public shouldVerify;
|
||||
|
||||
constructor(bool _shouldVerify) {
|
||||
shouldVerify = _shouldVerify;
|
||||
}
|
||||
|
||||
function verifyProof(
|
||||
address,
|
||||
bytes calldata,
|
||||
uint256[] calldata
|
||||
) external view returns (bool) {
|
||||
require(shouldVerify, "Verification failed");
|
||||
return shouldVerify;
|
||||
}
|
||||
}
|
||||
|
||||
contract MockTargetContract {
|
||||
int256[] public data;
|
||||
|
||||
constructor(int256[] memory _data) {
|
||||
data = _data;
|
||||
}
|
||||
|
||||
function setData(int256[] memory _data) external {
|
||||
data = _data;
|
||||
}
|
||||
|
||||
function getData() external view returns (int256[] memory) {
|
||||
return data;
|
||||
}
|
||||
}
|
||||
|
||||
contract DataAttestationTest is Test {
|
||||
AttestData.DataAttestation das;
|
||||
MockVerifier verifier;
|
||||
MockVerifierSeperate verifierSeperate;
|
||||
MockVKA vka;
|
||||
MockTargetContract target;
|
||||
int256[] mockData = [int256(1e18), -int256(5e17)];
|
||||
uint256[] decimals = [18, 18];
|
||||
uint256[] bits = [13, 13];
|
||||
uint8 instanceOffset = 0;
|
||||
bytes callData;
|
||||
|
||||
function setUp() public {
|
||||
target = new MockTargetContract(mockData);
|
||||
verifier = new MockVerifier(true);
|
||||
verifierSeperate = new MockVerifierSeperate(true);
|
||||
vka = new MockVKA();
|
||||
|
||||
callData = abi.encodeWithSignature("getData()");
|
||||
|
||||
das = new AttestData.DataAttestation(
|
||||
address(target),
|
||||
callData,
|
||||
decimals,
|
||||
bits,
|
||||
instanceOffset
|
||||
);
|
||||
}
|
||||
|
||||
// Fork of mulDivRound which doesn't revert on overflow and returns a boolean instead to indicate overflow
|
||||
function mulDivRound(
|
||||
uint256 x,
|
||||
uint256 y,
|
||||
uint256 denominator
|
||||
) public pure returns (uint256 result, bool overflow) {
|
||||
unchecked {
|
||||
uint256 prod0;
|
||||
uint256 prod1;
|
||||
assembly {
|
||||
let mm := mulmod(x, y, not(0))
|
||||
prod0 := mul(x, y)
|
||||
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
|
||||
}
|
||||
uint256 remainder = mulmod(x, y, denominator);
|
||||
bool addOne;
|
||||
if (remainder * 2 >= denominator) {
|
||||
addOne = true;
|
||||
}
|
||||
|
||||
if (prod1 == 0) {
|
||||
if (addOne) {
|
||||
return ((prod0 / denominator) + 1, false);
|
||||
}
|
||||
return (prod0 / denominator, false);
|
||||
}
|
||||
|
||||
if (denominator > prod1) {
|
||||
return (0, true);
|
||||
}
|
||||
|
||||
assembly {
|
||||
prod1 := sub(prod1, gt(remainder, prod0))
|
||||
prod0 := sub(prod0, remainder)
|
||||
}
|
||||
|
||||
uint256 twos = denominator & (~denominator + 1);
|
||||
assembly {
|
||||
denominator := div(denominator, twos)
|
||||
prod0 := div(prod0, twos)
|
||||
twos := add(div(sub(0, twos), twos), 1)
|
||||
}
|
||||
|
||||
prod0 |= prod1 * twos;
|
||||
|
||||
uint256 inverse = (3 * denominator) ^ 2;
|
||||
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
|
||||
result = prod0 * inverse;
|
||||
if (addOne) {
|
||||
result += 1;
|
||||
}
|
||||
return (result, false);
|
||||
}
|
||||
}
|
||||
struct SampleAttestation {
|
||||
int256 mockData;
|
||||
uint8 decimals;
|
||||
uint8 bits;
|
||||
}
|
||||
function test_fuzzAttestedData(
|
||||
SampleAttestation[] memory _attestations
|
||||
) public {
|
||||
vm.assume(_attestations.length == 1);
|
||||
int256[] memory _mockData = new int256[](1);
|
||||
uint256[] memory _decimals = new uint256[](1);
|
||||
uint256[] memory _bits = new uint256[](1);
|
||||
uint256[] memory _instances = new uint256[](1);
|
||||
for (uint256 i = 0; i < 1; i++) {
|
||||
SampleAttestation memory attestation = _attestations[i];
|
||||
_mockData[i] = attestation.mockData;
|
||||
vm.assume(attestation.mockData != type(int256).min); /// Will overflow int256 during negation op
|
||||
vm.assume(attestation.decimals < 77); /// Else will exceed uint256 bounds
|
||||
vm.assume(attestation.bits < 128); /// Else will exceed EZKL fixed point bounds for int128 type
|
||||
bool neg = attestation.mockData < 0;
|
||||
if (neg) {
|
||||
attestation.mockData = -attestation.mockData;
|
||||
}
|
||||
(uint256 _result, bool overflow) = mulDivRound(
|
||||
uint256(attestation.mockData),
|
||||
uint256(1 << attestation.bits),
|
||||
uint256(10 ** attestation.decimals)
|
||||
);
|
||||
vm.assume(!overflow);
|
||||
vm.assume(_result < das.HALF_ORDER());
|
||||
if (neg) {
|
||||
// No possibility of overflow here since output is less than or equal to HALF_ORDER
|
||||
// and therefore falls within the max range of int256 without overflow
|
||||
vm.assume(-int256(_result) > type(int128).min);
|
||||
_instances[i] =
|
||||
uint256(int(das.ORDER()) - int256(_result)) %
|
||||
das.ORDER();
|
||||
} else {
|
||||
vm.assume(_result < uint128(type(int128).max));
|
||||
_instances[i] = _result;
|
||||
}
|
||||
_decimals[i] = attestation.decimals;
|
||||
_bits[i] = attestation.bits;
|
||||
}
|
||||
// Update the attested data
|
||||
target.setData(_mockData);
|
||||
// Deploy the new data attestation contract
|
||||
AttestData.DataAttestation dasNew = new AttestData.DataAttestation(
|
||||
address(target),
|
||||
callData,
|
||||
_decimals,
|
||||
_bits,
|
||||
instanceOffset
|
||||
);
|
||||
bytes memory proof = hex"1234"; // Would normally contain commitments
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
proof,
|
||||
_instances
|
||||
);
|
||||
|
||||
AttestData.DataAttestation.Scalars memory _scalars = AttestData
|
||||
.DataAttestation
|
||||
.Scalars(10 ** _decimals[0], 1 << _bits[0]);
|
||||
|
||||
int256 output = dasNew.quantizeData(_mockData[0], _scalars);
|
||||
console.log("output: ", output);
|
||||
uint256 fieldElement = dasNew.toFieldElement(output);
|
||||
// output should equal to _instances[0]
|
||||
assertEq(fieldElement, _instances[0]);
|
||||
|
||||
bool verificationResult = dasNew.verifyWithDataAttestation(
|
||||
address(verifier),
|
||||
encoded
|
||||
);
|
||||
assertTrue(verificationResult);
|
||||
}
|
||||
|
||||
// Test deployment parameters
|
||||
function testDeployment() public view {
|
||||
assertEq(das.contractAddress(), address(target));
|
||||
assertEq(das.callData(), abi.encodeWithSignature("getData()"));
|
||||
assertEq(das.instanceOffset(), instanceOffset);
|
||||
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
assertEq(scalar.decimals, 1e18);
|
||||
assertEq(scalar.bits, 1 << 13);
|
||||
}
|
||||
|
||||
// Test quantizeData function
|
||||
function testQuantizeData() public view {
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
|
||||
int256 positive = das.quantizeData(1e18, scalar);
|
||||
assertEq(positive, int256(scalar.bits));
|
||||
|
||||
int256 negative = das.quantizeData(-1e18, scalar);
|
||||
assertEq(negative, -int256(scalar.bits));
|
||||
|
||||
// Test rounding
|
||||
int half = int(0.5e18 / scalar.bits);
|
||||
int256 rounded = das.quantizeData(half, scalar);
|
||||
assertEq(rounded, 1);
|
||||
}
|
||||
|
||||
// Test staticCall functionality
|
||||
function testStaticCall() public view {
|
||||
bytes memory result = das.staticCall(
|
||||
address(target),
|
||||
abi.encodeWithSignature("getData()")
|
||||
);
|
||||
int256[] memory decoded = abi.decode(result, (int256[]));
|
||||
assertEq(decoded[0], mockData[0]);
|
||||
assertEq(decoded[1], mockData[1]);
|
||||
}
|
||||
|
||||
// Test attestData validation
|
||||
function testAttestDataSuccess() public view {
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
das.attestData(instances); // Should not revert
|
||||
}
|
||||
|
||||
function testAttestDataFailure() public {
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
instances[0] = das.toFieldElement(1e18); // Incorrect value
|
||||
instances[1] = das.toFieldElement(5e17);
|
||||
|
||||
vm.expectRevert("Public input does not match");
|
||||
das.attestData(instances);
|
||||
}
|
||||
|
||||
// Test full verification flow
|
||||
function testSuccessfulVerification() public view {
|
||||
// Prepare valid instances
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
|
||||
// Create valid calldata (mock)
|
||||
bytes memory proof = hex"1234"; // Would normally contain commitments
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
proof,
|
||||
instances
|
||||
);
|
||||
bytes memory encoded_vka = abi.encodeWithSignature(
|
||||
"verifyProof(address,bytes,uint256[])",
|
||||
address(vka),
|
||||
proof,
|
||||
instances
|
||||
);
|
||||
|
||||
bool result = das.verifyWithDataAttestation(address(verifier), encoded);
|
||||
assertTrue(result);
|
||||
result = das.verifyWithDataAttestation(
|
||||
address(verifierSeperate),
|
||||
encoded_vka
|
||||
);
|
||||
assertTrue(result);
|
||||
}
|
||||
|
||||
function testLoadInstances() public view {
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
|
||||
// Create valid calldata (mock)
|
||||
bytes memory proof = hex"1234"; // Would normally contain commitments
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
proof,
|
||||
instances
|
||||
);
|
||||
bytes memory encoded_vka = abi.encodeWithSignature(
|
||||
"verifyProof(address,bytes,uint256[])",
|
||||
address(vka),
|
||||
proof,
|
||||
instances
|
||||
);
|
||||
|
||||
// Load encoded instances from calldata
|
||||
uint256[] memory extracted_instances_calldata = das
|
||||
.getInstancesCalldata(encoded);
|
||||
assertEq(extracted_instances_calldata[0], instances[0]);
|
||||
assertEq(extracted_instances_calldata[1], instances[1]);
|
||||
// Load encoded instances from memory
|
||||
uint256[] memory extracted_instances_memory = das.getInstancesMemory(
|
||||
encoded
|
||||
);
|
||||
assertEq(extracted_instances_memory[0], instances[0]);
|
||||
assertEq(extracted_instances_memory[1], instances[1]);
|
||||
// Load encoded with vk instances from calldata
|
||||
uint256[] memory extracted_instances_calldata_vk = das
|
||||
.getInstancesCalldata(encoded_vka);
|
||||
assertEq(extracted_instances_calldata_vk[0], instances[0]);
|
||||
assertEq(extracted_instances_calldata_vk[1], instances[1]);
|
||||
// Load encoded with vk instances from memory
|
||||
uint256[] memory extracted_instances_memory_vk = das.getInstancesMemory(
|
||||
encoded_vka
|
||||
);
|
||||
assertEq(extracted_instances_memory_vk[0], instances[0]);
|
||||
assertEq(extracted_instances_memory_vk[1], instances[1]);
|
||||
}
|
||||
|
||||
function testInvalidCommitments() public {
|
||||
// Create calldata with invalid commitments
|
||||
bytes memory invalidProof = hex"5678";
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
invalidProof,
|
||||
instances
|
||||
);
|
||||
|
||||
vm.expectRevert("Invalid KZG commitments");
|
||||
das.verifyWithDataAttestation(address(verifier), encoded);
|
||||
}
|
||||
|
||||
function testInvalidVerifier() public {
|
||||
MockVerifier invalidVerifier = new MockVerifier(false);
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
hex"1234",
|
||||
instances
|
||||
);
|
||||
|
||||
vm.expectRevert("low-level call to verifier failed");
|
||||
das.verifyWithDataAttestation(address(invalidVerifier), encoded);
|
||||
}
|
||||
|
||||
// Test edge cases
|
||||
function testZeroValueQuantization() public view {
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
int256 zero = das.quantizeData(0, scalar);
|
||||
assertEq(zero, 0);
|
||||
}
|
||||
|
||||
function testOverflowProtection() public {
|
||||
int256 order = int(
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
)
|
||||
);
|
||||
// int256 half_order = int(order >> 1);
|
||||
AttestData.DataAttestation.Scalars memory scalar = AttestData
|
||||
.DataAttestation
|
||||
.Scalars(1, 1 << 2);
|
||||
|
||||
vm.expectRevert("Overflow field modulus");
|
||||
das.quantizeData(order, scalar); // Value that would overflow
|
||||
}
|
||||
|
||||
function testInvalidFunctionSignature() public {
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
bytes memory encoded_invalid_sig = abi.encodeWithSignature(
|
||||
"verifyProofff(bytes,uint256[])",
|
||||
hex"1234",
|
||||
instances
|
||||
);
|
||||
|
||||
vm.expectRevert("Invalid function signature");
|
||||
das.verifyWithDataAttestation(address(verifier), encoded_invalid_sig);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -46,28 +46,7 @@ mod py_tests {
|
||||
assert!(status.success());
|
||||
});
|
||||
// set VOICE_DATA_DIR environment variable
|
||||
unsafe {
|
||||
std::env::set_var("VOICE_DATA_DIR", format!("{}", voice_data_dir));
|
||||
}
|
||||
}
|
||||
|
||||
fn download_catdog_data() {
|
||||
let cat_and_dog_data_dir = shellexpand::tilde("~/data/catdog_data");
|
||||
|
||||
DOWNLOAD_VOICE_DATA.call_once(|| {
|
||||
let status = Command::new("bash")
|
||||
.args([
|
||||
"examples/notebooks/cat_and_dog_data.sh",
|
||||
&cat_and_dog_data_dir,
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
});
|
||||
// set VOICE_DATA_DIR environment variable
|
||||
unsafe {
|
||||
std::env::set_var("CATDOG_DATA_DIR", format!("{}", cat_and_dog_data_dir));
|
||||
}
|
||||
std::env::set_var("VOICE_DATA_DIR", format!("{}", voice_data_dir));
|
||||
}
|
||||
|
||||
fn setup_py_env() {
|
||||
@@ -93,10 +72,11 @@ mod py_tests {
|
||||
"torchtext==0.17.2",
|
||||
"torchvision==0.17.2",
|
||||
"pandas==2.2.1",
|
||||
"numpy==1.26.4",
|
||||
"seaborn==0.13.2",
|
||||
"notebook==7.1.2",
|
||||
"nbconvert==7.16.3",
|
||||
"onnx==1.17.0",
|
||||
"onnx==1.16.0",
|
||||
"kaggle==1.6.8",
|
||||
"py-solc-x==2.0.3",
|
||||
"web3==7.5.0",
|
||||
@@ -110,13 +90,12 @@ mod py_tests {
|
||||
"xgboost==2.0.3",
|
||||
"hummingbird-ml==0.4.11",
|
||||
"lightgbm==4.3.0",
|
||||
"numpy==1.26.4",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new("pip")
|
||||
.args(["install", "numpy==1.26.4"])
|
||||
.args(["install", "numpy==1.23"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
@@ -147,10 +126,10 @@ mod py_tests {
|
||||
}
|
||||
|
||||
const TESTS: [&str; 35] = [
|
||||
"mnist_gan.ipynb", // 0
|
||||
"ezkl_demo_batch.ipynb", // 1
|
||||
"proof_splitting.ipynb", // 2
|
||||
"variance.ipynb", // 3
|
||||
"ezkl_demo_batch.ipynb", // 0
|
||||
"proof_splitting.ipynb", // 1
|
||||
"variance.ipynb", // 2
|
||||
"mnist_gan.ipynb", // 3
|
||||
"keras_simple_demo.ipynb", // 4
|
||||
"mnist_gan_proof_splitting.ipynb", // 5
|
||||
"hashed_vis.ipynb", // 6
|
||||
@@ -246,20 +225,6 @@ mod py_tests {
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn cat_and_dog_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
let mut anvil_child = crate::py_tests::start_anvil(false);
|
||||
crate::py_tests::download_catdog_data();
|
||||
let test_dir: TempDir = TempDir::new("cat_and_dog").unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::py_tests::mv_test_(path, "cat_and_dog.ipynb");
|
||||
run_notebook(path, "cat_and_dog.ipynb");
|
||||
test_dir.close().unwrap();
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reusable_verifier_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
|
||||
@@ -48,6 +48,7 @@ def test_py_run_args():
|
||||
run_args = ezkl.PyRunArgs()
|
||||
run_args.input_visibility = "hashed"
|
||||
run_args.output_visibility = "hashed"
|
||||
run_args.tolerance = 1.5
|
||||
|
||||
|
||||
def test_poseidon_hash():
|
||||
@@ -58,7 +59,7 @@ def test_poseidon_hash():
|
||||
message = [ezkl.float_to_felt(x, 7) for x in message]
|
||||
res = ezkl.poseidon_hash(message)
|
||||
assert ezkl.felt_to_big_endian(
|
||||
res[0]) == "0x2369898875588bf49b6539376b09705ea69aee318a58e6fcc1e68fc3e7ad81ab"
|
||||
res[0]) == "0x0da7e5e5c8877242fa699f586baf770d731defd54f952d4adeb85047a0e32f45"
|
||||
|
||||
|
||||
|
||||
@@ -872,8 +873,6 @@ def get_examples():
|
||||
'linear_regression',
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age",
|
||||
"1d_conv",
|
||||
]
|
||||
examples = []
|
||||
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
|
||||
@@ -900,12 +899,7 @@ async def test_all_examples(model_file, input_file):
|
||||
proof_path = os.path.join(folder_path, 'proof.json')
|
||||
|
||||
print("Testing example: ", model_file)
|
||||
|
||||
run_args = ezkl.PyRunArgs()
|
||||
run_args.variables = [("batch_size", 1), ("sequence_length", 100), ("<Sym1>", 1)]
|
||||
run_args.logrows = 22
|
||||
|
||||
res = ezkl.gen_settings(model_file, settings_path, py_run_args=run_args)
|
||||
res = ezkl.gen_settings(model_file, settings_path)
|
||||
assert res
|
||||
|
||||
res = await ezkl.calibrate_settings(
|
||||
|
||||
@@ -11,6 +11,7 @@ mod wasm32 {
|
||||
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use ezkl::circuit::modules::poseidon::PoseidonChip;
|
||||
use ezkl::circuit::modules::Module;
|
||||
use ezkl::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use ezkl::graph::GraphCircuit;
|
||||
use ezkl::graph::{GraphSettings, GraphWitness};
|
||||
use ezkl::pfsys;
|
||||
@@ -226,9 +227,11 @@ mod wasm32 {
|
||||
let hash: Vec<Vec<Fr>> = serde_json::from_slice(&hash[..]).unwrap();
|
||||
|
||||
let reference_hash =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(hash, reference_hash)
|
||||
}
|
||||
@@ -337,7 +340,7 @@ mod wasm32 {
|
||||
// Run compiled circuit validation on onnx network (should fail)
|
||||
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK.to_vec()));
|
||||
assert!(circuit.is_err());
|
||||
// Run compiled circuit validation on compiled network (should pass)
|
||||
// Run compiled circuit validation on comiled network (should pass)
|
||||
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()));
|
||||
assert!(circuit.is_ok());
|
||||
// Run input validation on witness (should fail)
|
||||
|
||||
Reference in New Issue
Block a user