Compare commits

..

18 Commits

Author SHA1 Message Date
dante
ca354f9261 refactor: allow for negative stride downsample 2025-03-13 16:33:36 -04:00
dante
ab4997d0c2 chore: update docs and panics (#952) 2025-03-10 11:32:28 -04:00
dante
701e69dd2f fix: handle [] shapes in sort (#954) 2025-03-08 13:17:45 -05:00
dante
f631445e26 docs: document arguments better (#950) 2025-03-05 16:10:50 -05:00
dante
fcbb27677f fix: empty dim len can be 1 (#949) 2025-02-28 23:56:19 -05:00
dante
bc26691bd5 chore: smaller cat dog example (#947) 2025-02-28 10:37:08 -05:00
dante
73c813a81d feat: pass data directly in cli (#939) 2025-02-13 12:35:13 -05:00
dante
ae076aef09 refactor: rm tolerance parameter (#937) 2025-02-11 12:57:18 -05:00
dante
a7544f4060 feat: generalize conv mem layout and ND (#935) 2025-02-10 09:11:58 -05:00
dante
c19fa5218a refactor: enforce max decomp base/legs in args (#936) 2025-02-09 16:15:40 -05:00
rebustron
eb205d0c73 chore: fix typos in comments and docs (#934) 2025-02-08 19:13:17 -05:00
dante
db498f8d7c docs: cat-dog example (#929) 2025-02-08 17:30:13 -05:00
Cypher Pepe
a363c91160 fix: broken links in polycommit.rs and poseidon.rs (#932) 2025-02-08 12:40:53 -05:00
dante
f7f04415fa chore!: add model input/output types to settings (#933)
BREAKING CHANGE: compiled model serialization is not backwards compatible
2025-02-07 16:05:59 -05:00
Jseam
de8d419e5d ci: change to sha hashes (#922) 2025-02-07 12:27:35 -05:00
dante
a38d318923 fix: pypi publication pipeline (#931) 2025-02-05 23:03:21 -05:00
dante
864990fe2d fix: publishing path 2025-02-05 19:57:13 -05:00
dante
29c3e4f977 fix: bump download artifact to v4 2025-02-05 19:05:29 -05:00
57 changed files with 3732 additions and 1157 deletions

View File

@@ -12,10 +12,10 @@ jobs:
contents: read
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -29,10 +29,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -46,10 +46,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -63,10 +63,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -80,10 +80,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -97,10 +97,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -114,10 +114,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -131,10 +131,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -148,10 +148,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -165,10 +165,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -182,10 +182,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true

View File

@@ -24,15 +24,15 @@ jobs:
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
@@ -40,7 +40,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
@@ -176,7 +176,7 @@ jobs:
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
- name: Set up Node.js
uses: actions/setup-node@v3
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
@@ -201,7 +201,7 @@ jobs:
RELEASE_TAG: ${{ github.ref_name }}
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Update version in package.json
@@ -230,13 +230,13 @@ jobs:
NR==30{$0=" specifier: \"" tag "\""}
NR==31{$0=" version: \"" tag "\""}
NR==400{$0=" /@ezkljs/engine@" tag ":"}
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
- name: Use pnpm 8
uses: pnpm/action-setup@v2
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@v3
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
@@ -247,4 +247,4 @@ jobs:
pnpm run build
pnpm publish --no-git-checks
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -10,12 +10,12 @@ jobs:
contents: read
runs-on: kaiju
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -28,37 +28,36 @@ jobs:
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
- name: Set pyproject.toml version to match github tag and rename ezkl to ezkl-gpu
shell: bash
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig > pyproject.toml.tmp
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.tmp > pyproject.toml
- name: rename ezkl.pyi to ezkl-gpu.pyi
run: mv ezkl.pyi ezkl-gpu.pyi
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- name: Set Cargo.toml version to match github tag
- name: Set Cargo.toml version to match github tag and rename ezkl to ezkl-gpu
shell: bash
# the ezkl substitution here looks for the first instance of name = "ezkl" and changes it to "ezkl-gpu"
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
sed "0,/name = \"ezkl\"/s/name = \"ezkl\"/name = \"ezkl-gpu\"/" Cargo.toml.orig > Cargo.toml.tmp
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.tmp > Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig > Cargo.lock
- name: Install required libraries
shell: bash
@@ -66,7 +65,7 @@ jobs:
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
manylinux: auto
@@ -79,7 +78,7 @@ jobs:
pip install ezkl-gpu --no-index --find-links dist --force-reinstall
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: wheels
path: dist
@@ -95,7 +94,7 @@ jobs:
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
needs: [linux]
steps:
- uses: actions/download-artifact@v3
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
with:
name: wheels
- name: List Files
@@ -107,14 +106,14 @@ jobs:
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
with:
packages-dir: ./
packages-dir: ./wheels
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
packages-dir: ./wheels

View File

@@ -26,10 +26,10 @@ jobs:
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@v4
with:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: x64
@@ -48,21 +48,21 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Build wheels
if: matrix.target == 'universal2-apple-darwin'
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
- name: Build wheels
if: matrix.target == 'x86_64'
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
@@ -73,7 +73,7 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: dist-macos-${{ matrix.target }}
path: dist
@@ -87,10 +87,10 @@ jobs:
matrix:
target: [x64, x86]
steps:
- uses: actions/checkout@v4
with:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: ${{ matrix.target }}
@@ -113,14 +113,14 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
@@ -130,7 +130,7 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0 #v4.6.0
with:
name: dist-windows-${{ matrix.target }}
path: dist
@@ -144,10 +144,10 @@ jobs:
matrix:
target: [x86_64]
steps:
- uses: actions/checkout@v4
with:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: x64
@@ -176,7 +176,7 @@ jobs:
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
manylinux: auto
@@ -203,7 +203,7 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: dist-linux-${{ matrix.target }}
path: dist
@@ -218,10 +218,10 @@ jobs:
target:
- x86_64-unknown-linux-musl
steps:
- uses: actions/checkout@v4
with:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: x64
@@ -250,7 +250,7 @@ jobs:
sudo apt-get update && sudo apt-get install -y pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
manylinux: musllinux_1_2
@@ -271,7 +271,7 @@ jobs:
python3 -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: dist-musllinux-${{ matrix.target }}
path: dist
@@ -287,10 +287,10 @@ jobs:
- target: aarch64-unknown-linux-musl
arch: aarch64
steps:
- uses: actions/checkout@v4
with:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
@@ -313,13 +313,13 @@ jobs:
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.platform.target }}
manylinux: musllinux_1_2
args: --release --out dist --features python-bindings
- uses: uraimo/run-on-arch-action@v2.8.1
- uses: uraimo/run-on-arch-action@5397f9e30a9b62422f302092631c99ae1effcd9e #v2.8.1
name: Install built wheel
with:
arch: ${{ matrix.platform.arch }}
@@ -334,7 +334,7 @@ jobs:
python3 -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: dist-musllinux-${{ matrix.platform.target }}
path: dist
@@ -347,26 +347,26 @@ jobs:
if: "startsWith(github.ref, 'refs/tags/')"
needs: [macos, windows, linux, musllinux, musllinux-cross]
steps:
- uses: actions/download-artifact@v3
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
with:
pattern: dist-*
merge-multiple: true
path: dist
path: wheels
- name: List Files
run: ls -R
# # publishes to TestPyPI
# - name: Publish package distribution to TestPyPI
# uses: pypa/gh-action-pypi-publish@unstable/v1
# uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
# with:
# repository-url: https://test.pypi.org/legacy/
# packages-dir: ./
# publishes to PyPI
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@unstable/v1
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
with:
packages-dir: ./
packages-dir: ./wheels
doc-publish:
@@ -377,7 +377,7 @@ jobs:
needs: pypi-publish
steps:
- uses: actions/checkout@v4
with:
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1

View File

@@ -30,7 +30,7 @@ jobs:
- name: Create Github Release
id: create-release
uses: softprops/action-gh-release@v1
uses: softprops/action-gh-release@c95fe1489396fe8a9eb87c0abf8aa5b2ef267fda #v2.2.1
with:
token: ${{ secrets.RELEASE_TOKEN }}
tag_name: ${{ env.EZKL_VERSION }}
@@ -49,14 +49,14 @@ jobs:
RUST_BACKTRACE: 1
PCRE2_SYS_STATIC: 1
steps:
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Checkout repo
uses: actions/checkout@v4
with:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
@@ -90,7 +90,7 @@ jobs:
echo "ASSET=build-artifacts/ezkl-linux-gpu.tar.gz" >> $GITHUB_ENV
- name: Upload release archive
uses: actions/upload-release-asset@v1.0.2
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
env:
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
with:
@@ -119,33 +119,33 @@ jobs:
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: aarch64-unknown-linux-gnu
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Get release version from tag
@@ -170,7 +170,7 @@ jobs:
fi
- name: Install Rust
uses: dtolnay/rust-toolchain@nightly
uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
with:
target: ${{ matrix.target }}
@@ -196,7 +196,7 @@ jobs:
echo "target flag is: ${{ env.TARGET_FLAGS }}"
echo "target dir is: ${{ env.TARGET_DIR }}"
- name: Build release binary (no asm or metal)
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
@@ -233,7 +233,7 @@ jobs:
echo "ASSET=build-artifacts/ezkl-win.zip" >> $GITHUB_ENV
- name: Upload release archive
uses: actions/upload-release-asset@v1.0.2
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
env:
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
with:

View File

@@ -25,12 +25,12 @@ jobs:
contents: read
runs-on: large-self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -45,12 +45,12 @@ jobs:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Build
@@ -61,12 +61,12 @@ jobs:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Docs
@@ -77,15 +77,15 @@ jobs:
contents: read
runs-on: ubuntu-latest-32-cores
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -102,25 +102,21 @@ jobs:
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
# with:
# toolchain: nightly-2024-07-18
# toolchain: nightly-2025-02-17
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# with:
# crate: cargo-nextest
# locked: true
# - uses: mwilliamson/setup-wasmtime-action@v2
# - uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
# with:
# wasmtime-version: "3.0.1"
# - name: Install wasm32-wasi
# run: rustup target add wasm32-wasi
# - name: Install cargo-wasi
# run: cargo install cargo-wasi
# # - name: Matmul overflow (wasi)
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# # - name: Conv overflow (wasi)
@@ -139,25 +135,21 @@ jobs:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@v2
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Install cargo-wasi
run: cargo install cargo-wasi
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
@@ -176,25 +168,21 @@ jobs:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@v2
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Install cargo-wasi
run: cargo install cargo-wasi
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
@@ -213,15 +201,15 @@ jobs:
contents: read
runs-on: ubuntu-latest-16-cores
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -233,25 +221,25 @@ jobs:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.12.1
version: "v0.12.1"
- uses: nanasess/setup-chromedriver@v2
- uses: nanasess/setup-chromedriver@e93e57b843c0c92788f22483f1a31af8ee48db25 #v2.3.0
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -262,20 +250,22 @@ jobs:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: Large 1D Conv Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_7_expects -- --include-ignored
- name: MNIST Gan Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
- name: NanoGPT Mock
@@ -292,8 +282,6 @@ jobs:
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
- name: public outputs and bounded lookup log
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
- name: public outputs and tolerance > 0
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
- name: kzg inputs
@@ -329,32 +317,32 @@ jobs:
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- uses: actions/checkout@v3
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@v2
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Use Node.js 18.12.1
uses: actions/setup-node@v3
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --frozen-lockfile
@@ -414,28 +402,28 @@ jobs:
# runs-on: macos-13
# # needs: [build, library-tests, docs]
# steps:
# - uses: actions/checkout@v4
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
# with:
# toolchain: nightly-2024-07-18
# toolchain: nightly-2025-02-17
# override: true
# components: rustfmt, clippy
# - uses: jetli/wasm-pack-action@v0.4.0
# - uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
# with:
# # Pin to version 0.12.1
# version: 'v0.12.1'
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2024-07-18
# - uses: actions/checkout@v3
# run: rustup component add rust-src --toolchain nightly-2025-02-17
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# with:
# persist-credentials: false
# - name: Use pnpm 8
# uses: pnpm/action-setup@v2
# uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
# with:
# version: 8
# - uses: baptiste0928/cargo-install@v1
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# with:
# crate: cargo-nextest
# locked: true
@@ -448,15 +436,15 @@ jobs:
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.12.1
version: "v0.12.1"
@@ -464,16 +452,16 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@v2
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Use Node.js 18.12.1
uses: actions/setup-node@v3
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
cache: "pnpm"
@@ -483,7 +471,7 @@ jobs:
env:
CI: false
NODE_ENV: development
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -529,18 +517,18 @@ jobs:
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
# with:
# toolchain: nightly-2024-07-18
# toolchain: nightly-2025-02-17
# override: true
# components: rustfmt, clippy
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
# - uses: actions/checkout@v3
# - uses: baptiste0928/cargo-install@v1
# run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# with:
# crate: cargo-nextest
# locked: true
@@ -567,15 +555,15 @@ jobs:
runs-on: self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -587,15 +575,15 @@ jobs:
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@v1
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
# with:
# toolchain: nightly-2024-07-18
# toolchain: nightly-2025-02-17
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# with:
# crate: cargo-nextest
# locked: true
@@ -608,15 +596,15 @@ jobs:
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -629,15 +617,15 @@ jobs:
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -654,15 +642,15 @@ jobs:
runs-on: ubuntu-latest-32-cores
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -675,15 +663,15 @@ jobs:
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Install cmake
@@ -705,18 +693,18 @@ jobs:
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -756,18 +744,18 @@ jobs:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@v4
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -781,6 +769,8 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
- name: Cat and Dog notebook
run: source .env/bin/activate; cargo nextest run py_tests::tests::cat_and_dog_notebook_
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
@@ -812,20 +802,20 @@ jobs:
contents: read
runs-on: macos-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2025-02-17-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
swift-package-tests:
permissions:
@@ -834,12 +824,12 @@ jobs:
needs: [ios-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Build EzklCoreBindings

View File

@@ -12,12 +12,12 @@ jobs:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
@@ -30,4 +30,3 @@ jobs:
run: zizmor .

View File

@@ -19,8 +19,8 @@ jobs:
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
with:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Extract TAG from github.ref_name
@@ -34,7 +34,7 @@ jobs:
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
- name: Install Rust (nightly)
uses: actions-rs/toolchain@v1
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly
override: true

View File

@@ -11,12 +11,12 @@ jobs:
contents: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.2
uses: mathieudutour/github-tag-action@a22cf08638b34d5badda920f9daf6e72c477b07b #v6.2
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
@@ -46,7 +46,7 @@ jobs:
git tag $RELEASE_TAG
- name: Push changes
uses: ad-m/github-push-action@master
uses: ad-m/github-push-action@77c5b412c50b723d2a4fbc6d71fb5723bcd439aa #master
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
with:

View File

@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
[package]
name = "ezkl"
version = "0.0.0"
edition = "2021"
edition = "2024"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -73,6 +73,8 @@ impl Circuit<Fr> for MyCircuit {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
}),
)
.unwrap();

View File

@@ -69,6 +69,7 @@ impl Circuit<Fr> for MyCircuit {
stride: vec![1, 1],
kernel_shape: vec![2, 2],
normalized: false,
data_format: DataFormat::NCHW,
}),
)
.unwrap();

View File

@@ -23,8 +23,6 @@ use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
const L: usize = 10;
#[derive(Clone, Debug)]
struct MyCircuit {
image: ValTensor<Fr>,
@@ -40,7 +38,7 @@ impl Circuit<Fr> for MyCircuit {
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, 10>::configure(cs, ())
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::configure(cs, ())
}
fn synthesize(
@@ -48,7 +46,7 @@ impl Circuit<Fr> for MyCircuit {
config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
@@ -59,7 +57,7 @@ fn runposeidon(c: &mut Criterion) {
let mut group = c.benchmark_group("poseidon");
for size in [64, 784, 2352, 12288].iter() {
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::num_rows(*size)
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::num_rows(*size)
as f32)
.log2()
.ceil() as u32;
@@ -67,7 +65,7 @@ fn runposeidon(c: &mut Criterion) {
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
let _output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.to_vec())
.unwrap();
let mut image = Tensor::from(message.into_iter().map(Value::known));

View File

@@ -22,7 +22,7 @@ fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
let mut rng = rand::thread_rng();
(0..size)
.map(|_i| {
if rng.gen::<f64>() < zero_probability {
if rng.r#gen::<f64>() < zero_probability {
ValType::Constant(F::ZERO)
} else {
ValType::Constant(F::ONE) // Or some other non-zero value

View File

@@ -1,20 +1,52 @@
# EZKL Security Note: Quantization-Induced Model Backdoors
# EZKL Security Note: Quantization-Activated Model Backdoors
> Note: this only affects a situation where a party separate to an application's developer has access to the model's weights and can modify them. This is a common scenario in adversarial machine learning research, but can be less common in real-world applications. If you're building your models in house and deploying them yourself, this is less of a concern. If you're building a permisionless system where anyone can submit models, this is more of a concern.
## Model backdoors and provenance
Models processed through EZKL's quantization step can harbor backdoors that are dormant in the original full-precision model but activate during quantization. These backdoors force specific outputs when triggered, with impact varying by application.
Machine learning models inherently suffer from robustness issues, which can lead to various
kinds of attacks, from backdoors to evasion attacks. These vulnerabilities are a direct byproductof how machine learning models learn and cannot be remediated.
Key Factors:
We say a model has a backdoor whenever a specific attacker-chosen trigger in the input leads
to the model misbehaving. For instance, if we have an image classifier discriminating cats from dogs, the ability to turn any image of a cat into an image classified as a dog by changing a specific pixel pattern constitutes a backdoor.
- Larger models increase attack feasibility through more parameter capacity
- Smaller quantization scales facilitate attacks by allowing greater weight modifications
- Rebase ratio of 1 enables exploitation of convolutional layer consistency
Backdoors can be introduced using many different vectors. An attacker can introduce a
backdoor using traditional security vulnerabilities. For instance, they could directly alter the file containing model weights or dynamically hack the Python code of the model. In addition, backdoors can be introduced by the training data through a process known as poisoning. In this case, an attacker adds malicious data points to the dataset before the model is trained so that the model learns to associate the backdoor trigger with the intended misbehavior.
Limitations:
All these vectors constitute a whole range of provenance challenges, as any component of an
AI system can virtually be an entrypoint for a backdoor. Although provenance is already a
concern with traditional code, the issue is exacerbated with AI, as retraining a model is
cost-prohibitive. It is thus impractical to translate the “recompile it yourself” thinking to AI.
- Attack effectiveness depends on calibration settings and internal rescaling operations.
## Quantization activated backdoors
Backdoors are a generic concern in AI that is outside the scope of EZKL. However, EZKL may
activate a specific subset of backdoors. Several academic papers have demonstrated the
possibility, both in theory and in practice, of implanting undetectable and inactive backdoors in a full precision model that can be reactivated by quantization.
An external attacker may trick the user of an application running EZKL into loading a model
containing a quantization backdoor. This backdoor is active in the resulting model and circuit but not in the full-precision model supplied to EZKL, compromising the integrity of the target application and the resulting proof.
### When is this a concern for me as a user?
Any untrusted component in your AI stack may be a backdoor vector. In practice, the most
sensitive parts include:
- Datasets downloaded from the web or containing crowdsourced data
- Models downloaded from the web even after finetuning
- Untrusted software dependencies (well-known frameworks such as PyTorch can typically
be considered trusted)
- Any component loaded through an unsafe serialization format, such as Pickle.
Because backdoors are inherent to ML and cannot be eliminated, reviewing the provenance of
these sensitive components is especially important.
### Responsibilities of the user and EZKL
As EZKL cannot prevent backdoored models from being used, it is the responsibility of the user to review the provenance of all the components in their AI stack to ensure that no backdoor could have been implanted. EZKL shall not be held responsible for misleading prediction proofs resulting from using a backdoored model or for any harm caused to a system or its users due to a misbehaving model.
### Limitations:
- Attack effectiveness depends on calibration settings and internal rescaling operations.
- Further research needed on backdoor persistence through witness/proof stages.
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model.
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model in pytorch or onnx-runtime as difference in evaluation could reveal a backdoor.
References:

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '19.0.4'
release = '0.0.0'
version = release

View File

@@ -32,6 +32,7 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -208,6 +209,8 @@ where
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
};
let x = config
.layer_config

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
# download tess data
# check if first argument has been set
if [ ! -z "$1" ]; then
DATA_DIR=$1
else
DATA_DIR=data
fi
echo "Downloading data to $DATA_DIR"
if [ ! -d "$DATA_DIR/CATDOG" ]; then
kaggle datasets download tongpython/cat-and-dog -p $DATA_DIR/CATDOG --unzip
fi

View File

@@ -0,0 +1,106 @@
{
"input_data": [
[
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127,
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127
]
]
}

Binary file not shown.

View File

@@ -0,0 +1 @@
{"run_args":{"input_scale":7,"param_scale":7,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"num_inner_cols":2,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private","rebase_frac_zero_constants":false,"check_mode":"UNSAFE","commitment":"KZG","decomp_base":16384,"decomp_legs":2,"bounded_log_lookup":false,"ignore_range_check_inputs_outputs":false},"num_rows":54,"total_assignments":109,"total_const_size":4,"total_dynamic_col_size":0,"max_dynamic_input_len":0,"num_dynamic_lookups":0,"num_shuffles":0,"total_shuffle_col_size":0,"model_instance_shapes":[[1,1]],"model_output_scales":[7],"model_input_scales":[7],"module_sizes":{"polycommit":[],"poseidon":[0,[0]]},"required_lookups":[],"required_range_checks":[[-1,1],[0,16383]],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null,"timestamp":1739396322131,"input_types":["F32"],"output_types":["F32"]}

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-07-18"
channel = "nightly-2025-02-17"
components = ["rustfmt", "clippy"]

View File

@@ -28,6 +28,8 @@ use std::env;
#[tokio::main(flavor = "current_thread")]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn main() {
use log::debug;
let args = Cli::parse();
if let Some(generator) = args.generator {
@@ -42,7 +44,7 @@ pub async fn main() {
} else {
info!("Running with CPU");
}
info!(
debug!(
"command: \n {}",
&command.as_json().to_colored_json_auto().unwrap()
);

View File

@@ -4,8 +4,8 @@ use crate::circuit::modules::poseidon::{
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::{CheckMode, Tolerance};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::TestDataSource;
@@ -155,9 +155,6 @@ impl pyo3::ToPyObject for PyG1Affine {
#[derive(Clone)]
#[gen_stub_pyclass]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
pub tolerance: f32,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing inputs
pub input_scale: crate::Scale,
@@ -225,7 +222,6 @@ impl From<PyRunArgs> for RunArgs {
fn from(py_run_args: PyRunArgs) -> Self {
RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
tolerance: Tolerance::from(py_run_args.tolerance),
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
num_inner_cols: py_run_args.num_inner_cols,
@@ -250,7 +246,6 @@ impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
tolerance: self.tolerance.val,
input_scale: self.input_scale,
param_scale: self.param_scale,
num_inner_cols: self.num_inner_cols,
@@ -337,6 +332,8 @@ enum PyInputType {
Int,
///
TDim,
///
Unknown,
}
impl From<InputType> for PyInputType {
@@ -348,6 +345,7 @@ impl From<InputType> for PyInputType {
InputType::F64 => PyInputType::F64,
InputType::Int => PyInputType::Int,
InputType::TDim => PyInputType::TDim,
InputType::Unknown => PyInputType::Unknown,
}
}
}
@@ -361,6 +359,7 @@ impl From<PyInputType> for InputType {
PyInputType::F64 => InputType::F64,
PyInputType::Int => InputType::Int,
PyInputType::TDim => InputType::TDim,
PyInputType::Unknown => InputType::Unknown,
}
}
}
@@ -375,6 +374,7 @@ impl FromStr for PyInputType {
"f64" => Ok(PyInputType::F64),
"int" => Ok(PyInputType::Int),
"tdim" => Ok(PyInputType::TDim),
"unknown" => Ok(PyInputType::Unknown),
_ => Err("Invalid value for InputType".to_string()),
}
}
@@ -592,7 +592,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
/// List of field elements represented as strings
///
/// vk_path: str
/// Path to the verification key
@@ -651,7 +651,7 @@ fn kzg_commit(
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
/// List of field elements represented as strings
///
/// vk_path: str
/// Path to the verification key
@@ -1009,7 +1009,7 @@ fn gen_random_data(
/// bool
///
#[pyfunction(signature = (
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
data = String::from(DEFAULT_CALIBRATION_FILE),
model = PathBuf::from(DEFAULT_MODEL),
settings = PathBuf::from(DEFAULT_SETTINGS),
target = CalibrationTarget::default(), // default is "resources
@@ -1021,7 +1021,7 @@ fn gen_random_data(
#[gen_stub_pyfunction]
fn calibrate_settings(
py: Python,
data: PathBuf,
data: String,
model: PathBuf,
settings: PathBuf,
target: CalibrationTarget,
@@ -1076,7 +1076,7 @@ fn calibrate_settings(
/// Python object containing the witness values
///
#[pyfunction(signature = (
data=PathBuf::from(DEFAULT_DATA),
data=String::from(DEFAULT_DATA),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
output=PathBuf::from(DEFAULT_WITNESS),
vk_path=None,
@@ -1085,7 +1085,7 @@ fn calibrate_settings(
#[gen_stub_pyfunction]
fn gen_witness(
py: Python,
data: PathBuf,
data: String,
model: PathBuf,
output: Option<PathBuf>,
vk_path: Option<PathBuf>,
@@ -1754,7 +1754,7 @@ fn create_evm_vka(
/// bool
///
#[pyfunction(signature = (
input_data=PathBuf::from(DEFAULT_DATA),
input_data=String::from(DEFAULT_DATA),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
@@ -1763,7 +1763,7 @@ fn create_evm_vka(
#[gen_stub_pyfunction]
fn create_evm_data_attestation(
py: Python,
input_data: PathBuf,
input_data: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
@@ -1824,7 +1824,7 @@ fn create_evm_data_attestation(
#[gen_stub_pyfunction]
fn setup_test_evm_witness(
py: Python,
data_path: PathBuf,
data_path: String,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
input_source: PyTestDataSource,
@@ -1902,7 +1902,7 @@ fn deploy_evm(
fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
input_data: PathBuf,
input_data: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
@@ -1945,7 +1945,7 @@ fn deploy_da_evm(
/// does the verifier use data attestation ?
///
/// addr_vk: str
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
/// The address of the separate VK contract (if the verifier key is rendered as a separate contract)
/// Returns
/// -------
/// bool

View File

@@ -1,7 +1,7 @@
/*
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
use std::collections::HashMap;

View File

@@ -1,7 +1,7 @@
/*
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
pub mod poseidon_params;

View File

@@ -21,7 +21,10 @@ pub enum BaseOp {
/// Matches a [BaseOp] to an operation over inputs
impl BaseOp {
/// forward func
/// forward func for non-accumulating operations
/// # Panics
/// Panics if called on an accumulating operation
/// # Examples
pub fn nonaccum_f<
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
>(
@@ -37,7 +40,9 @@ impl BaseOp {
}
}
/// forward func
/// forward func for accumulating operations
/// # Panics
/// Panics if called on a non-accumulating operation
pub fn accum_f<
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
>(

View File

@@ -20,7 +20,6 @@ use crate::{
circuit::{
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
@@ -85,55 +84,6 @@ impl CheckMode {
}
}
#[allow(missing_docs)]
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
pub struct Tolerance {
pub val: f32,
pub scale: utils::F32,
}
impl std::fmt::Display for Tolerance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.2}", self.val)
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Tolerance {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl FromStr for Tolerance {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(val) = s.parse::<f32>() {
Ok(Tolerance {
val,
scale: utils::F32(1.0),
})
} else {
Err(
"Invalid tolerance value provided. It should expressed as a percentage (f32)."
.to_string(),
)
}
}
}
impl From<f32> for Tolerance {
fn from(value: f32) -> Self {
Tolerance {
val: value,
scale: utils::F32(1.0),
}
}
}
#[cfg(feature = "python-bindings")]
/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python)
impl IntoPy<PyObject> for CheckMode {
@@ -158,29 +108,6 @@ impl<'source> FromPyObject<'source> for CheckMode {
}
}
#[cfg(feature = "python-bindings")]
/// Converts Tolerance into a PyObject (Required for Tolerance to be compatible with Python)
impl IntoPy<PyObject> for Tolerance {
fn into_py(self, py: Python) -> PyObject {
(self.val, self.scale.0).to_object(py)
}
}
#[cfg(feature = "python-bindings")]
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
impl<'source> FromPyObject<'source> for Tolerance {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
Ok(Tolerance {
val,
scale: utils::F32(scale),
})
} else {
Err(PyValueError::new_err("Invalid tolerance value provided. "))
}
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct DynamicLookups {

View File

@@ -1,9 +1,9 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
circuit::{layouts, utils},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
@@ -57,11 +57,13 @@ pub enum HybridOp {
stride: Vec<usize>,
kernel_shape: Vec<usize>,
normalized: bool,
data_format: DataFormat,
},
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
data_format: DataFormat,
},
ReduceMin {
axes: Vec<usize>,
@@ -77,7 +79,6 @@ pub enum HybridOp {
axes: Vec<usize>,
},
Output {
tol: Tolerance,
decomp: bool,
},
Greater,
@@ -154,10 +155,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
kernel_shape,
normalized,
normalized, data_format
} => format!(
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
padding, stride, kernel_shape, normalized
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
padding, stride, kernel_shape, normalized, data_format
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
@@ -165,9 +166,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => format!(
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
padding, stride, pool_dims, data_format
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
@@ -181,8 +183,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale, output_scale, axes
)
}
HybridOp::Output { tol, decomp } => {
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
HybridOp::Output { decomp } => {
format!("OUTPUT (decomp={})", decomp)
}
HybridOp::Greater => "GREATER".to_string(),
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
@@ -239,6 +241,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
normalized,
data_format,
} => layouts::sumpool(
config,
region,
@@ -247,6 +250,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
*normalized,
*data_format,
)?,
HybridOp::Recip {
input_scale,
@@ -287,6 +291,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => layouts::max_pool(
config,
region,
@@ -294,6 +299,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
*data_format,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?
@@ -319,14 +325,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
*output_scale,
axes,
)?,
HybridOp::Output { tol, decomp } => layouts::output(
config,
region,
values[..].try_into()?,
tol.scale,
tol.val,
*decomp,
)?,
HybridOp::Output { decomp } => {
layouts::output(config, region, values[..].try_into()?, *decomp)?
}
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
HybridOp::GreaterEqual => {
layouts::greater_equal(config, region, values[..].try_into()?)?

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,8 @@
use std::any::Any;
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::prelude::DatumType;
use crate::{
graph::quantize_tensor,
@@ -96,6 +98,8 @@ pub enum InputType {
Int,
///
TDim,
///
Unknown,
}
impl InputType {
@@ -132,6 +136,7 @@ impl InputType {
let int_input = input.clone().to_i64().unwrap();
*input = T::from_i64(int_input).unwrap();
}
InputType::Unknown => {}
}
}
}
@@ -152,6 +157,30 @@ impl std::str::FromStr for InputType {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<DatumType> for InputType {
/// # Panics
/// Panics if the datum type is not supported
fn from(datum_type: DatumType) -> Self {
match datum_type {
DatumType::Bool => InputType::Bool,
DatumType::F16 => InputType::F16,
DatumType::F32 => InputType::F32,
DatumType::F64 => InputType::F64,
DatumType::I8 => InputType::Int,
DatumType::I16 => InputType::Int,
DatumType::I32 => InputType::Int,
DatumType::I64 => InputType::Int,
DatumType::U8 => InputType::Int,
DatumType::U16 => InputType::Int,
DatumType::U32 => InputType::Int,
DatumType::U64 => InputType::Int,
DatumType::TDim => InputType::TDim,
_ => unimplemented!(),
}
}
}
///
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Input {
@@ -290,13 +319,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self

View File

@@ -4,6 +4,7 @@ use crate::{
utils::{self, F32},
},
tensor::{self, Tensor, TensorError},
tensor::{DataFormat, KernelFormat},
};
use super::{base::BaseOp, *};
@@ -43,10 +44,12 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
},
Downsample {
axis: usize,
stride: usize,
stride: isize,
modulo: usize,
},
DeConv {
@@ -54,6 +57,8 @@ pub enum PolyOp {
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
},
Add,
Sub,
@@ -103,13 +108,8 @@ pub enum PolyOp {
}
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -165,10 +165,12 @@ impl<
stride,
padding,
group,
data_format,
kernel_format,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
"CONV (stride={:?}, padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, group, data_format, kernel_format
)
}
PolyOp::DeConv {
@@ -176,10 +178,12 @@ impl<
padding,
output_padding,
group,
data_format,
kernel_format,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, output_padding, group, data_format, kernel_format
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -242,6 +246,8 @@ impl<
padding,
stride,
group,
data_format,
kernel_format,
} => layouts::conv(
config,
region,
@@ -249,6 +255,8 @@ impl<
padding,
stride,
*group,
*data_format,
*kernel_format,
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
@@ -309,6 +317,8 @@ impl<
output_padding,
stride,
group,
data_format,
kernel_format,
} => layouts::deconv(
config,
region,
@@ -317,6 +327,8 @@ impl<
output_padding,
stride,
*group,
*data_format,
*kernel_format,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,

View File

@@ -1,5 +1,6 @@
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{DataFormat, KernelFormat};
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -1065,6 +1066,8 @@ mod conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1220,6 +1223,8 @@ mod conv_col_ultra_overflow {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1377,6 +1382,8 @@ mod conv_relu_col_ultra_overflow {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis);

View File

@@ -1,6 +1,6 @@
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{generate, Generator, Shell};
use clap_complete::{Generator, Shell, generate};
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::{Commitments, RunArgs, pfsys::ProofType};
use crate::circuit::CheckMode;
use crate::graph::TestDataSource;
@@ -360,8 +360,13 @@ pub fn get_styles() -> clap::builder::Styles {
}
/// Print completions for the given generator
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
pub fn print_completions<G: Generator>(r#gen: G, cmd: &mut Command) {
generate(
r#gen,
cmd,
cmd.get_name().to_string(),
&mut std::io::stdout(),
);
}
#[allow(missing_docs)]
@@ -396,8 +401,9 @@ pub enum Commands {
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
data: Option<String>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
@@ -429,7 +435,7 @@ pub enum Commands {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to the .json data file
/// The path to the .json data file to output
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// Hand-written parser for graph variables, eg. batch_size=1
@@ -442,8 +448,9 @@ pub enum Commands {
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
CalibrateSettings {
/// The path to the .json calibration data file.
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
data: Option<String>,
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
@@ -626,8 +633,9 @@ pub enum Commands {
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
data: Option<String>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
@@ -653,8 +661,9 @@ pub enum Commands {
#[arg(long, value_hint = clap::ValueHint::Other)]
addr: H160Flag,
/// The path to the .json data file.
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
data: Option<String>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
@@ -771,7 +780,7 @@ pub enum Commands {
/// view functions that return the data that the network
/// ingests as inputs.
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
data: Option<String>,
/// The path to the witness file. This is needed for proof swapping for kzg commitments.
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
@@ -866,8 +875,9 @@ pub enum Commands {
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
data: Option<String>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,

View File

@@ -383,7 +383,7 @@ pub async fn deploy_contract_via_solidity(
///
pub async fn deploy_da_verifier_via_solidity(
settings_path: PathBuf,
input: PathBuf,
input: String,
sol_code_path: PathBuf,
rpc_url: Option<&str>,
runs: usize,
@@ -391,7 +391,7 @@ pub async fn deploy_da_verifier_via_solidity(
) -> Result<H160, EthError> {
let (client, client_address) = setup_eth_backend(rpc_url, private_key).await?;
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
let input = GraphData::from_str(&input).map_err(|_| EthError::GraphData)?;
let settings = GraphSettings::load(&settings_path).map_err(|_| EthError::GraphSettings)?;
@@ -688,10 +688,10 @@ fn parse_call_to_account(call_to_account: CallToAccount) -> Result<ParsedCallToA
pub async fn update_account_calls(
addr: H160,
input: PathBuf,
input: String,
rpc_url: Option<&str>,
) -> Result<(), EthError> {
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
let input = GraphData::from_str(&input).map_err(|_| EthError::GraphData)?;
// The data that will be stored in the test contracts that will eventually be read from.
let mut calls_to_accounts = vec![];

View File

@@ -1,5 +1,6 @@
use crate::circuit::region::RegionSettings;
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::CheckMode;
use crate::circuit::region::RegionSettings;
use crate::commands::CalibrationTarget;
use crate::eth::{
deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_multi_sol,
@@ -12,21 +13,21 @@ use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::{
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
};
use crate::pfsys::{
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::EZKL_BUF_CAPACITY;
use crate::{commands::*, EZKLError};
use crate::{Commitments, RunArgs};
use crate::{EZKLError, commands::*};
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::plonk::{self, Circuit};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
@@ -39,7 +40,6 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
use halo2_proofs::poly::kzg::{
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
use halo2_solidity_verifier;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
@@ -50,12 +50,12 @@ use instant::Instant;
use itertools::Itertools;
use log::debug;
use log::{info, trace, warn};
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde::de::DeserializeOwned;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::Config;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
use std::fs::File;
use std::io::BufWriter;
use std::io::{Cursor, Write};
@@ -516,7 +516,9 @@ fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
.status()
.is_err()
{
log::warn!("bash is not installed on this system, trying to run the install script with sh (may fail)");
log::warn!(
"bash is not installed on this system, trying to run the install script with sh (may fail)"
);
"sh"
} else {
"bash"
@@ -725,7 +727,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLErr
pub(crate) async fn gen_witness(
compiled_circuit_path: PathBuf,
data: PathBuf,
data: String,
output: Option<PathBuf>,
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
@@ -733,7 +735,7 @@ pub(crate) async fn gen_witness(
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
let data: GraphData = GraphData::from_path(data)?;
let data = GraphData::from_str(&data)?;
let settings = circuit.settings().clone();
let vk = if let Some(vk) = vk_path {
@@ -876,7 +878,7 @@ pub(crate) fn gen_random_data(
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
let slice = tensor.as_slice_mut::<f32>().unwrap();
slice.iter_mut().for_each(|x| *x = rng.gen());
slice.iter_mut().for_each(|x| *x = rng.r#gen());
tensor.cast_to_dt(datum_type).unwrap().into_owned()
}
@@ -1044,7 +1046,7 @@ impl AccuracyResults {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn calibrate(
model_path: PathBuf,
data: PathBuf,
data: String,
settings_path: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: f64,
@@ -1058,7 +1060,7 @@ pub(crate) async fn calibrate(
use crate::fieldutils::IntegerRep;
let data = GraphData::from_path(data)?;
let data = GraphData::from_str(&data)?;
// load the pre-generated settings
let settings = GraphSettings::load(&settings_path)?;
// now retrieve the run args
@@ -1522,7 +1524,7 @@ pub(crate) async fn create_evm_data_attestation(
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
input: PathBuf,
input: String,
witness: Option<PathBuf>,
) -> Result<String, EZKLError> {
#[allow(unused_imports)]
@@ -1536,7 +1538,7 @@ pub(crate) async fn create_evm_data_attestation(
// if input is not provided, we just instantiate dummy input data
let data =
GraphData::from_path(input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
GraphData::from_str(&input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
// The number of input and output instances we attest to for the single call data attestation
let mut input_len = None;
@@ -1625,7 +1627,7 @@ pub(crate) async fn create_evm_data_attestation(
}
pub(crate) async fn deploy_da_evm(
data: PathBuf,
data: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
@@ -1868,7 +1870,7 @@ pub(crate) fn setup(
}
pub(crate) async fn setup_test_evm_witness(
data_path: PathBuf,
data_path: String,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
rpc_url: Option<String>,
@@ -1877,7 +1879,7 @@ pub(crate) async fn setup_test_evm_witness(
) -> Result<String, EZKLError> {
use crate::graph::TestOnChainData;
let mut data = GraphData::from_path(data_path)?;
let mut data = GraphData::from_str(&data_path)?;
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
// if both input and output are from files fail
@@ -1905,7 +1907,7 @@ pub(crate) async fn setup_test_evm_witness(
use crate::pfsys::ProofType;
pub(crate) async fn test_update_account_calls(
addr: H160Flag,
data: PathBuf,
data: String,
rpc_url: Option<String>,
) -> Result<String, EZKLError> {
use crate::eth::update_account_calls;

View File

@@ -33,7 +33,7 @@ pub enum GraphError {
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
#[error("a node has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]

View File

@@ -528,6 +528,27 @@ impl GraphData {
}
}
/// Loads graph input data from a string, first seeing if it is a file path or JSON data
/// If it is a file path, it will load the data from the file
/// Otherwise, it will attempt to parse the string as JSON data
///
/// # Arguments
/// * `data` - String containing the input data
/// # Returns
/// A new GraphData instance containing the loaded data
pub fn from_str(data: &str) -> Result<Self, GraphError> {
let graph_input = serde_json::from_str(data);
match graph_input {
Ok(graph_input) => {
return Ok(graph_input);
}
Err(_) => {
let path = std::path::PathBuf::from(data);
GraphData::from_path(path)
}
}
}
/// Loads graph input data from a file
///
/// # Arguments
@@ -609,8 +630,12 @@ impl GraphData {
if input.len() % input_size != 0 {
return Err(GraphError::InvalidDims(
0,
"calibration data length must be evenly divisible by the original input_size"
.to_string(),
format!(
"calibration data length (={}) must be evenly divisible by the original input_size(={})",
input.len(),
input_size
),
));
}

View File

@@ -455,6 +455,10 @@ pub struct GraphSettings {
pub num_blinding_factors: Option<usize>,
/// unix time timestamp
pub timestamp: Option<u128>,
/// Model inputs types (if any)
pub input_types: Option<Vec<InputType>>,
/// Model outputs types (if any)
pub output_types: Option<Vec<InputType>>,
}
impl GraphSettings {

View File

@@ -1,7 +1,6 @@
use super::errors::GraphError;
use super::extract_const_quantized_values;
use super::node::*;
use super::scale_to_multiplier;
use super::vars::*;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
@@ -379,9 +378,15 @@ pub struct ParsedNodes {
pub nodes: BTreeMap<usize, NodeType>,
inputs: Vec<usize>,
outputs: Vec<Outlet>,
output_types: Vec<InputType>,
}
impl ParsedNodes {
/// Returns the output types of the computational graph.
pub fn get_output_types(&self) -> Vec<InputType> {
self.output_types.clone()
}
/// Returns the number of the computational graph's inputs
pub fn num_inputs(&self) -> usize {
self.inputs.len()
@@ -491,6 +496,16 @@ impl Model {
Ok(om)
}
/// Gets the input types from the parsed nodes
pub fn get_input_types(&self) -> Result<Vec<InputType>, GraphError> {
self.graph.get_input_types()
}
/// Gets the output types from the parsed nodes
pub fn get_output_types(&self) -> Vec<InputType> {
self.graph.get_output_types()
}
///
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
@@ -574,6 +589,11 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
num_shuffles: res.num_shuffles,
@@ -704,6 +724,11 @@ impl Model {
nodes,
inputs: model.inputs.iter().map(|o| o.node).collect(),
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
output_types: model
.outputs
.iter()
.map(|o| Ok::<InputType, GraphError>(model.outlet_fact(*o)?.datum_type.into()))
.collect::<Result<Vec<_>, GraphError>>()?,
};
let duration = start_time.elapsed();
@@ -862,6 +887,15 @@ impl Model {
nodes: subgraph_nodes,
inputs: model.inputs.iter().map(|o| o.node).collect(),
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
output_types: model
.outputs
.iter()
.map(|o| {
Ok::<InputType, GraphError>(
model.outlet_fact(*o)?.datum_type.into(),
)
})
.collect::<Result<Vec<_>, GraphError>>()?,
};
let om = Model {
@@ -1138,17 +1172,10 @@ impl Model {
})?;
if run_args.output_visibility.is_public() || run_args.output_visibility.is_fixed() {
let output_scales = self.graph.get_output_scales().map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
tol.scale = scale_to_multiplier(output_scales[i]).into();
let comparators = if run_args.output_visibility == Visibility::Public {
let res = vars
.instance
@@ -1171,7 +1198,6 @@ impl Model {
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::Output {
tol,
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
)
@@ -1433,11 +1459,9 @@ impl Model {
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
.map(|output| {
let mut comparator: ValTensor<Fp> = (0..output.len())
.map(|_| {
if !self.visibility.output.is_fixed() {
@@ -1450,14 +1474,10 @@ impl Model {
.into();
comparator.reshape(output.dims())?;
let mut tol = run_args.tolerance;
tol.scale = scale_to_multiplier(output_scales[i]).into();
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
Box::new(HybridOp::Output {
tol,
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
)
@@ -1579,4 +1599,16 @@ impl Model {
}
Ok(instance_shapes)
}
/// Input types of the computational graph's public inputs (if any)
pub fn instance_types(&self) -> Result<Vec<InputType>, GraphError> {
let mut instance_types = vec![];
if self.visibility.input.is_public() {
instance_types.extend(self.graph.get_input_types()?);
}
if self.visibility.output.is_public() {
instance_types.extend(self.graph.get_output_types());
}
Ok(instance_types)
}
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,6 +22,7 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -31,7 +32,6 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{
@@ -39,9 +39,8 @@ use tract_onnx::tract_hir::{
ops::array::{Pad, PadMode, TypedConcat},
ops::cnn::PoolSpec,
ops::konst::Const,
ops::nn::DataFormat,
tract_core::ops::cast::Cast,
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
tract_core::ops::cnn::{MaxPool, SumPool},
};
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
@@ -1146,13 +1145,6 @@ pub fn new_op_from_onnx(
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
let kernel_shape = &pool_spec.kernel_shape;
@@ -1161,6 +1153,7 @@ pub fn new_op_from_onnx(
padding,
stride: stride.to_vec(),
pool_dims: kernel_shape.to_vec(),
data_format: pool_spec.data_format.into(),
})
}
"Ceil" => {
@@ -1314,15 +1307,6 @@ pub fn new_op_from_onnx(
}
}
if ((conv_node.pool_spec.data_format != DataFormat::NCHW)
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
{
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
));
}
let pool_spec = &conv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
@@ -1350,6 +1334,8 @@ pub fn new_op_from_onnx(
padding,
stride,
group,
data_format: conv_node.pool_spec.data_format.into(),
kernel_format: conv_node.kernel_fmt.into(),
})
}
"Not" => SupportedOp::Linear(PolyOp::Not),
@@ -1373,14 +1359,6 @@ pub fn new_op_from_onnx(
}
}
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|| (deconv_node.kernel_format != KernelFormat::OIHW)
{
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
));
}
let pool_spec = &deconv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
@@ -1406,6 +1384,8 @@ pub fn new_op_from_onnx(
output_padding: deconv_node.adjustments.to_vec(),
stride,
group: deconv_node.group,
data_format: deconv_node.pool_spec.data_format.into(),
kernel_format: deconv_node.kernel_format.into(),
})
}
"Downsample" => {
@@ -1418,7 +1398,7 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Downsample {
axis: downsample_node.axis,
stride: downsample_node.stride as usize,
stride: downsample_node.stride,
modulo: downsample_node.modulo,
})
}
@@ -1489,13 +1469,6 @@ pub fn new_op_from_onnx(
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
@@ -1504,6 +1477,7 @@ pub fn new_op_from_onnx(
stride: stride.to_vec(),
kernel_shape: pool_spec.kernel_shape.to_vec(),
normalized: sumpool_node.normalize,
data_format: pool_spec.data_format.into(),
})
}
"Pad" => {

View File

@@ -97,10 +97,9 @@ impl From<String> for EZKLError {
use std::str::FromStr;
use circuit::{table::Range, CheckMode, Tolerance};
use circuit::{table::Range, CheckMode};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use fieldutils::IntegerRep;
use graph::{Visibility, MAX_PUBLIC_SRS};
use halo2_proofs::poly::{
@@ -275,10 +274,6 @@ impl From<String> for Commitments {
derive(Args, ToFlags)
)]
pub struct RunArgs {
/// Error tolerance for model outputs
/// Only applicable when outputs are public
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
pub tolerance: Tolerance,
/// Fixed point scaling factor for quantizing inputs
/// Higher values provide more precision but increase circuit complexity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
@@ -365,7 +360,6 @@ impl Default for RunArgs {
fn default() -> Self {
Self {
bounded_log_lookup: false,
tolerance: Tolerance::default(),
input_scale: 7,
param_scale: 7,
scale_rebase_multiplier: 1,
@@ -399,6 +393,16 @@ impl RunArgs {
pub fn validate(&self) -> Result<(), String> {
let mut errors = Vec::new();
// check if the largest represented integer in the decomposed form overflows IntegerRep
// try it with the largest possible value
let max_decomp = (self.decomp_base as IntegerRep).checked_pow(self.decomp_legs as u32);
if max_decomp.is_none() {
errors.push(format!(
"decomp_base^decomp_legs overflows IntegerRep: {}^{}",
self.decomp_base, self.decomp_legs
));
}
// Visibility validations
if self.param_visibility == Visibility::Public {
errors.push(
@@ -407,10 +411,6 @@ impl RunArgs {
);
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
}
// Scale validations
if self.scale_rebase_multiplier < 1 {
errors.push("scale_rebase_multiplier must be >= 1".to_string());
@@ -459,11 +459,6 @@ impl RunArgs {
warn!("logrows exceeds maximum public SRS size");
}
// Validate tolerance is non-negative
if self.tolerance.val < 0.0 {
errors.push("tolerance cannot be negative".to_string());
}
// Performance warnings
if self.input_scale > 20 || self.param_scale > 20 {
warn!("High scale values (>20) may impact performance");
@@ -610,23 +605,6 @@ mod tests {
assert!(err.contains("num_inner_cols must be >= 1"));
}
#[test]
fn test_invalid_tolerance() {
let mut args = RunArgs::default();
args.tolerance.val = 1.0;
args.output_visibility = Visibility::Private;
let err = args.validate().unwrap_err();
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
}
#[test]
fn test_negative_tolerance() {
let mut args = RunArgs::default();
args.tolerance.val = -1.0;
let err = args.validate().unwrap_err();
assert!(err.contains("tolerance cannot be negative"));
}
#[test]
fn test_zero_batch_size() {
let mut args = RunArgs::default();

View File

@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::CurveAffine;
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use halo2curves::serde::SerdeObject;
use halo2curves::CurveAffine;
use instant::Instant;
use log::{debug, info, trace};
#[cfg(not(feature = "det-prove"))]
@@ -51,6 +51,9 @@ use pyo3::types::PyDictMethods;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
/// Converts a string to a `SerdeFormat`.
/// # Panics
/// Panics if the provided `s` is not a valid `SerdeFormat` (i.e. not one of "processed", "raw-bytes-unchecked", or "raw-bytes").
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
match s {
"processed" => halo2_proofs::SerdeFormat::Processed,
@@ -321,7 +324,7 @@ where
}
#[cfg(feature = "python-bindings")]
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
#[cfg(feature = "python-bindings")]
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
where
@@ -345,9 +348,9 @@ where
}
impl<
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,

View File

@@ -1,6 +1,6 @@
use thiserror::Error;
use super::ops::DecompositionError;
use super::{ops::DecompositionError, DataFormat};
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
@@ -44,4 +44,7 @@ pub enum TensorError {
/// Index out of bounds
#[error("index {0} out of bounds for dimension {1}")]
IndexOutOfBounds(usize, usize),
/// Invalid data conversion
#[error("invalid data conversion from format {0} to {1}")]
InvalidDataConversion(DataFormat, DataFormat),
}

View File

@@ -9,6 +9,7 @@ pub mod var;
pub use errors::TensorError;
use core::hash::Hash;
use halo2curves::ff::PrimeField;
use maybe_rayon::{
prelude::{
@@ -26,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::{IntegerRep, integer_rep_to_felt},
graph::Visibility,
};
@@ -61,7 +62,7 @@ pub trait TensorType: Clone + Debug + 'static {
}
macro_rules! tensor_type {
($rust_type:ty, $tensor_type:ident, $zero:expr, $one:expr) => {
($rust_type:ty, $tensor_type:ident, $zero:expr_2021, $one:expr_2021) => {
impl TensorType for $rust_type {
fn zero() -> Option<Self> {
Some($zero)
@@ -414,7 +415,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
Err(_) => {
return Err(TensorError::FileLoadError(
"Failed to read tensor".to_string(),
))
));
}
}
}
@@ -925,6 +926,9 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
self.dims = vec![];
}
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
new_dims.iter().product::<usize>()
@@ -1103,6 +1107,10 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
} else if self.dims() == &[0] && shape.iter().product::<usize>() == 1 {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
}
if self.dims().len() > shape.len() {
@@ -1253,7 +1261,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get last element of empty tensor".to_string(),
))
));
}
};
@@ -1278,7 +1286,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get first element of empty tensor".to_string(),
))
));
}
};
@@ -1691,8 +1699,8 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
lhs.par_iter_mut()
.zip(rhs)
.map(|(o, r)| {
if let Some(zero) = T::zero() {
.map(|(o, r)| match T::zero() {
Some(zero) => {
if r != zero {
*o = o.clone() % r;
Ok(())
@@ -1701,11 +1709,10 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
"Cannot divide by zero in remainder".to_string(),
))
}
} else {
Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
))
}
_ => Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
)),
})
.collect::<Result<Vec<_>, _>>()?;
@@ -1767,6 +1774,229 @@ pub fn get_broadcasted_shape(
}
}
////////////////////////
///
/// The shape of data for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
pub enum DataFormat {
/// NCHW
#[default]
NCHW,
/// NHWC
NHWC,
/// CHW
CHW,
/// HWC
HWC,
}
// as str
impl core::fmt::Display for DataFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DataFormat::NCHW => write!(f, "NCHW"),
DataFormat::NHWC => write!(f, "NHWC"),
DataFormat::CHW => write!(f, "CHW"),
DataFormat::HWC => write!(f, "HWC"),
}
}
}
impl DataFormat {
/// Get the format's canonical form
pub fn canonical(&self) -> DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => self.clone(),
}
}
/// no batch dim
pub fn has_no_batch(&self) -> bool {
match self {
DataFormat::CHW | DataFormat::HWC => true,
_ => false,
}
}
/// Convert tensor to canonical format (NCHW or CHW)
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
tensor: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
DataFormat::NHWC => {
// For ND: Move channels from last axis to position after batch
let ndims = tensor.dims().len();
if ndims > 2 {
tensor.move_axis(ndims - 1, 1)?;
}
}
DataFormat::HWC => {
// For ND: Move channels from last axis to first position
let ndims = tensor.dims().len();
if ndims > 1 {
tensor.move_axis(ndims - 1, 0)?;
}
}
_ => {} // NCHW/CHW are already in canonical format
}
Ok(())
}
/// Convert tensor from canonical format to target format
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
tensor: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
DataFormat::NHWC => {
// Move channels from position 1 to end
let ndims = tensor.dims().len();
if ndims > 2 {
tensor.move_axis(1, ndims - 1)?;
}
}
DataFormat::HWC => {
// Move channels from position 0 to end
let ndims = tensor.dims().len();
if ndims > 1 {
tensor.move_axis(0, ndims - 1)?;
}
}
_ => {} // NCHW/CHW don't need conversion
}
Ok(())
}
/// Get the position of the channel dimension
pub fn get_channel_dim(&self, ndims: usize) -> usize {
match self {
DataFormat::NCHW => 1,
DataFormat::NHWC => ndims - 1,
DataFormat::CHW => 0,
DataFormat::HWC => ndims - 1,
}
}
}
/// The shape of the kernel for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
pub enum KernelFormat {
/// HWIO
HWIO,
/// OIHW
#[default]
OIHW,
/// OHWI
OHWI,
}
impl core::fmt::Display for KernelFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KernelFormat::HWIO => write!(f, "HWIO"),
KernelFormat::OIHW => write!(f, "OIHW"),
KernelFormat::OHWI => write!(f, "OHWI"),
}
}
}
impl KernelFormat {
/// Get the format's canonical form
pub fn canonical(&self) -> KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => self.clone(),
}
}
/// Convert kernel to canonical format (OIHW)
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
kernel: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
KernelFormat::HWIO => {
let kdims = kernel.dims().len();
// Move output channels from last to first
kernel.move_axis(kdims - 1, 0)?;
// Move input channels from new last to second position
kernel.move_axis(kdims - 1, 1)?;
}
KernelFormat::OHWI => {
let kdims = kernel.dims().len();
// Move input channels from last to second position
kernel.move_axis(kdims - 1, 1)?;
}
_ => {} // OIHW is already canonical
}
Ok(())
}
/// Convert kernel from canonical format to target format
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
kernel: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
KernelFormat::HWIO => {
let kdims = kernel.dims().len();
// Move input channels from second position to last
kernel.move_axis(1, kdims - 1)?;
// Move output channels from first to last
kernel.move_axis(0, kdims - 1)?;
}
KernelFormat::OHWI => {
let kdims = kernel.dims().len();
// Move input channels from second position to last
kernel.move_axis(1, kdims - 1)?;
}
_ => {} // OIHW doesn't need conversion
}
Ok(())
}
/// Get the position of input and output channel dimensions
pub fn get_channel_dims(&self, ndims: usize) -> (usize, usize) {
// (input_ch, output_ch)
match self {
KernelFormat::OIHW => (1, 0),
KernelFormat::HWIO => (ndims - 2, ndims - 1),
KernelFormat::OHWI => (ndims - 1, 0),
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<tract_onnx::tract_hir::ops::nn::DataFormat> for DataFormat {
fn from(fmt: tract_onnx::tract_hir::ops::nn::DataFormat) -> Self {
match fmt {
tract_onnx::tract_hir::ops::nn::DataFormat::NCHW => DataFormat::NCHW,
tract_onnx::tract_hir::ops::nn::DataFormat::NHWC => DataFormat::NHWC,
tract_onnx::tract_hir::ops::nn::DataFormat::CHW => DataFormat::CHW,
tract_onnx::tract_hir::ops::nn::DataFormat::HWC => DataFormat::HWC,
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat> for KernelFormat {
fn from(fmt: tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat) -> Self {
match fmt {
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::HWIO => {
KernelFormat::HWIO
}
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OIHW => {
KernelFormat::OIHW
}
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OHWI => {
KernelFormat::OHWI
}
}
}
}
#[cfg(test)]
mod tests {

View File

@@ -387,7 +387,7 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
} else if t.is_empty() {
return Err(TensorError::DimMismatch("add".to_string()));
}
@@ -441,7 +441,7 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
} else if t.is_empty() {
return Err(TensorError::DimMismatch("sub".to_string()));
}
// calculate value of output
@@ -492,7 +492,7 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.len() == 0 {
} else if t.is_empty() {
return Err(TensorError::DimMismatch("mult".to_string()));
}
// calculate value of output
@@ -535,30 +535,101 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
/// let result = downsample(&x, 1, 2, 2).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 6]), &[2, 1]).unwrap();
/// assert_eq!(result, expected);
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
/// &[2, 3],
/// ).unwrap();
///
/// // Test case 1: Negative stride along dimension 0
/// // This should flip the order along dimension 0
/// let result = downsample(&x, 0, -1, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 6, 1, 2, 3]), // Flipped order of rows
/// &[2, 3]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 2: Negative stride along dimension 1
/// // This should flip the order along dimension 1
/// let result = downsample(&x, 1, -1, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 2, 1, 6, 5, 4]), // Flipped order of columns
/// &[2, 3]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 3: Negative stride with stride magnitude > 1
/// // This should both skip and flip
/// let result = downsample(&x, 1, -2, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 1, 6, 4]), // Take every 2nd element in reverse
/// &[2, 2]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 4: Negative stride with non-zero modulo
/// // This should start at (size - 1 - modulo) and reverse
/// let result = downsample(&x, 1, -2, 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[2, 5]), // Start at second element from end, take every 2nd in reverse
/// &[2, 1]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Create a larger test case for more complex downsampling
/// let y = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
/// &[3, 4],
/// ).unwrap();
///
/// // Test case 5: Negative stride with modulo on larger tensor
/// let result = downsample(&y, 1, -2, 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 1, 7, 5, 11, 9]), // Start at one after reverse, take every 2nd
/// &[3, 2]
/// ).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn downsample<T: TensorType + Send + Sync>(
input: &Tensor<T>,
dim: usize,
stride: usize,
stride: isize, // Changed from usize to isize to support negative strides
modulo: usize,
) -> Result<Tensor<T>, TensorError> {
let mut output_shape = input.dims().to_vec();
// now downsample along axis dim offset by modulo, rounding up (+1 if remaidner is non-zero)
let remainder = (input.dims()[dim] - modulo) % stride;
let div = (input.dims()[dim] - modulo) / stride;
output_shape[dim] = div + (remainder > 0) as usize;
let mut output = Tensor::<T>::new(None, &output_shape)?;
// Handle negative stride case
if stride == 0 {
return Err(TensorError::DimMismatch(
"downsample stride cannot be zero".to_string(),
));
}
if modulo > input.dims()[dim] {
let stride_abs = stride.unsigned_abs();
let mut output_shape = input.dims().to_vec();
if modulo >= input.dims()[dim] {
return Err(TensorError::DimMismatch("downsample".to_string()));
}
// now downsample along axis dim offset by modulo
// Calculate output shape based on the absolute value of stride
let remainder = (input.dims()[dim] - modulo) % stride_abs;
let div = (input.dims()[dim] - modulo) / stride_abs;
output_shape[dim] = div + (remainder > 0) as usize;
let mut output = Tensor::<T>::new(None, &output_shape)?;
// Calculate indices based on stride direction
let indices = (0..output_shape.len())
.map(|i| {
if i == dim {
let mut index = vec![0; output_shape[i]];
for (i, idx) in index.iter_mut().enumerate() {
*idx = i * stride + modulo;
for (j, idx) in index.iter_mut().enumerate() {
if stride > 0 {
// Positive stride: move forward from modulo
*idx = j * stride_abs + modulo;
} else {
// Negative stride: move backward from (size - 1 - modulo)
*idx = (input.dims()[dim] - 1 - modulo) - j * stride_abs;
}
}
index
} else {
@@ -1326,7 +1397,6 @@ pub fn pad<T: TensorType>(
///
/// # Errors
/// Returns a TensorError if the tensors in `inputs` have incompatible dimensions for concatenation along the specified `axis`.
pub fn concat<T: TensorType + Send + Sync>(
inputs: &[&Tensor<T>],
axis: usize,
@@ -2102,7 +2172,6 @@ pub mod nonlinearities {
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn tanh(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;

View File

@@ -1342,9 +1342,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Gets the total number of elements in the tensor
pub fn len(&self) -> usize {
match self {
ValTensor::Value { dims, .. } => {
ValTensor::Value { dims, inner, .. } => {
if !dims.is_empty() && (dims != &[0]) {
dims.iter().product::<usize>()
} else if dims.is_empty() {
inner.inner.len()
} else {
0
}

View File

@@ -2,7 +2,7 @@ use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::{region::ConstantsMap, CheckMode};
use crate::circuit::{CheckMode, region::ConstantsMap};
use super::*;
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
@@ -403,7 +403,10 @@ impl VarTensor {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
unimplemented!("cannot assign instance to advice columns with omissions")
error!(
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
@@ -569,8 +572,13 @@ impl VarTensor {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
ValTensor::Instance { .. } => {
error!(
"duplication is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, dims, .. } => {
let duplication_freq = if single_inner_col {
self.col_size()
} else {
@@ -583,21 +591,20 @@ impl VarTensor {
self.num_inner_cols()
};
let duplication_offset = if single_inner_col {
row
} else {
offset
};
let duplication_offset = if single_inner_col { row } else { offset };
// duplicates every nth element to adjust for column overflow
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
let mut res: ValTensor<F> = v
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
.unwrap()
.into();
let constants_map = res.create_constants_map();
constants.extend(constants_map);
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
.unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
@@ -627,9 +634,13 @@ impl VarTensor {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
ValTensor::Instance { .. } => {
error!(
"duplication is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, dims, .. } => {
let duplication_freq = self.block_size();
let num_repeats = self.num_inner_cols();
@@ -637,17 +648,31 @@ impl VarTensor {
let duplication_offset = offset;
// duplicates every nth element to adjust for column overflow
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
let v = v
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
.map_err(|e| {
error!("Error duplicating values: {:?}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let mut res: ValTensor<F> = {
v.enum_map(|coord, k| {
let cell = self.assign_value(region, offset, k.clone(), coord, constants)?;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?.into()};
let cell =
self.assign_value(region, offset, k.clone(), coord, constants)?;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into()
};
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
.map_err(|e| {
error!("Error duplicating values: {:?}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
res.reshape(dims).unwrap();
res.reshape(dims).map_err(|e| {
error!("Error duplicating values: {:?}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
res.set_scale(values.scale());
Ok((res, total_used_len))
@@ -681,61 +706,71 @@ impl VarTensor {
let mut prev_cell = None;
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
ValTensor::Instance { .. } => {
error!(
"duplication is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, dims, .. } => {
let duplication_freq = self.col_size();
let num_repeats = 1;
let duplication_offset = row;
// duplicates every nth element to adjust for column overflow
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
let mut res: ValTensor<F> =
v.enum_map(|coord, k| {
let v = v
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
.unwrap();
let mut res: ValTensor<F> = v
.enum_map(|coord, k| {
let step = self.num_inner_cols();
let step = self.num_inner_cols();
let (x, y, z) = self.cartesian_coord(offset + coord * step);
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
// assert that duplication occurred correctly
assert_eq!(
Into::<IntegerRep>::into(k.clone()),
Into::<IntegerRep>::into(v[coord - 1].clone())
);
};
let (x, y, z) = self.cartesian_coord(offset + coord * step);
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
// assert that duplication occurred correctly
assert_eq!(Into::<IntegerRep>::into(k.clone()), Into::<IntegerRep>::into(v[coord - 1].clone()));
};
let cell =
self.assign_value(region, offset, k.clone(), coord * step, constants)?;
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
let at_end_of_column = z == duplication_freq - 1;
let at_beginning_of_column = z == 0;
let at_end_of_column = z == duplication_freq - 1;
let at_beginning_of_column = z == 0;
if at_end_of_column {
// if we are at the end of the column, we need to copy the cell to the next column
prev_cell = Some(cell.clone());
} else if coord > 0 && at_beginning_of_column {
if let Some(prev_cell) = prev_cell.as_ref() {
let cell = if let Some(cell) = cell.cell() {
cell
if at_end_of_column {
// if we are at the end of the column, we need to copy the cell to the next column
prev_cell = Some(cell.clone());
} else if coord > 0 && at_beginning_of_column {
if let Some(prev_cell) = prev_cell.as_ref() {
let cell = if let Some(cell) = cell.cell() {
cell
} else {
error!("Error getting cell: {:?}", (x, y));
return Err(halo2_proofs::plonk::Error::Synthesis);
};
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
prev_cell
} else {
error!("Error getting prev cell: {:?}", (x, y));
return Err(halo2_proofs::plonk::Error::Synthesis);
};
region.constrain_equal(prev_cell, cell)?;
} else {
error!("Error getting cell: {:?}", (x,y));
error!("Previous cell was not set");
return Err(halo2_proofs::plonk::Error::Synthesis);
};
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
prev_cell
} else {
error!("Error getting prev cell: {:?}", (x,y));
return Err(halo2_proofs::plonk::Error::Synthesis);
};
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Previous cell was not set");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
}
}
Ok(cell)
})?.into();
Ok(cell)
})?
.into();
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
.unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
@@ -771,21 +806,30 @@ impl VarTensor {
VarTensor::Advice { inner: advices, .. } => {
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
}
_ => unimplemented!(),
_ => {
error!("VarTensor was not initialized");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
},
// Handle copying previously assigned value
ValType::PrevAssigned(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
}
_ => unimplemented!(),
_ => {
error!("VarTensor was not initialized");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
},
// Handle copying previously assigned constant
ValType::AssignedConstant(v, val) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
}
_ => unimplemented!(),
_ => {
error!("VarTensor was not initialized");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
},
// Handle assigning evaluated value
ValType::AssignedValue(v) => match &self {
@@ -794,7 +838,10 @@ impl VarTensor {
.assign_advice(|| "k", advices[x][y], z, || v)?
.evaluate(),
),
_ => unimplemented!(),
_ => {
error!("VarTensor was not initialized");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
},
// Handle constant value assignment with caching
ValType::Constant(v) => {

Binary file not shown.

View File

@@ -1,13 +1,12 @@
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[cfg(test)]
mod native_tests {
use ezkl::circuit::Tolerance;
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::Commitments;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
use ezkl::pfsys::Snark;
use ezkl::Commitments;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::Bn256;
use lazy_static::lazy_static;
@@ -187,7 +186,7 @@ mod native_tests {
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
const LARGE_TESTS: [&str; 7] = [
const LARGE_TESTS: [&str; 8] = [
"self_attention",
"nanoGPT",
"multihead_attention",
@@ -195,6 +194,7 @@ mod native_tests {
"mnist_gan",
"smallworm",
"fr_age",
"1d_conv",
];
const ACCURACY_CAL_TESTS: [&str; 6] = [
@@ -522,7 +522,7 @@ mod native_tests {
use crate::native_tests::run_js_tests;
use crate::native_tests::render_circuit;
use crate::native_tests::model_serialization_different_binaries;
use rand::Rng;
use tempdir::TempDir;
use ezkl::Commitments;
@@ -543,7 +543,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, false, None, None);
test_dir.close().unwrap();
}
});
@@ -608,7 +608,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -618,22 +618,10 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, true, Some(8194), Some(4));
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_tolerance_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// gen random number between 0.0 and 1.0
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(32776), Some(5));
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
@@ -644,7 +632,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, false, None, None);
test_dir.close().unwrap();
}
}
@@ -654,7 +642,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -663,7 +651,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -672,7 +660,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -681,7 +669,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -690,7 +678,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -699,7 +687,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -708,7 +696,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -718,7 +706,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -728,7 +716,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -737,7 +725,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -747,7 +735,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -756,7 +744,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -766,7 +754,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -776,7 +764,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -786,7 +774,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -796,7 +784,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -806,7 +794,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, false, None, None);
test_dir.close().unwrap();
}
@@ -965,7 +953,7 @@ mod native_tests {
});
seq!(N in 0..=6 {
seq!(N in 0..=7 {
#(#[test_case(LARGE_TESTS[N])])*
#[ignore]
@@ -983,7 +971,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, Some(5));
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, false, None, Some(5));
test_dir.close().unwrap();
}
});
@@ -1459,12 +1447,10 @@ mod native_tests {
batch_size: usize,
cal_target: &str,
scales_to_use: Option<Vec<u32>>,
tolerance: f32,
bounded_lookup_log: bool,
decomp_base: Option<usize>,
decomp_legs: Option<usize>,
) {
let mut tolerance = tolerance;
gen_circuit_settings_and_witness(
test_dir,
example_name.clone(),
@@ -1475,7 +1461,6 @@ mod native_tests {
cal_target,
scales_to_use,
2,
&mut tolerance,
Commitments::KZG,
2,
bounded_lookup_log,
@@ -1483,128 +1468,17 @@ mod native_tests {
decomp_legs,
);
if tolerance > 0.0 {
// load witness and shift the output by a small amount that is less than tolerance percent
let witness = GraphWitness::from_path(
format!("{}/{}/witness.json", test_dir, example_name).into(),
)
.unwrap();
let witness = witness.clone();
let outputs = witness.outputs.clone();
// get values as i64
let output_perturbed_safe: Vec<Vec<halo2curves::bn256::Fr>> = outputs
.iter()
.map(|sv| {
sv.iter()
.map(|v| {
// randomly perturb by a small amount less than tolerance
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::zero()
} else {
integer_rep_to_felt(
(felt_to_integer_rep(*v) as f32
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
as IntegerRep,
)
};
*v + perturbation
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// get values as i64
let output_perturbed_bad: Vec<Vec<halo2curves::bn256::Fr>> = outputs
.iter()
.map(|sv| {
sv.iter()
.map(|v| {
// randomly perturb by a small amount less than tolerance
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::from(2)
} else {
integer_rep_to_felt(
(felt_to_integer_rep(*v) as f32
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
as IntegerRep,
)
};
*v + perturbation
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let good_witness = GraphWitness {
outputs: output_perturbed_safe,
..witness.clone()
};
// save
good_witness
.save(format!("{}/{}/witness_ok.json", test_dir, example_name).into())
.unwrap();
let bad_witness = GraphWitness {
outputs: output_perturbed_bad,
..witness.clone()
};
// save
bad_witness
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
.unwrap();
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness_ok.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness_bad.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(!status.success());
} else {
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
#[allow(clippy::too_many_arguments)]
@@ -1618,7 +1492,6 @@ mod native_tests {
cal_target: &str,
scales_to_use: Option<Vec<u32>>,
num_inner_columns: usize,
tolerance: &mut f32,
commitment: Commitments,
lookup_safety_margin: usize,
bounded_lookup_log: bool,
@@ -1633,13 +1506,16 @@ mod native_tests {
"--settings-path={}/{}/settings.json",
test_dir, example_name
),
format!("--variables=batch_size->{}", batch_size),
format!(
"--variables=batch_size->{},sequence_length->100,<Sym1>->1",
batch_size
),
format!("--input-visibility={}", input_visibility),
format!("--param-visibility={}", param_visibility),
format!("--output-visibility={}", output_visibility),
format!("--num-inner-cols={}", num_inner_columns),
format!("--tolerance={}", tolerance),
format!("--commitment={}", commitment),
format!("--logrows={}", 22),
];
// if output-visibility is fixed set --range-check-inputs-outputs to False
@@ -1695,24 +1571,6 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let mut settings =
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
if any_output_scales_smol {
// set the tolerance to 0.0
settings.run_args.tolerance = Tolerance {
val: 0.0,
scale: 0.0.into(),
};
settings
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
*tolerance = 0.0;
}
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"compile-circuit",
@@ -1725,7 +1583,6 @@ mod native_tests {
test_dir, example_name
),
])
.stdout(std::process::Stdio::null())
.status()
.expect("failed to execute process");
assert!(status.success());
@@ -1768,7 +1625,6 @@ mod native_tests {
cal_target,
None,
2,
&mut 0.0,
Commitments::KZG,
2,
false,
@@ -2054,7 +1910,6 @@ mod native_tests {
target_str,
scales_to_use,
num_inner_columns,
&mut 0.0,
commitment,
lookup_safety_margin,
false,
@@ -2438,7 +2293,12 @@ mod native_tests {
.expect("failed to execute process");
if status.success() {
log::error!("Verification unexpectedly succeeded for modified proof {}. Flipped bit {} in byte {}", i, random_bit, random_byte);
log::error!(
"Verification unexpectedly succeeded for modified proof {}. Flipped bit {} in byte {}",
i,
random_bit,
random_byte
);
}
assert!(
@@ -2489,7 +2349,6 @@ mod native_tests {
// we need the accuracy
Some(vec![4]),
1,
&mut 0.0,
Commitments::KZG,
2,
false,

View File

@@ -46,7 +46,28 @@ mod py_tests {
assert!(status.success());
});
// set VOICE_DATA_DIR environment variable
std::env::set_var("VOICE_DATA_DIR", format!("{}", voice_data_dir));
unsafe {
std::env::set_var("VOICE_DATA_DIR", format!("{}", voice_data_dir));
}
}
fn download_catdog_data() {
let cat_and_dog_data_dir = shellexpand::tilde("~/data/catdog_data");
DOWNLOAD_VOICE_DATA.call_once(|| {
let status = Command::new("bash")
.args([
"examples/notebooks/cat_and_dog_data.sh",
&cat_and_dog_data_dir,
])
.status()
.expect("failed to execute process");
assert!(status.success());
});
// set VOICE_DATA_DIR environment variable
unsafe {
std::env::set_var("CATDOG_DATA_DIR", format!("{}", cat_and_dog_data_dir));
}
}
fn setup_py_env() {
@@ -225,6 +246,20 @@ mod py_tests {
anvil_child.kill().unwrap();
}
#[test]
fn cat_and_dog_notebook_() {
crate::py_tests::init_binary();
let mut anvil_child = crate::py_tests::start_anvil(false);
crate::py_tests::download_catdog_data();
let test_dir: TempDir = TempDir::new("cat_and_dog").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "cat_and_dog.ipynb");
run_notebook(path, "cat_and_dog.ipynb");
test_dir.close().unwrap();
anvil_child.kill().unwrap();
}
#[test]
fn reusable_verifier_notebook_() {
crate::py_tests::init_binary();

View File

@@ -48,7 +48,6 @@ def test_py_run_args():
run_args = ezkl.PyRunArgs()
run_args.input_visibility = "hashed"
run_args.output_visibility = "hashed"
run_args.tolerance = 1.5
def test_poseidon_hash():
@@ -873,7 +872,8 @@ def get_examples():
'linear_regression',
"mnist_gan",
"smallworm",
"fr_age"
"fr_age",
"1d_conv",
]
examples = []
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
@@ -900,7 +900,12 @@ async def test_all_examples(model_file, input_file):
proof_path = os.path.join(folder_path, 'proof.json')
print("Testing example: ", model_file)
res = ezkl.gen_settings(model_file, settings_path)
run_args = ezkl.PyRunArgs()
run_args.variables = [("batch_size", 1), ("sequence_length", 100), ("<Sym1>", 1)]
run_args.logrows = 22
res = ezkl.gen_settings(model_file, settings_path, py_run_args=run_args)
assert res
res = await ezkl.calibrate_settings(

View File

@@ -337,7 +337,7 @@ mod wasm32 {
// Run compiled circuit validation on onnx network (should fail)
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK.to_vec()));
assert!(circuit.is_err());
// Run compiled circuit validation on comiled network (should pass)
// Run compiled circuit validation on compiled network (should pass)
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()));
assert!(circuit.is_ok());
// Run input validation on witness (should fail)