mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
10 Commits
ac/artifac
...
v18.0.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
566ccddf6d | ||
|
|
675628cd08 | ||
|
|
4fe7290240 | ||
|
|
3e027db9b6 | ||
|
|
e566acc22a | ||
|
|
75ea99e81d | ||
|
|
c5354c382d | ||
|
|
bdcba5ca61 | ||
|
|
6752a05f19 | ||
|
|
03aefb85eb |
33
.github/workflows/benchmarks.yml
vendored
33
.github/workflows/benchmarks.yml
vendored
@@ -6,22 +6,13 @@ on:
|
||||
description: "Test scenario tags"
|
||||
|
||||
jobs:
|
||||
bench_elgamal:
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Bench elgamal
|
||||
run: cargo bench --verbose --bench elgamal
|
||||
|
||||
bench_poseidon:
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -35,6 +26,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -48,6 +41,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -61,6 +56,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -74,6 +71,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -87,6 +86,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -100,6 +101,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -113,6 +116,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -126,6 +131,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -139,6 +146,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
@@ -152,6 +161,8 @@ jobs:
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
|
||||
10
.github/workflows/engine.yml
vendored
10
.github/workflows/engine.yml
vendored
@@ -20,6 +20,8 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -51,7 +53,7 @@ jobs:
|
||||
run: |
|
||||
echo '{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "${{ github.ref_name }}",
|
||||
"version": "${RELEASE_TAG}",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
@@ -190,15 +192,17 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${RELEASE_TAG}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
|
||||
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
|
||||
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
|
||||
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
|
||||
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
|
||||
|
||||
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -9,6 +9,8 @@ jobs:
|
||||
runs-on: kaiju
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
|
||||
2
.github/workflows/pypi-gpu.yml
vendored
2
.github/workflows/pypi-gpu.yml
vendored
@@ -24,6 +24,8 @@ jobs:
|
||||
target: [x86_64]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
13
.github/workflows/pypi.yml
vendored
13
.github/workflows/pypi.yml
vendored
@@ -23,6 +23,8 @@ jobs:
|
||||
target: [x86_64, universal2-apple-darwin]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -69,6 +71,8 @@ jobs:
|
||||
target: [x64, x86]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -114,6 +118,8 @@ jobs:
|
||||
target: [x86_64]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -228,6 +234,8 @@ jobs:
|
||||
- x86_64-unknown-linux-musl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -292,6 +300,8 @@ jobs:
|
||||
arch: aarch64
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.12
|
||||
@@ -373,7 +383,8 @@ jobs:
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
with:
|
||||
|
||||
5
.github/workflows/release.yml
vendored
5
.github/workflows/release.yml
vendored
@@ -50,6 +50,9 @@ jobs:
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- name: Get release version from tag
|
||||
shell: bash
|
||||
@@ -132,6 +135,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get release version from tag
|
||||
shell: bash
|
||||
|
||||
74
.github/workflows/rust.yml
vendored
74
.github/workflows/rust.yml
vendored
@@ -19,11 +19,28 @@ env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
|
||||
fr-age-test:
|
||||
runs-on: large-self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: fr age Mock
|
||||
run: cargo test --release --verbose tests::large_mock_::large_tests_6_expects -- --include-ignored
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -36,6 +53,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -48,6 +67,8 @@ jobs:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -71,6 +92,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -104,6 +127,8 @@ jobs:
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -137,6 +162,8 @@ jobs:
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -170,6 +197,8 @@ jobs:
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -186,6 +215,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -212,6 +243,8 @@ jobs:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -275,6 +308,8 @@ jobs:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -285,6 +320,8 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
@@ -354,6 +391,8 @@ jobs:
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -369,6 +408,8 @@ jobs:
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
@@ -431,6 +472,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -465,6 +508,8 @@ jobs:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -483,6 +528,8 @@ jobs:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
@@ -500,6 +547,8 @@ jobs:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -517,6 +566,8 @@ jobs:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -538,6 +589,8 @@ jobs:
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -555,6 +608,8 @@ jobs:
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
@@ -577,10 +632,12 @@ jobs:
|
||||
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
|
||||
|
||||
accuracy-measurement-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
@@ -628,6 +685,8 @@ jobs:
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
@@ -680,6 +739,8 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -698,6 +759,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
@@ -715,6 +778,15 @@ jobs:
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
|
||||
32
.github/workflows/static-analysis.yml
vendored
Normal file
32
.github/workflows/static-analysis.yml
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
name: Static Analysis
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
# Run Zizmor static analysis
|
||||
|
||||
- name: Install Zizmor
|
||||
run: cargo install --locked zizmor
|
||||
|
||||
- name: Run Zizmor Analysis
|
||||
run: zizmor .
|
||||
|
||||
|
||||
|
||||
10
.github/workflows/swift-pm.yml
vendored
10
.github/workflows/swift-pm.yml
vendored
@@ -12,15 +12,18 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
|
||||
steps:
|
||||
- name: Checkout EZKL
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Extract TAG from github.ref_name
|
||||
run: |
|
||||
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
|
||||
TAG="${{ github.ref_name }}"
|
||||
TAG="${RELEASE_TAG}"
|
||||
echo "Original TAG: $TAG"
|
||||
# Remove leading 'v' if present to match the Swift Package Manager version format.
|
||||
NEW_TAG=${TAG#v}
|
||||
@@ -47,7 +50,8 @@ jobs:
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/*
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
@@ -105,7 +109,6 @@ jobs:
|
||||
cd ezkl-swift-package
|
||||
git add Sources/EzklCoreBindings Tests/EzklAssets
|
||||
git commit -m "Automatically updated EzklCoreBindings for EZKL"
|
||||
|
||||
if ! git push origin; then
|
||||
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
|
||||
exit 1
|
||||
@@ -115,7 +118,6 @@ jobs:
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
source $GITHUB_ENV
|
||||
|
||||
# Tag the latest commit on the current branch
|
||||
if git rev-parse "$TAG" >/dev/null 2>&1; then
|
||||
echo "Tag $TAG already exists locally. Skipping tag creation."
|
||||
|
||||
2
.github/workflows/tagging.yml
vendored
2
.github/workflows/tagging.yml
vendored
@@ -12,6 +12,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
|
||||
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -2377,7 +2377,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#6d72498928cdb69ce0de9f2230d2873ca2cf5324"
|
||||
source = "git+https://github.com/zkonduit/halo2#d7ecad83c7439fa1cb450ee4a89c2d0b45604ceb"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec",
|
||||
@@ -2394,7 +2394,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#6d72498928cdb69ce0de9f2230d2873ca2cf5324#6d72498928cdb69ce0de9f2230d2873ca2cf5324"
|
||||
source = "git+https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6#ee4e1a09ebdb1f79f797685b78951c6034c430a6"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
|
||||
@@ -280,10 +280,10 @@ no-update = []
|
||||
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#6d72498928cdb69ce0de9f2230d2873ca2cf5324", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#6d72498928cdb69ce0de9f2230d2873ca2cf5324", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#ee4e1a09ebdb1f79f797685b78951c6034c430a6", package = "halo2_proofs" }
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '18.0.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -54,7 +54,7 @@
|
||||
" gip_run_args.param_scale = 19\n",
|
||||
" gip_run_args.logrows = 8\n",
|
||||
" run_args = ezkl.gen_settings(py_run_args=gip_run_args)\n",
|
||||
" ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n",
|
||||
" await ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n",
|
||||
" ezkl.compile_circuit()\n",
|
||||
" res = await ezkl.gen_witness()\n",
|
||||
" print(res)\n",
|
||||
|
||||
@@ -1,279 +1,284 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Linear Regression\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
|
||||
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import json\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.linear_model import LinearRegression\n",
|
||||
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
|
||||
"# y = 1 * x_0 + 2 * x_1 + 3\n",
|
||||
"y = np.dot(X, np.array([1, 2])) + 3\n",
|
||||
"reg = LinearRegression().fit(X, y)\n",
|
||||
"reg.score(X, y)\n",
|
||||
"\n",
|
||||
"circuit = convert(reg, \"torch\", X[:1]).model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = X.shape[1:]\n",
|
||||
"x = torch.rand(1, *shape, requires_grad=True)\n",
|
||||
"torch_out = circuit(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
" x,\n",
|
||||
" # where to save the model (can be a file or file-like object)\n",
|
||||
" \"network.onnx\",\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names=['input'], # the model's input names\n",
|
||||
" output_names=['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cal_path = os.path.join(\"calibration.json\")\n",
|
||||
"\n",
|
||||
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aa4f090",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Linear Regression\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
|
||||
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import json\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.linear_model import LinearRegression\n",
|
||||
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
|
||||
"# y = 1 * x_0 + 2 * x_1 + 3\n",
|
||||
"y = np.dot(X, np.array([1, 2])) + 3\n",
|
||||
"reg = LinearRegression().fit(X, y)\n",
|
||||
"reg.score(X, y)\n",
|
||||
"\n",
|
||||
"circuit = convert(reg, \"torch\", X[:1]).model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = X.shape[1:]\n",
|
||||
"x = torch.rand(1, *shape, requires_grad=True)\n",
|
||||
"torch_out = circuit(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
" x,\n",
|
||||
" # where to save the model (can be a file or file-like object)\n",
|
||||
" \"network.onnx\",\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names=['input'], # the model's input names\n",
|
||||
" output_names=['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# note that you can also call the following function to generate random data for the model\n",
|
||||
"# it is functionally equivalent to the code above\n",
|
||||
"ezkl.gen_random_data()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cal_path = os.path.join(\"calibration.json\")\n",
|
||||
"\n",
|
||||
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aa4f090",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
||||
@@ -1,456 +1,459 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2wIAHwqH2_mo"
|
||||
},
|
||||
"source": [
|
||||
"**Import Dependencies**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "osjj-0Ta3E8O"
|
||||
},
|
||||
"source": [
|
||||
"**Create Computational Graph**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
"\n",
|
||||
" # x is a time series \n",
|
||||
" def forward(self, x):\n",
|
||||
" return [torch.mean(x)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = Model()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# # print(torch.__version__)\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"print(device)\n",
|
||||
"\n",
|
||||
"circuit.to(device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=11, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"# export(circuit, input_shape=[1, 20])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E3qCeX-X5xqd"
|
||||
},
|
||||
"source": [
|
||||
"**Set Data Source and Get Data**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"pg_input_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
|
||||
"print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# this corresponds to 4 batches\n",
|
||||
"calibration_filename = os.path.join('calibration.json')\n",
|
||||
"\n",
|
||||
"pg_cal_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eLJ7oirQ_HQR"
|
||||
},
|
||||
"source": [
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rNw0C9QL6W88"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
"with open(\"settings.json\") as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
" json_formatted_str = json.dumps(data, indent=2)\n",
|
||||
"\n",
|
||||
" print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"assert os.path.exists(\"settings.json\")\n",
|
||||
"assert os.path.exists(\"input.json\")\n",
|
||||
"assert os.path.exists(\"lol.onnx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_filename)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2wIAHwqH2_mo"
|
||||
},
|
||||
"source": [
|
||||
"**Import Dependencies**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "osjj-0Ta3E8O"
|
||||
},
|
||||
"source": [
|
||||
"**Create Computational Graph**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
"\n",
|
||||
" # x is a time series \n",
|
||||
" def forward(self, x):\n",
|
||||
" return [torch.mean(x)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = Model()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# # print(torch.__version__)\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"print(device)\n",
|
||||
"\n",
|
||||
"circuit.to(device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=11, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"# export(circuit, input_shape=[1, 20])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E3qCeX-X5xqd"
|
||||
},
|
||||
"source": [
|
||||
"**Set Data Source and Get Data**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"pg_input_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
|
||||
"print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# this corresponds to 4 batches\n",
|
||||
"calibration_filename = os.path.join('calibration.json')\n",
|
||||
"\n",
|
||||
"pg_cal_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eLJ7oirQ_HQR"
|
||||
},
|
||||
"source": [
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rNw0C9QL6W88"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"await ezkl.get_srs(settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
"with open(\"settings.json\") as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
" json_formatted_str = json.dumps(data, indent=2)\n",
|
||||
"\n",
|
||||
" print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"assert os.path.exists(\"settings.json\")\n",
|
||||
"assert os.path.exists(\"input.json\")\n",
|
||||
"assert os.path.exists(\"lol.onnx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_filename)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
||||
@@ -453,18 +453,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now mock aggregate the proofs\n",
|
||||
"proofs = []\n",
|
||||
"for i in range(3):\n",
|
||||
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
" proofs.append(proof_path)\n",
|
||||
"# proofs = []\n",
|
||||
"# for i in range(3):\n",
|
||||
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
"# proofs.append(proof_path)\n",
|
||||
"\n",
|
||||
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
|
||||
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ezkl",
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -478,7 +478,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
"version": "3.12.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
1
examples/onnx/fr_age/input.json
Normal file
1
examples/onnx/fr_age/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/fr_age/network.onnx
Normal file
BIN
examples/onnx/fr_age/network.onnx
Normal file
Binary file not shown.
@@ -938,6 +938,45 @@ fn gen_settings(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Generates random data for the model
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// output: str
|
||||
/// Path to create the data file
|
||||
///
|
||||
/// seed: int
|
||||
/// Random seed to use for generated data
|
||||
///
|
||||
/// variables
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
output=PathBuf::from(DEFAULT_SETTINGS),
|
||||
variables=Vec::from([("batch_size".to_string(), 1)]),
|
||||
seed=DEFAULT_SEED.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_random_data(
|
||||
model: PathBuf,
|
||||
output: PathBuf,
|
||||
variables: Vec<(String, usize)>,
|
||||
seed: u64,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
|
||||
let err_str = format!("Failed to generate settings: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Calibrates the circuit settings
|
||||
///
|
||||
/// Arguments
|
||||
@@ -2055,6 +2094,7 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(get_srs, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;
|
||||
|
||||
@@ -100,9 +100,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
Self::configure_with_cols(
|
||||
@@ -152,9 +149,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
|
||||
|
||||
for input in hash_inputs.iter().take(WIDTH) {
|
||||
meta.enable_equality(*input);
|
||||
}
|
||||
meta.enable_constant(rc_b[0]);
|
||||
|
||||
let instance = meta.instance_column();
|
||||
|
||||
@@ -75,6 +75,16 @@ impl FromStr for CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
impl CheckMode {
|
||||
/// Returns the value of the check mode
|
||||
pub fn is_safe(&self) -> bool {
|
||||
match self {
|
||||
CheckMode::SAFE => true,
|
||||
CheckMode::UNSAFE => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
|
||||
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
|
||||
@@ -205,15 +215,16 @@ impl DynamicLookups {
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
|
||||
pub struct Shuffles {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub reference_selectors: Vec<Selector>,
|
||||
pub output_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// tables
|
||||
pub references: Vec<VarTensor>,
|
||||
pub outputs: Vec<VarTensor>,
|
||||
}
|
||||
|
||||
impl Shuffles {
|
||||
@@ -224,9 +235,13 @@ impl Shuffles {
|
||||
|
||||
Self {
|
||||
input_selectors: BTreeMap::new(),
|
||||
reference_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone()],
|
||||
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
|
||||
output_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
|
||||
outputs: vec![
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -364,6 +379,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
if inputs[0].num_cols() != output.num_cols() {
|
||||
log::warn!("input and output shapes do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
|
||||
log::warn!("input number of inner columns do not match");
|
||||
}
|
||||
if inputs[0].num_inner_cols() != output.num_inner_cols() {
|
||||
log::warn!("input and output number of inner columns do not match");
|
||||
}
|
||||
|
||||
for i in 0..output.num_blocks() {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
@@ -730,8 +751,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
pub fn configure_shuffles(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
inputs: &[VarTensor; 2],
|
||||
references: &[VarTensor; 2],
|
||||
inputs: &[VarTensor; 3],
|
||||
outputs: &[VarTensor; 3],
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
@@ -742,14 +763,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
}
|
||||
|
||||
for t in references.iter() {
|
||||
for t in outputs.iter() {
|
||||
if !t.is_advice() || t.num_inner_cols() > 1 {
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// assert all tables have the same number of blocks
|
||||
if references
|
||||
if outputs
|
||||
.iter()
|
||||
.map(|t| t.num_blocks())
|
||||
.collect::<Vec<_>>()
|
||||
@@ -757,23 +778,23 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
.any(|w| w[0] != w[1])
|
||||
{
|
||||
return Err(CircuitError::WrongDynamicColumnType(
|
||||
"references inner cols".to_string(),
|
||||
"outputs inner cols".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
for q in 0..references[0].num_blocks() {
|
||||
let s_reference = cs.complex_selector();
|
||||
for q in 0..outputs[0].num_blocks() {
|
||||
let s_output = cs.complex_selector();
|
||||
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
cs.lookup_any("shuffle", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_referenceq = cs.query_selector(s_reference);
|
||||
let s_outputq = cs.query_selector(s_output);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
|
||||
for input in inputs {
|
||||
@@ -785,9 +806,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
});
|
||||
}
|
||||
|
||||
let mut ref_queries = vec![one.clone()];
|
||||
for reference in references {
|
||||
ref_queries.push(match reference {
|
||||
let mut output_queries = vec![one.clone()];
|
||||
for output in outputs {
|
||||
output_queries.push(match output {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[q][0], Rotation(0))
|
||||
}
|
||||
@@ -796,7 +817,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
}
|
||||
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
|
||||
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
@@ -807,13 +828,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
.or_insert(s_input);
|
||||
}
|
||||
}
|
||||
self.shuffles.reference_selectors.push(s_reference);
|
||||
self.shuffles.output_selectors.push(s_output);
|
||||
}
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.shuffles.references.is_empty() {
|
||||
debug!("assigning shuffles reference");
|
||||
self.shuffles.references = references.to_vec();
|
||||
if self.shuffles.outputs.is_empty() {
|
||||
debug!("assigning shuffles output");
|
||||
self.shuffles.outputs = outputs.to_vec();
|
||||
}
|
||||
if self.shuffles.inputs.is_empty() {
|
||||
debug!("assigning shuffles input");
|
||||
|
||||
@@ -100,4 +100,7 @@ pub enum CircuitError {
|
||||
#[error("invalid input type {0}")]
|
||||
/// Invalid input type
|
||||
InvalidInputType(String),
|
||||
#[error("an element is missing from the shuffled version of the tensor")]
|
||||
/// An element is missing from the shuffled version of the tensor
|
||||
MissingShuffleElement,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::integer_rep_to_felt,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -250,8 +250,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
integer_rep_to_felt(input_scale.0 as IntegerRep),
|
||||
integer_rep_to_felt(output_scale.0 as IntegerRep),
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
@@ -259,7 +259,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(denom.0 as i128),
|
||||
integer_rep_to_felt(denom.0 as IntegerRep),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
|
||||
@@ -43,7 +43,7 @@ const ASCII_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz";
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
@@ -160,9 +160,26 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
let claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -235,6 +252,30 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[equal_zero_mask.clone(), equal_inverse_mask],
|
||||
)?;
|
||||
|
||||
let masked_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), not_equal_zero_mask.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let err_func = |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>|
|
||||
@@ -266,7 +307,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 9]),
|
||||
/// &[3, 3],
|
||||
@@ -338,7 +379,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
@@ -377,7 +418,7 @@ pub fn rsqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::layouts::dot;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
@@ -404,22 +445,6 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let mut values = values.clone();
|
||||
|
||||
// this section has been optimized to death, don't mess with it
|
||||
let mut removal_indices = values[0].get_const_zero_indices();
|
||||
let second_zero_indices = values[1].get_const_zero_indices();
|
||||
removal_indices.extend(second_zero_indices);
|
||||
removal_indices.par_sort_unstable();
|
||||
removal_indices.dedup();
|
||||
|
||||
// if empty return a const
|
||||
if removal_indices.len() == values[0].len() {
|
||||
return Ok(create_zero_tensor(1));
|
||||
}
|
||||
|
||||
// is already sorted
|
||||
values[0].remove_indices(&mut removal_indices, true)?;
|
||||
values[1].remove_indices(&mut removal_indices, true)?;
|
||||
|
||||
let mut inputs = vec![];
|
||||
let block_width = config.custom_gates.output.num_inner_cols();
|
||||
|
||||
@@ -492,7 +517,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// // matmul case
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -906,7 +931,6 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let window_b = assigned_sort.get_slice(&[1..assigned_sort.len()])?;
|
||||
|
||||
let is_greater = greater_equal(config, region, &[window_b.clone(), window_a.clone()])?;
|
||||
|
||||
let unit = create_unit_tensor(is_greater.len());
|
||||
|
||||
enforce_equality(config, region, &[unit, is_greater])?;
|
||||
@@ -945,7 +969,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -1168,30 +1192,38 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
}
|
||||
|
||||
/// Shuffle arg
|
||||
/// 1. the input is a set of pairs (index_input, value_input) -- looked up against a (dynamic) set of values (index_output, value_output).
|
||||
/// 2. index_input is copy constrained to values in a fixed column and is thus a fixed set of incrementing values over the input
|
||||
/// 3. value_input is just the input that we are ascertaining is shuffled
|
||||
/// 4. index_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 5. value_output is (typically) a prover generated witness committed to in an advice column
|
||||
/// 6. Given the above, and given the fixed index_input , we go through every (index_input, value_input) pair and ascertain that it is contained in the input.
|
||||
/// Given the fixed incrementing index index_input, we avoid multiplicity in the output by leveraging this surrogate index: if index_output isn't matched to the exact value where for `index_input=index_output` -> `value_input=value_output`, then the lookup fails
|
||||
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
output: &[ValTensor<F>; 1],
|
||||
input: &[ValTensor<F>; 1],
|
||||
reference: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let shuffle_index = region.shuffle_index();
|
||||
let (input, reference) = (input[0].clone(), reference[0].clone());
|
||||
let (output, input) = (output[0].clone(), input[0].clone());
|
||||
|
||||
// assert input and reference are same length
|
||||
if input.len() != reference.len() {
|
||||
if output.len() != input.len() {
|
||||
return Err(CircuitError::MismatchedShuffleLength(
|
||||
output.len(),
|
||||
input.len(),
|
||||
reference.len(),
|
||||
));
|
||||
}
|
||||
|
||||
let (reference, flush_len_ref) =
|
||||
region.assign_shuffle(&config.shuffles.references[0], &reference)?;
|
||||
let reference_len = reference.len();
|
||||
let (output, flush_len_ref) = region.assign_shuffle(&config.shuffles.outputs[0], &output)?;
|
||||
let output_len = output.len();
|
||||
let input = region.assign(&config.shuffles.inputs[0], &input)?;
|
||||
|
||||
// now create a vartensor of constants for the shuffle index
|
||||
let index = create_constant_tensor(F::from(shuffle_index as u64), reference_len);
|
||||
let (index, flush_len_index) = region.assign_shuffle(&config.shuffles.references[1], &index)?;
|
||||
let index = create_constant_tensor(F::from(shuffle_index as u64), output_len);
|
||||
let (index, flush_len_index) = region.assign_shuffle(&config.shuffles.outputs[1], &index)?;
|
||||
region.assign(&config.shuffles.inputs[1], &index)?;
|
||||
|
||||
if flush_len_index != flush_len_ref {
|
||||
return Err(CircuitError::MismatchedShuffleLength(
|
||||
@@ -1200,18 +1232,71 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
));
|
||||
}
|
||||
|
||||
let input = region.assign(&config.shuffles.inputs[0], &input)?;
|
||||
region.assign(&config.shuffles.inputs[1], &index)?;
|
||||
// now found the position of each element of the reference to the input
|
||||
|
||||
let is_known = !output.any_unknowns()? && !input.any_unknowns()?;
|
||||
|
||||
let claimed_index_output = if is_known {
|
||||
let input = input.int_evals()?;
|
||||
let output = output.int_evals()?;
|
||||
|
||||
// Keep track of which positions we've used for each value
|
||||
let mut used_positions: HashMap<usize, bool> = HashMap::new();
|
||||
|
||||
let index_output = output
|
||||
.iter()
|
||||
.map(|x| {
|
||||
// Find all positions of the current element
|
||||
let positions: Vec<usize> = input
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, y)| *y == x)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
// Find the first unused position for this element
|
||||
let pos = positions
|
||||
.iter()
|
||||
.find(|&&i| !used_positions.get(&i).unwrap_or(&false))
|
||||
.unwrap_or(&0);
|
||||
|
||||
// Mark this position as used
|
||||
used_positions.insert(*pos, true);
|
||||
|
||||
Ok::<_, CircuitError>(Value::known(F::from(*pos as u64)))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// Verify all indices were used exactly once
|
||||
if !used_positions.values().all(|&x| x) {
|
||||
return Err(CircuitError::MissingShuffleElement);
|
||||
}
|
||||
|
||||
Tensor::from(index_output.into_iter()).into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); output_len]),
|
||||
&[output_len],
|
||||
)?
|
||||
.into()
|
||||
};
|
||||
|
||||
region.assign_shuffle(&config.shuffles.outputs[2], &claimed_index_output)?;
|
||||
|
||||
// the incrementing index is the set of numbered values for the input tensor 0...n, and is FIXED
|
||||
let incrementing_index: ValTensor<F> =
|
||||
Tensor::from((0..output.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
|
||||
region.assign(&config.shuffles.inputs[2], &incrementing_index)?;
|
||||
|
||||
let mut shuffle_block = 0;
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..reference_len)
|
||||
(0..output_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) = config.shuffles.references[0]
|
||||
let (x, _, z) = config.shuffles.outputs[0]
|
||||
.cartesian_coord(region.combined_dynamic_shuffle_coord() + i + flush_len_ref);
|
||||
shuffle_block = x;
|
||||
let ref_selector = config.shuffles.reference_selectors[shuffle_block];
|
||||
let ref_selector = config.shuffles.output_selectors[shuffle_block];
|
||||
region.enable(Some(&ref_selector), z)?;
|
||||
Ok(())
|
||||
})
|
||||
@@ -1220,7 +1305,7 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
|
||||
if !region.is_dummy() {
|
||||
// Enable the selectors
|
||||
(0..reference_len)
|
||||
(0..output_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) =
|
||||
config.custom_gates.inputs[0].cartesian_coord(region.linear_coord() + i);
|
||||
@@ -1237,11 +1322,11 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
region.increment_shuffle_col_coord(reference_len + flush_len_ref);
|
||||
region.increment_shuffle_col_coord(output_len + flush_len_ref);
|
||||
region.increment_shuffle_index(1);
|
||||
region.increment(reference_len);
|
||||
region.increment(output_len);
|
||||
|
||||
Ok(input)
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// One hot accumulated layout
|
||||
@@ -1698,6 +1783,8 @@ pub(crate) fn get_missing_set_elements<
|
||||
// assign the claimed output
|
||||
claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
|
||||
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// input and claimed output should be the shuffles of fullset
|
||||
// concatentate input and claimed output
|
||||
let input_and_claimed_output = input.concat(claimed_output.clone())?;
|
||||
@@ -1891,7 +1978,7 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -1989,7 +2076,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2144,7 +2231,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2180,7 +2267,7 @@ pub fn prod_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2216,7 +2303,7 @@ pub fn sum_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2258,7 +2345,7 @@ pub fn argmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
@@ -2294,7 +2381,7 @@ pub fn max_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2340,7 +2427,7 @@ pub fn argmin_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2381,25 +2468,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
lhs.expand(&broadcasted_shape)?;
|
||||
rhs.expand(&broadcasted_shape)?;
|
||||
|
||||
// original values
|
||||
let orig_lhs = lhs.clone();
|
||||
let orig_rhs = rhs.clone();
|
||||
|
||||
let first_zero_indices = HashSet::from_iter(lhs.get_const_zero_indices());
|
||||
let second_zero_indices = HashSet::from_iter(rhs.get_const_zero_indices());
|
||||
|
||||
let removal_indices = match op {
|
||||
BaseOp::Add | BaseOp::Mult => {
|
||||
// join the zero indices
|
||||
first_zero_indices
|
||||
.union(&second_zero_indices)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
BaseOp::Sub => second_zero_indices.clone(),
|
||||
_ => return Err(CircuitError::UnsupportedOp),
|
||||
};
|
||||
|
||||
if lhs.len() != rhs.len() {
|
||||
return Err(CircuitError::DimMismatch(format!(
|
||||
"pairwise {} layout",
|
||||
@@ -2411,11 +2479,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, input)| {
|
||||
let res = region.assign_with_omissions(
|
||||
&config.custom_gates.inputs[i],
|
||||
input,
|
||||
&removal_indices,
|
||||
)?;
|
||||
let res = region.assign(&config.custom_gates.inputs[i], input)?;
|
||||
|
||||
Ok(res.get_inner()?)
|
||||
})
|
||||
@@ -2434,12 +2498,8 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
|
||||
let assigned_len = op_result.len() - removal_indices.len();
|
||||
let mut output = region.assign_with_omissions(
|
||||
&config.custom_gates.output,
|
||||
&op_result.into(),
|
||||
&removal_indices,
|
||||
)?;
|
||||
let assigned_len = op_result.len();
|
||||
let mut output = region.assign(&config.custom_gates.output, &op_result.into())?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
@@ -2457,44 +2517,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
}
|
||||
region.increment(assigned_len);
|
||||
|
||||
let a_tensor = orig_lhs.get_inner_tensor()?;
|
||||
let b_tensor = orig_rhs.get_inner_tensor()?;
|
||||
|
||||
// infill the zero indices with the correct values from values[0] or values[1]
|
||||
if !removal_indices.is_empty() {
|
||||
output
|
||||
.get_inner_tensor_mut()?
|
||||
.par_enum_map_mut_filtered(&removal_indices, |i| {
|
||||
let val = match op {
|
||||
BaseOp::Add => {
|
||||
let a_is_null = first_zero_indices.contains(&i);
|
||||
let b_is_null = second_zero_indices.contains(&i);
|
||||
|
||||
if a_is_null && b_is_null {
|
||||
ValType::Constant(F::ZERO)
|
||||
} else if a_is_null {
|
||||
b_tensor[i].clone()
|
||||
} else {
|
||||
a_tensor[i].clone()
|
||||
}
|
||||
}
|
||||
BaseOp::Sub => {
|
||||
let a_is_null = first_zero_indices.contains(&i);
|
||||
// by default b is null in this case for sub
|
||||
if a_is_null {
|
||||
ValType::Constant(F::ZERO)
|
||||
} else {
|
||||
a_tensor[i].clone()
|
||||
}
|
||||
}
|
||||
BaseOp::Mult => ValType::Constant(F::ZERO),
|
||||
// can safely panic as the prior check ensures this is not called
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok::<_, TensorError>(val)
|
||||
})?;
|
||||
}
|
||||
|
||||
output.reshape(&broadcasted_shape)?;
|
||||
|
||||
let end = global_start.elapsed();
|
||||
@@ -2522,7 +2544,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
@@ -2579,7 +2601,7 @@ pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 12, 6, 4, 5, 6]),
|
||||
@@ -2607,7 +2629,8 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let diff = pairwise(config, region, &[lhs, rhs], BaseOp::Sub)?;
|
||||
let sign = sign(config, region, &[diff])?;
|
||||
equals(config, region, &[sign, create_unit_tensor(1)])
|
||||
let eq = equals(config, region, &[sign, create_unit_tensor(1)])?;
|
||||
Ok(eq)
|
||||
}
|
||||
|
||||
/// Greater equals than operation.
|
||||
@@ -2626,7 +2649,7 @@ pub fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -2676,7 +2699,7 @@ pub fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128, 2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 5, 4, 5, 1]),
|
||||
@@ -2717,7 +2740,7 @@ pub fn less<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 5, 4, 5, 1]),
|
||||
@@ -2758,7 +2781,7 @@ pub fn less_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2802,7 +2825,7 @@ pub fn and<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2850,7 +2873,7 @@ pub fn or<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2881,6 +2904,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::H
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let values = values[0].clone();
|
||||
let values_inverse = values.inverse()?;
|
||||
|
||||
let product_values_and_invert = pairwise(
|
||||
config,
|
||||
region,
|
||||
@@ -2924,7 +2948,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::H
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let a = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -2980,7 +3004,7 @@ pub fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1, 1, 0]),
|
||||
@@ -3023,7 +3047,7 @@ pub fn not<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let mask = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 0, 1, 0, 1, 0]),
|
||||
@@ -3081,7 +3105,7 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
@@ -3113,7 +3137,7 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3195,7 +3219,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3249,32 +3273,30 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
output
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(flat_index, o)| {
|
||||
let coord = &cartesian_coord[flat_index];
|
||||
let (b, i) = (coord[0], coord[1]);
|
||||
let inner_loop_function = |idx: usize, region: &mut RegionCtx<F>| {
|
||||
let coord = &cartesian_coord[idx];
|
||||
let (b, i) = (coord[0], coord[1]);
|
||||
|
||||
let mut slice = vec![b..b + 1, i..i + 1];
|
||||
slice.extend(
|
||||
coord[2..]
|
||||
.iter()
|
||||
.zip(stride.iter())
|
||||
.zip(pool_dims.iter())
|
||||
.map(|((c, s), k)| {
|
||||
let start = c * s;
|
||||
let end = start + k;
|
||||
start..end
|
||||
}),
|
||||
);
|
||||
let mut slice = vec![b..b + 1, i..i + 1];
|
||||
slice.extend(
|
||||
coord[2..]
|
||||
.iter()
|
||||
.zip(stride.iter())
|
||||
.zip(pool_dims.iter())
|
||||
.map(|((c, s), k)| {
|
||||
let start = c * s;
|
||||
let end = start + k;
|
||||
start..end
|
||||
}),
|
||||
);
|
||||
|
||||
let slice = padded_image.get_slice(&slice)?;
|
||||
let max_w = max(config, region, &[slice])?;
|
||||
*o = max_w.get_inner_tensor()?[0].clone();
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
let slice = padded_image.get_slice(&slice)?;
|
||||
let max_w = max(config, region, &[slice])?;
|
||||
|
||||
Ok::<_, CircuitError>(max_w.get_inner_tensor()?[0].clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let res: ValTensor<F> = output.into();
|
||||
|
||||
@@ -3296,7 +3318,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
@@ -3524,7 +3546,7 @@ pub fn deconv<
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
@@ -4038,7 +4060,7 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd + std::hash::H
|
||||
}
|
||||
|
||||
let is_assigned = !w.any_unknowns()?;
|
||||
if is_assigned && region.check_range() {
|
||||
if is_assigned && region.check_range() && config.check_mode.is_safe() {
|
||||
// assert is within range
|
||||
let int_values = w.int_evals()?;
|
||||
for v in int_values.iter() {
|
||||
@@ -4248,7 +4270,7 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4301,7 +4323,7 @@ pub fn max_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4340,10 +4362,7 @@ pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input_len = values[0].len();
|
||||
_sort_ascending(config, region, values)?
|
||||
.get_slice(&[input_len - 1..input_len])
|
||||
.map_err(|e| e.into())
|
||||
Ok(_sort_ascending(config, region, values)?.last()?)
|
||||
}
|
||||
|
||||
/// min layout
|
||||
@@ -4352,9 +4371,7 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
_sort_ascending(config, region, values)?
|
||||
.get_slice(&[0..1])
|
||||
.map_err(|e| e.into())
|
||||
Ok(_sort_ascending(config, region, values)?.first()?)
|
||||
}
|
||||
|
||||
/// floor layout
|
||||
@@ -4377,7 +4394,7 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, -3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4489,7 +4506,7 @@ pub fn floor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4602,7 +4619,7 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -4889,7 +4906,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -5032,7 +5049,7 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -5267,7 +5284,24 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
let rest = claimed_output.get_slice(&rest_slice)?;
|
||||
|
||||
let sign = range_check(config, region, &[sign], &(-1, 1))?;
|
||||
let rest = range_check(config, region, &[rest], &(0, (*base - 1) as i128))?;
|
||||
|
||||
// isZero(input) * sign == 0.
|
||||
{
|
||||
let is_zero = equals_zero(config, region, &[input.clone()])?;
|
||||
// take the product of the sign and is_zero
|
||||
let sign_is_zero = pairwise(config, region, &[sign.clone(), is_zero], BaseOp::Mult)?;
|
||||
// constrain the sign_is_zero to be 0
|
||||
enforce_equality(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sign_is_zero.clone(),
|
||||
create_constant_tensor(F::ZERO, sign_is_zero.len()),
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
||||
let rest = range_check(config, region, &[rest], &(0, (*base - 1) as IntegerRep))?;
|
||||
|
||||
// equation needs to be constructed as ij,ij->i but for arbitrary n dims we need to construct this dynamically
|
||||
// indices should map in order of the alphabet
|
||||
@@ -5531,7 +5565,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 2, 3, 2, 2, 0]),
|
||||
@@ -5582,7 +5616,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[100, 200, 300, 400, 500, 600]),
|
||||
|
||||
@@ -1040,6 +1040,10 @@ mod conv {
|
||||
let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
|
||||
|
||||
// column for constants
|
||||
let _constant = VarTensor::constant_cols(cs, K, 8, false);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1171,7 +1175,7 @@ mod conv_col_ultra_overflow {
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const K: usize = 6;
|
||||
const LEN: usize = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -1191,9 +1195,10 @@ mod conv_col_ultra_overflow {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * LEN * LEN * LEN, false);
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
@@ -1776,13 +1781,18 @@ mod shuffle {
|
||||
|
||||
let d = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let e = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
|
||||
|
||||
let mut config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
|
||||
config
|
||||
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
|
||||
.configure_shuffles(
|
||||
cs,
|
||||
&[a.clone(), b.clone(), c.clone()],
|
||||
&[d.clone(), e.clone(), f.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
@@ -83,13 +83,15 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
/// Default VK abi path
|
||||
pub const DEFAULT_VK_ABI: &str = "vk.abi";
|
||||
/// Default scale rebase multipliers for calibration
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,10";
|
||||
/// Default use reduced srs for verification
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
/// Default commitment
|
||||
pub const DEFAULT_COMMITMENT: &str = "kzg";
|
||||
/// Default seed used to generate random data
|
||||
pub const DEFAULT_SEED: &str = "21242";
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
@@ -422,7 +424,21 @@ pub enum Commands {
|
||||
#[clap(flatten)]
|
||||
args: RunArgs,
|
||||
},
|
||||
|
||||
/// Generate random data for a model
|
||||
GenRandomData {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
/// The path to the .json data file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = crate::parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
|
||||
variables: Vec<(String, usize)>,
|
||||
/// random seed for reproducibility (optional)
|
||||
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
|
||||
seed: u64,
|
||||
},
|
||||
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
|
||||
CalibrateSettings {
|
||||
/// The path to the .json calibration data file.
|
||||
|
||||
@@ -65,6 +65,8 @@ use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use tabled::Tabled;
|
||||
use thiserror::Error;
|
||||
use tract_onnx::prelude::IntoTensor;
|
||||
use tract_onnx::prelude::Tensor as TractTensor;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -134,6 +136,17 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
args,
|
||||
),
|
||||
Commands::GenRandomData {
|
||||
model,
|
||||
data,
|
||||
variables,
|
||||
seed,
|
||||
} => gen_random_data(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
variables,
|
||||
seed,
|
||||
),
|
||||
Commands::CalibrateSettings {
|
||||
model,
|
||||
settings_path,
|
||||
@@ -828,6 +841,71 @@ pub(crate) fn gen_circuit_settings(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
/// Generate a circuit settings file
|
||||
pub(crate) fn gen_random_data(
|
||||
model_path: PathBuf,
|
||||
data_path: PathBuf,
|
||||
variables: Vec<(String, usize)>,
|
||||
seed: u64,
|
||||
) -> Result<String, EZKLError> {
|
||||
let mut file = std::fs::File::open(&model_path).map_err(|e| {
|
||||
crate::graph::errors::GraphError::ReadWriteFileError(
|
||||
model_path.display().to_string(),
|
||||
e.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (tract_model, _symbol_values) = Model::load_onnx_using_tract(&mut file, &variables)?;
|
||||
|
||||
let input_facts = tract_model
|
||||
.input_outlets()
|
||||
.map_err(|e| EZKLError::from(e.to_string()))?
|
||||
.iter()
|
||||
.map(|&i| tract_model.outlet_fact(i))
|
||||
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
|
||||
.map_err(|e| EZKLError::from(e.to_string()))?;
|
||||
|
||||
/// Generates a random tensor of a given size and type.
|
||||
fn random(
|
||||
sizes: &[usize],
|
||||
datum_type: tract_onnx::prelude::DatumType,
|
||||
seed: u64,
|
||||
) -> TractTensor {
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
|
||||
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
|
||||
let slice = tensor.as_slice_mut::<f32>().unwrap();
|
||||
slice.iter_mut().for_each(|x| *x = rng.gen());
|
||||
tensor.cast_to_dt(datum_type).unwrap().into_owned()
|
||||
}
|
||||
|
||||
fn tensor_for_fact(fact: &tract_onnx::prelude::TypedFact, seed: u64) -> TractTensor {
|
||||
if let Some(value) = &fact.konst {
|
||||
return value.clone().into_tensor();
|
||||
}
|
||||
|
||||
random(
|
||||
fact.shape
|
||||
.as_concrete()
|
||||
.expect("Expected concrete shape, found: {fact:?}"),
|
||||
fact.datum_type,
|
||||
seed,
|
||||
)
|
||||
}
|
||||
|
||||
let generated = input_facts
|
||||
.iter()
|
||||
.map(|v| tensor_for_fact(v, seed))
|
||||
.collect_vec();
|
||||
|
||||
let data = GraphData::from_tract_data(&generated)?;
|
||||
|
||||
data.save(data_path)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
// not for wasm targets
|
||||
pub(crate) fn init_spinner() -> ProgressBar {
|
||||
let pb = indicatif::ProgressBar::new_spinner();
|
||||
|
||||
@@ -5,7 +5,7 @@ use halo2curves::ff::PrimeField;
|
||||
/// Integer representation of a PrimeField element.
|
||||
pub type IntegerRep = i128;
|
||||
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
/// Converts an integer rep to a PrimeField element.
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
@@ -69,7 +69,7 @@ mod test {
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
|
||||
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
@@ -515,7 +514,7 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
|
||||
/// Input to graph as a datasource
|
||||
/// Always use JSON serialization for GraphData. Seriously.
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq)]
|
||||
#[derive(Clone, Debug, Deserialize, Default, PartialEq, Serialize)]
|
||||
pub struct GraphData {
|
||||
/// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain).
|
||||
pub input_data: DataSource,
|
||||
@@ -557,6 +556,34 @@ impl GraphData {
|
||||
Ok(inputs)
|
||||
}
|
||||
|
||||
// not wasm
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
/// Convert the tract data to tract data
|
||||
pub fn from_tract_data(tensors: &[TractTensor]) -> Result<Self, GraphError> {
|
||||
use tract_onnx::prelude::DatumType;
|
||||
|
||||
let mut input_data = vec![];
|
||||
for tensor in tensors {
|
||||
match tensor.datum_type() {
|
||||
tract_onnx::prelude::DatumType::Bool => {
|
||||
let tensor = tensor.to_array_view::<bool>()?;
|
||||
let tensor = tensor.iter().map(|e| FileSourceInner::Bool(*e)).collect();
|
||||
input_data.push(tensor);
|
||||
}
|
||||
_ => {
|
||||
let cast_tensor = tensor.cast_to_dt(DatumType::F64)?;
|
||||
let tensor = cast_tensor.to_array_view::<f64>()?;
|
||||
let tensor = tensor.iter().map(|e| FileSourceInner::Float(*e)).collect();
|
||||
input_data.push(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(GraphData {
|
||||
input_data: DataSource::File(input_data),
|
||||
output_data: None,
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
pub fn new(input_data: DataSource) -> Self {
|
||||
GraphData {
|
||||
@@ -741,18 +768,6 @@ impl ToPyObject for FileSourceInner {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for GraphData {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut state = serializer.serialize_struct("GraphData", 4)?;
|
||||
state.serialize_field("input_data", &self.input_data)?;
|
||||
state.serialize_field("output_data", &self.output_data)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -280,7 +280,13 @@ impl GraphWitness {
|
||||
})?;
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
let witness: GraphWitness =
|
||||
serde_json::from_reader(reader).map_err(Into::<GraphError>::into)?;
|
||||
|
||||
// check versions match
|
||||
crate::check_version_string_matches(witness.version.as_deref().unwrap_or(""));
|
||||
|
||||
Ok(witness)
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
@@ -572,10 +578,14 @@ impl GraphSettings {
|
||||
// buf reader
|
||||
let reader =
|
||||
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
|
||||
serde_json::from_reader(reader).map_err(|e| {
|
||||
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
|
||||
error!("failed to load settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
})
|
||||
})?;
|
||||
|
||||
crate::check_version_string_matches(&settings.version);
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
@@ -697,6 +707,9 @@ impl GraphCircuit {
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let result: GraphCircuit = bincode::deserialize_from(reader)?;
|
||||
|
||||
// check the versions matche
|
||||
crate::check_version_string_matches(&result.core.settings.version);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -621,16 +621,16 @@ impl Model {
|
||||
/// * `scale` - The scale to use for quantization.
|
||||
/// * `public_params` - Whether to make the params public.
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
fn load_onnx_using_tract(
|
||||
pub(crate) fn load_onnx_using_tract(
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
variables: &[(String, usize)],
|
||||
) -> Result<TractResult, GraphError> {
|
||||
use tract_onnx::tract_hir::internal::GenericFactoid;
|
||||
|
||||
let mut model = tract_onnx::onnx().model_for_read(reader)?;
|
||||
|
||||
let variables: std::collections::HashMap<String, usize> =
|
||||
std::collections::HashMap::from_iter(run_args.variables.clone());
|
||||
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
|
||||
|
||||
for (i, id) in model.clone().inputs.iter().enumerate() {
|
||||
let input = model.node_mut(id.node);
|
||||
@@ -655,7 +655,7 @@ impl Model {
|
||||
}
|
||||
|
||||
let mut symbol_values = SymbolValues::default();
|
||||
for (symbol, value) in run_args.variables.iter() {
|
||||
for (symbol, value) in variables.iter() {
|
||||
let symbol = model.symbols.sym(symbol);
|
||||
symbol_values = symbol_values.with(&symbol, *value as i64);
|
||||
debug!("set {} to {}", symbol, value);
|
||||
@@ -683,7 +683,7 @@ impl Model {
|
||||
) -> Result<ParsedNodes, GraphError> {
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
|
||||
let (model, symbol_values) = Self::load_onnx_using_tract(reader, &run_args.variables)?;
|
||||
|
||||
let scales = VarScales::from_args(run_args);
|
||||
let nodes = Self::nodes_from_graph(
|
||||
@@ -964,7 +964,7 @@ impl Model {
|
||||
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
|
||||
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
|
||||
let (model, _) = Model::load_onnx_using_tract(&mut file, &run_args.variables)?;
|
||||
|
||||
let datum_types: Vec<DatumType> = model
|
||||
.input_outlets()?
|
||||
@@ -1045,8 +1045,8 @@ impl Model {
|
||||
if settings.requires_shuffle() {
|
||||
base_gate.configure_shuffles(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
vars.advices[3..5].try_into()?,
|
||||
vars.advices[0..3].try_into()?,
|
||||
vars.advices[3..6].try_into()?,
|
||||
)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -435,7 +435,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
.collect_vec();
|
||||
|
||||
if requires_dynamic_lookup || requires_shuffle {
|
||||
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
|
||||
let num_cols = 3;
|
||||
for _ in 0..num_cols {
|
||||
let dynamic_lookup =
|
||||
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
|
||||
|
||||
27
src/lib.rs
27
src/lib.rs
@@ -420,3 +420,30 @@ where
|
||||
let b = s[pos + 2..].parse()?;
|
||||
Ok((a, b))
|
||||
}
|
||||
|
||||
/// Check if the version string matches the artifact version
|
||||
/// If the version string does not match the artifact version, log a warning
|
||||
pub fn check_version_string_matches(artifact_version: &str) {
|
||||
if artifact_version == "0.0.0"
|
||||
|| artifact_version == "source - no compatibility guaranteed"
|
||||
|| artifact_version.is_empty()
|
||||
{
|
||||
log::warn!("Artifact version is 0.0.0, skipping version check");
|
||||
return;
|
||||
}
|
||||
|
||||
let version = crate::version();
|
||||
|
||||
if version == "source - no compatibility guaranteed" {
|
||||
log::warn!("Compiled source version is not guaranteed to match artifact version");
|
||||
return;
|
||||
}
|
||||
|
||||
if version != artifact_version {
|
||||
log::warn!(
|
||||
"Version mismatch: CLI version is {} but artifact version is {}",
|
||||
version,
|
||||
artifact_version
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ pub fn get_rep(
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
|
||||
if (*x).abs() > ((base as i128).pow(n as u32)) - 1 {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
@@ -43,8 +43,8 @@ pub fn get_rep(
|
||||
let mut x = x.abs();
|
||||
//
|
||||
for i in (1..rep.len()).rev() {
|
||||
rep[i] = x % base as i128;
|
||||
x /= base as i128;
|
||||
rep[i] = x % base as IntegerRep;
|
||||
x /= base as IntegerRep;
|
||||
}
|
||||
|
||||
Ok(rep)
|
||||
@@ -127,7 +127,7 @@ pub fn decompose(
|
||||
.flatten()
|
||||
.collect::<Vec<IntegerRep>>();
|
||||
|
||||
let output = Tensor::<i128>::new(Some(&resp), &dims)?;
|
||||
let output = Tensor::<IntegerRep>::new(Some(&resp), &dims)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
@@ -558,7 +558,7 @@ impl VarTensor {
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
let mut res: ValTensor<F> = {
|
||||
let mut res: ValTensor<F> =
|
||||
v.enum_map(|coord, k| {
|
||||
|
||||
let step = self.num_inner_cols();
|
||||
@@ -579,12 +579,18 @@ impl VarTensor {
|
||||
prev_cell = Some(cell.clone());
|
||||
} else if coord > 0 && at_beginning_of_column {
|
||||
if let Some(prev_cell) = prev_cell.as_ref() {
|
||||
let cell = cell.cell().ok_or({
|
||||
let cell = if let Some(cell) = cell.cell() {
|
||||
cell
|
||||
} else {
|
||||
error!("Error getting cell: {:?}", (x,y));
|
||||
halo2_proofs::plonk::Error::Synthesis})?;
|
||||
let prev_cell = prev_cell.cell().ok_or({
|
||||
error!("Error getting cell: {:?}", (x,y));
|
||||
halo2_proofs::plonk::Error::Synthesis})?;
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
|
||||
prev_cell
|
||||
} else {
|
||||
error!("Error getting prev cell: {:?}", (x,y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
region.constrain_equal(prev_cell,cell)?;
|
||||
} else {
|
||||
error!("Previous cell was not set");
|
||||
@@ -594,7 +600,8 @@ impl VarTensor {
|
||||
|
||||
Ok(cell)
|
||||
|
||||
})?.into()};
|
||||
})?.into();
|
||||
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
|
||||
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -187,13 +187,14 @@ mod native_tests {
|
||||
|
||||
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
|
||||
|
||||
const LARGE_TESTS: [&str; 6] = [
|
||||
const LARGE_TESTS: [&str; 7] = [
|
||||
"self_attention",
|
||||
"nanoGPT",
|
||||
"multihead_attention",
|
||||
"mobilenet",
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age",
|
||||
];
|
||||
|
||||
const ACCURACY_CAL_TESTS: [&str; 6] = [
|
||||
@@ -395,29 +396,29 @@ mod native_tests {
|
||||
const TESTS_AGGR: [&str; 3] = ["1l_mlp", "1l_flatten", "1l_average"];
|
||||
|
||||
const TESTS_EVM: [&str; 23] = [
|
||||
"1l_mlp",
|
||||
"1l_flatten",
|
||||
"1l_average",
|
||||
"1l_reshape",
|
||||
"1l_sigmoid",
|
||||
"1l_div",
|
||||
"1l_sqrt",
|
||||
"1l_prelu",
|
||||
"1l_var",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
"1l_relu",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_small",
|
||||
"min",
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"idolmodel",
|
||||
"1l_identity",
|
||||
"lstm",
|
||||
"rnn",
|
||||
"quantize_dequantize",
|
||||
"1l_mlp", // 0
|
||||
"1l_flatten", // 1
|
||||
"1l_average", // 2
|
||||
"1l_reshape", // 3
|
||||
"1l_sigmoid", // 4
|
||||
"1l_div", // 5
|
||||
"1l_sqrt", // 6
|
||||
"1l_prelu", // 7
|
||||
"1l_var", // 8
|
||||
"1l_leakyrelu", // 9
|
||||
"1l_gelu_noappx", // 10
|
||||
"1l_relu", // 11
|
||||
"1l_tanh", // 12
|
||||
"2l_relu_sigmoid_small", // 13
|
||||
"2l_relu_small", // 14
|
||||
"min", // 15
|
||||
"max", // 16
|
||||
"1l_max_pool", // 17
|
||||
"idolmodel", // 18
|
||||
"1l_identity", // 19
|
||||
"lstm", // 20
|
||||
"rnn", // 21
|
||||
"quantize_dequantize", // 22
|
||||
];
|
||||
|
||||
const TESTS_EVM_AGGR: [&str; 18] = [
|
||||
@@ -541,7 +542,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -606,7 +607,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -616,7 +617,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -627,7 +628,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(8194), Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -642,7 +643,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -652,7 +653,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -661,7 +662,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -670,7 +671,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -679,7 +680,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -688,7 +689,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -697,7 +698,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -706,7 +707,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -716,7 +717,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -726,7 +727,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -735,7 +736,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -745,7 +746,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -754,7 +755,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -764,7 +765,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -774,7 +775,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -784,7 +785,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -794,7 +795,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -804,7 +805,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -963,7 +964,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=5 {
|
||||
seq!(N in 0..=6 {
|
||||
|
||||
#(#[test_case(LARGE_TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -981,7 +982,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1459,6 +1460,8 @@ mod native_tests {
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
@@ -1475,6 +1478,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
decomp_base,
|
||||
decomp_legs,
|
||||
);
|
||||
|
||||
if tolerance > 0.0 {
|
||||
@@ -1616,6 +1621,8 @@ mod native_tests {
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1634,6 +1641,14 @@ mod native_tests {
|
||||
format!("--commitment={}", commitment),
|
||||
];
|
||||
|
||||
if let Some(decomp_base) = decomp_base {
|
||||
args.push(format!("--decomp-base={}", decomp_base));
|
||||
}
|
||||
|
||||
if let Some(decomp_legs) = decomp_legs {
|
||||
args.push(format!("--decomp-legs={}", decomp_legs));
|
||||
}
|
||||
|
||||
if bounded_lookup_log {
|
||||
args.push("--bounded-log-lookup".to_string());
|
||||
}
|
||||
@@ -1751,6 +1766,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -2035,6 +2052,8 @@ mod native_tests {
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -2467,6 +2486,8 @@ mod native_tests {
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
@@ -873,6 +873,7 @@ def get_examples():
|
||||
'linear_regression',
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age"
|
||||
]
|
||||
examples = []
|
||||
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
|
||||
|
||||
Reference in New Issue
Block a user