Compare commits

..

1 Commits

Author SHA1 Message Date
github-actions[bot]
76764073b3 ci: update version string in docs 2025-02-05 22:27:00 +00:00
51 changed files with 969 additions and 2441 deletions

View File

@@ -12,10 +12,10 @@ jobs:
contents: read
runs-on: self-hosted
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -29,10 +29,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -46,10 +46,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -63,10 +63,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -80,10 +80,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -97,10 +97,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -114,10 +114,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -131,10 +131,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -148,10 +148,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -165,10 +165,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
@@ -182,10 +182,10 @@ jobs:
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true

View File

@@ -24,15 +24,15 @@ jobs:
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
@@ -40,7 +40,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
@@ -176,7 +176,7 @@ jobs:
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
- name: Set up Node.js
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
@@ -201,7 +201,7 @@ jobs:
RELEASE_TAG: ${{ github.ref_name }}
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Update version in package.json
@@ -230,13 +230,13 @@ jobs:
NR==30{$0=" specifier: \"" tag "\""}
NR==31{$0=" version: \"" tag "\""}
NR==400{$0=" /@ezkljs/engine@" tag ":"}
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
- name: Use pnpm 8
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
uses: pnpm/action-setup@v2
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
@@ -247,4 +247,4 @@ jobs:
pnpm run build
pnpm publish --no-git-checks
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -10,12 +10,12 @@ jobs:
contents: read
runs-on: kaiju
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -28,36 +28,34 @@ jobs:
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag and rename ezkl to ezkl-gpu
- name: Set pyproject.toml version to match github tag
shell: bash
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig > pyproject.toml.tmp
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.tmp > pyproject.toml
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- name: Set Cargo.toml version to match github tag and rename ezkl to ezkl-gpu
- name: Set Cargo.toml version to match github tag
shell: bash
# the ezkl substitution here looks for the first instance of name = "ezkl" and changes it to "ezkl-gpu"
run: |
mv Cargo.toml Cargo.toml.orig
sed "0,/name = \"ezkl\"/s/name = \"ezkl\"/name = \"ezkl-gpu\"/" Cargo.toml.orig > Cargo.toml.tmp
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.tmp > Cargo.toml
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig > Cargo.lock
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
@@ -65,7 +63,7 @@ jobs:
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
manylinux: auto
@@ -78,7 +76,7 @@ jobs:
pip install ezkl-gpu --no-index --find-links dist --force-reinstall
- name: Upload wheels
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
uses: actions/upload-artifact@v4
with:
name: wheels
path: dist
@@ -94,7 +92,7 @@ jobs:
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
needs: [linux]
steps:
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
- uses: actions/download-artifact@v3
with:
name: wheels
- name: List Files
@@ -106,14 +104,14 @@ jobs:
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
packages-dir: ./wheels
packages-dir: ./
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./wheels
packages-dir: ./

View File

@@ -26,10 +26,10 @@ jobs:
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
@@ -48,21 +48,21 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build wheels
if: matrix.target == 'universal2-apple-darwin'
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
- name: Build wheels
if: matrix.target == 'x86_64'
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
@@ -73,9 +73,9 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
uses: actions/upload-artifact@v4
with:
name: dist-macos-${{ matrix.target }}
name: wheels
path: dist
windows:
@@ -87,10 +87,10 @@ jobs:
matrix:
target: [x64, x86]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: ${{ matrix.target }}
@@ -113,14 +113,14 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build wheels
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
@@ -130,9 +130,9 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0 #v4.6.0
uses: actions/upload-artifact@v4
with:
name: dist-windows-${{ matrix.target }}
name: wheels
path: dist
linux:
@@ -144,10 +144,10 @@ jobs:
matrix:
target: [x86_64]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
@@ -176,7 +176,7 @@ jobs:
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
manylinux: auto
@@ -203,9 +203,9 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
uses: actions/upload-artifact@v4
with:
name: dist-linux-${{ matrix.target }}
name: wheels
path: dist
musllinux:
@@ -218,10 +218,10 @@ jobs:
target:
- x86_64-unknown-linux-musl
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: 3.12
architecture: x64
@@ -250,7 +250,7 @@ jobs:
sudo apt-get update && sudo apt-get install -y pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
manylinux: musllinux_1_2
@@ -271,9 +271,9 @@ jobs:
python3 -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
uses: actions/upload-artifact@v4
with:
name: dist-musllinux-${{ matrix.target }}
name: wheels
path: dist
musllinux-cross:
@@ -287,10 +287,10 @@ jobs:
- target: aarch64-unknown-linux-musl
arch: aarch64
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: 3.12
@@ -313,13 +313,13 @@ jobs:
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Build wheels
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.platform.target }}
manylinux: musllinux_1_2
args: --release --out dist --features python-bindings
- uses: uraimo/run-on-arch-action@5397f9e30a9b62422f302092631c99ae1effcd9e #v2.8.1
- uses: uraimo/run-on-arch-action@v2.8.1
name: Install built wheel
with:
arch: ${{ matrix.platform.arch }}
@@ -334,9 +334,9 @@ jobs:
python3 -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
uses: actions/upload-artifact@v4
with:
name: dist-musllinux-${{ matrix.platform.target }}
name: wheels
path: dist
pypi-publish:
@@ -347,26 +347,24 @@ jobs:
if: "startsWith(github.ref, 'refs/tags/')"
needs: [macos, windows, linux, musllinux, musllinux-cross]
steps:
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
- uses: actions/download-artifact@v3
with:
pattern: dist-*
merge-multiple: true
path: wheels
name: wheels
- name: List Files
run: ls -R
# # publishes to TestPyPI
# - name: Publish package distribution to TestPyPI
# uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
# uses: pypa/gh-action-pypi-publish@unstable/v1
# with:
# repository-url: https://test.pypi.org/legacy/
# packages-dir: ./
# publishes to PyPI
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
packages-dir: ./wheels
packages-dir: ./
doc-publish:
@@ -377,7 +375,7 @@ jobs:
needs: pypi-publish
steps:
- uses: actions/checkout@v4
with:
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1

View File

@@ -30,7 +30,7 @@ jobs:
- name: Create Github Release
id: create-release
uses: softprops/action-gh-release@c95fe1489396fe8a9eb87c0abf8aa5b2ef267fda #v2.2.1
uses: softprops/action-gh-release@v1
with:
token: ${{ secrets.RELEASE_TOKEN }}
tag_name: ${{ env.EZKL_VERSION }}
@@ -49,14 +49,14 @@ jobs:
RUST_BACKTRACE: 1
PCRE2_SYS_STATIC: 1
steps:
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
uses: actions/checkout@v4
with:
persist-credentials: false
@@ -90,7 +90,7 @@ jobs:
echo "ASSET=build-artifacts/ezkl-linux-gpu.tar.gz" >> $GITHUB_ENV
- name: Upload release archive
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
uses: actions/upload-release-asset@v1.0.2
env:
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
with:
@@ -119,33 +119,33 @@ jobs:
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2025-02-17
rust: nightly-2024-07-18
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2025-02-17
rust: nightly-2024-07-18
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2025-02-17
rust: nightly-2024-07-18
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2025-02-17
rust: nightly-2024-07-18
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2025-02-17
rust: nightly-2024-07-18
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2025-02-17
rust: nightly-2024-07-18
target: aarch64-unknown-linux-gnu
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get release version from tag
@@ -170,7 +170,7 @@ jobs:
fi
- name: Install Rust
uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
uses: dtolnay/rust-toolchain@nightly
with:
target: ${{ matrix.target }}
@@ -196,7 +196,7 @@ jobs:
echo "target flag is: ${{ env.TARGET_FLAGS }}"
echo "target dir is: ${{ env.TARGET_DIR }}"
- name: Build release binary (no asm or metal)
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
@@ -233,7 +233,7 @@ jobs:
echo "ASSET=build-artifacts/ezkl-win.zip" >> $GITHUB_ENV
- name: Upload release archive
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
uses: actions/upload-release-asset@v1.0.2
env:
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
with:

View File

@@ -25,12 +25,12 @@ jobs:
contents: read
runs-on: large-self-hosted
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -45,12 +45,12 @@ jobs:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build
@@ -61,12 +61,12 @@ jobs:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Docs
@@ -77,15 +77,15 @@ jobs:
contents: read
runs-on: ubuntu-latest-32-cores
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -102,21 +102,25 @@ jobs:
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2025-02-17
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
# - uses: mwilliamson/setup-wasmtime-action@v2
# with:
# wasmtime-version: "3.0.1"
# - name: Install wasm32-wasi
# run: rustup target add wasm32-wasi
# - name: Install cargo-wasi
# run: cargo install cargo-wasi
# # - name: Matmul overflow (wasi)
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# # - name: Conv overflow (wasi)
@@ -135,21 +139,25 @@ jobs:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
- uses: mwilliamson/setup-wasmtime-action@v2
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Install cargo-wasi
run: cargo install cargo-wasi
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
@@ -168,21 +176,25 @@ jobs:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
- uses: mwilliamson/setup-wasmtime-action@v2
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Install cargo-wasi
run: cargo install cargo-wasi
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
@@ -201,15 +213,15 @@ jobs:
contents: read
runs-on: ubuntu-latest-16-cores
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -221,25 +233,25 @@ jobs:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: "v0.12.1"
- uses: nanasess/setup-chromedriver@e93e57b843c0c92788f22483f1a31af8ee48db25 #v2.3.0
- uses: nanasess/setup-chromedriver@v2
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -250,22 +262,20 @@ jobs:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: Large 1D Conv Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_7_expects -- --include-ignored
- name: MNIST Gan Mock
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
- name: NanoGPT Mock
@@ -282,6 +292,8 @@ jobs:
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
- name: public outputs and bounded lookup log
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
- name: public outputs and tolerance > 0
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
- name: kzg inputs
@@ -317,32 +329,32 @@ jobs:
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v3
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
uses: pnpm/action-setup@v2
with:
version: 8
- name: Use Node.js 18.12.1
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --frozen-lockfile
@@ -402,28 +414,28 @@ jobs:
# runs-on: macos-13
# # needs: [build, library-tests, docs]
# steps:
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2025-02-17
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
# - uses: jetli/wasm-pack-action@v0.4.0
# with:
# # Pin to version 0.12.1
# version: 'v0.12.1'
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2025-02-17
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# run: rustup component add rust-src --toolchain nightly-2024-07-18
# - uses: actions/checkout@v3
# with:
# persist-credentials: false
# - name: Use pnpm 8
# uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
# uses: pnpm/action-setup@v2
# with:
# version: 8
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
@@ -436,15 +448,15 @@ jobs:
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
- uses: jetli/wasm-pack-action@v0.4.0
with:
# Pin to version 0.12.1
version: "v0.12.1"
@@ -452,16 +464,16 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
uses: pnpm/action-setup@v2
with:
version: 8
- name: Use Node.js 18.12.1
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
cache: "pnpm"
@@ -471,7 +483,7 @@ jobs:
env:
CI: false
NODE_ENV: development
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -517,18 +529,18 @@ jobs:
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2025-02-17
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
# - uses: actions/checkout@v3
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
@@ -555,15 +567,15 @@ jobs:
runs-on: self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -575,15 +587,15 @@ jobs:
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# - uses: actions/checkout@v4
# with:
# persist-credentials: false
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2025-02-17
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
@@ -596,15 +608,15 @@ jobs:
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -617,15 +629,15 @@ jobs:
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -642,15 +654,15 @@ jobs:
runs-on: ubuntu-latest-32-cores
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -663,15 +675,15 @@ jobs:
runs-on: non-gpu
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Install cmake
@@ -693,18 +705,18 @@ jobs:
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: "3.12"
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -744,18 +756,18 @@ jobs:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
@@ -769,8 +781,6 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
- name: Cat and Dog notebook
run: source .env/bin/activate; cargo nextest run py_tests::tests::cat_and_dog_notebook_
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
@@ -802,20 +812,20 @@ jobs:
contents: read
runs-on: macos-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2025-02-17-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
swift-package-tests:
permissions:
@@ -824,12 +834,12 @@ jobs:
needs: [ios-integration-tests]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- name: Build EzklCoreBindings

View File

@@ -12,12 +12,12 @@ jobs:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2025-02-17
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
@@ -30,3 +30,4 @@ jobs:
run: zizmor .

View File

@@ -19,8 +19,8 @@ jobs:
steps:
- name: Checkout EZKL
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
uses: actions/checkout@v3
with:
persist-credentials: false
- name: Extract TAG from github.ref_name
@@ -34,7 +34,7 @@ jobs:
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
- name: Install Rust (nightly)
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true

View File

@@ -11,12 +11,12 @@ jobs:
contents: write
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@a22cf08638b34d5badda920f9daf6e72c477b07b #v6.2
uses: mathieudutour/github-tag-action@v6.2
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
@@ -46,7 +46,7 @@ jobs:
git tag $RELEASE_TAG
- name: Push changes
uses: ad-m/github-push-action@77c5b412c50b723d2a4fbc6d71fb5723bcd439aa #master
uses: ad-m/github-push-action@master
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
with:

View File

@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
[package]
name = "ezkl"
version = "0.0.0"
edition = "2024"
edition = "2021"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -73,8 +73,6 @@ impl Circuit<Fr> for MyCircuit {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
}),
)
.unwrap();

View File

@@ -69,7 +69,6 @@ impl Circuit<Fr> for MyCircuit {
stride: vec![1, 1],
kernel_shape: vec![2, 2],
normalized: false,
data_format: DataFormat::NCHW,
}),
)
.unwrap();

View File

@@ -23,6 +23,8 @@ use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
const L: usize = 10;
#[derive(Clone, Debug)]
struct MyCircuit {
image: ValTensor<Fr>,
@@ -38,7 +40,7 @@ impl Circuit<Fr> for MyCircuit {
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::configure(cs, ())
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, 10>::configure(cs, ())
}
fn synthesize(
@@ -46,7 +48,7 @@ impl Circuit<Fr> for MyCircuit {
config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE> =
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
@@ -57,7 +59,7 @@ fn runposeidon(c: &mut Criterion) {
let mut group = c.benchmark_group("poseidon");
for size in [64, 784, 2352, 12288].iter() {
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::num_rows(*size)
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::num_rows(*size)
as f32)
.log2()
.ceil() as u32;
@@ -65,7 +67,7 @@ fn runposeidon(c: &mut Criterion) {
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
let _output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.to_vec())
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
.unwrap();
let mut image = Tensor::from(message.into_iter().map(Value::known));

View File

@@ -22,7 +22,7 @@ fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
let mut rng = rand::thread_rng();
(0..size)
.map(|_i| {
if rng.r#gen::<f64>() < zero_probability {
if rng.gen::<f64>() < zero_probability {
ValType::Constant(F::ZERO)
} else {
ValType::Constant(F::ONE) // Or some other non-zero value

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '18.1.14'
version = release

View File

@@ -32,7 +32,6 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -209,8 +208,6 @@ where
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
};
let x = config
.layer_config

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +0,0 @@
# download tess data
# check if first argument has been set
if [ ! -z "$1" ]; then
DATA_DIR=$1
else
DATA_DIR=data
fi
echo "Downloading data to $DATA_DIR"
if [ ! -d "$DATA_DIR/CATDOG" ]; then
kaggle datasets download tongpython/cat-and-dog -p $DATA_DIR/CATDOG --unzip
fi

View File

@@ -1,106 +0,0 @@
{
"input_data": [
[
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127,
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127
]
]
}

Binary file not shown.

View File

@@ -1 +0,0 @@
{"run_args":{"input_scale":7,"param_scale":7,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"num_inner_cols":2,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private","rebase_frac_zero_constants":false,"check_mode":"UNSAFE","commitment":"KZG","decomp_base":16384,"decomp_legs":2,"bounded_log_lookup":false,"ignore_range_check_inputs_outputs":false},"num_rows":54,"total_assignments":109,"total_const_size":4,"total_dynamic_col_size":0,"max_dynamic_input_len":0,"num_dynamic_lookups":0,"num_shuffles":0,"total_shuffle_col_size":0,"model_instance_shapes":[[1,1]],"model_output_scales":[7],"model_input_scales":[7],"module_sizes":{"polycommit":[],"poseidon":[0,[0]]},"required_lookups":[],"required_range_checks":[[-1,1],[0,16383]],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null,"timestamp":1739396322131,"input_types":["F32"],"output_types":["F32"]}

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2025-02-17"
channel = "nightly-2024-07-18"
components = ["rustfmt", "clippy"]

View File

@@ -28,8 +28,6 @@ use std::env;
#[tokio::main(flavor = "current_thread")]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn main() {
use log::debug;
let args = Cli::parse();
if let Some(generator) = args.generator {
@@ -44,7 +42,7 @@ pub async fn main() {
} else {
info!("Running with CPU");
}
debug!(
info!(
"command: \n {}",
&command.as_json().to_colored_json_auto().unwrap()
);

View File

@@ -4,8 +4,8 @@ use crate::circuit::modules::poseidon::{
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::{CheckMode, Tolerance};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::TestDataSource;
@@ -155,6 +155,9 @@ impl pyo3::ToPyObject for PyG1Affine {
#[derive(Clone)]
#[gen_stub_pyclass]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
pub tolerance: f32,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing inputs
pub input_scale: crate::Scale,
@@ -222,6 +225,7 @@ impl From<PyRunArgs> for RunArgs {
fn from(py_run_args: PyRunArgs) -> Self {
RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
tolerance: Tolerance::from(py_run_args.tolerance),
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
num_inner_cols: py_run_args.num_inner_cols,
@@ -246,6 +250,7 @@ impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
tolerance: self.tolerance.val,
input_scale: self.input_scale,
param_scale: self.param_scale,
num_inner_cols: self.num_inner_cols,
@@ -332,8 +337,6 @@ enum PyInputType {
Int,
///
TDim,
///
Unknown,
}
impl From<InputType> for PyInputType {
@@ -345,7 +348,6 @@ impl From<InputType> for PyInputType {
InputType::F64 => PyInputType::F64,
InputType::Int => PyInputType::Int,
InputType::TDim => PyInputType::TDim,
InputType::Unknown => PyInputType::Unknown,
}
}
}
@@ -359,7 +361,6 @@ impl From<PyInputType> for InputType {
PyInputType::F64 => InputType::F64,
PyInputType::Int => InputType::Int,
PyInputType::TDim => InputType::TDim,
PyInputType::Unknown => InputType::Unknown,
}
}
}
@@ -374,7 +375,6 @@ impl FromStr for PyInputType {
"f64" => Ok(PyInputType::F64),
"int" => Ok(PyInputType::Int),
"tdim" => Ok(PyInputType::TDim),
"unknown" => Ok(PyInputType::Unknown),
_ => Err("Invalid value for InputType".to_string()),
}
}
@@ -592,7 +592,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represented as strings
/// List of field elements represnted as strings
///
/// vk_path: str
/// Path to the verification key
@@ -651,7 +651,7 @@ fn kzg_commit(
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represented as strings
/// List of field elements represnted as strings
///
/// vk_path: str
/// Path to the verification key
@@ -1009,7 +1009,7 @@ fn gen_random_data(
/// bool
///
#[pyfunction(signature = (
data = String::from(DEFAULT_CALIBRATION_FILE),
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
model = PathBuf::from(DEFAULT_MODEL),
settings = PathBuf::from(DEFAULT_SETTINGS),
target = CalibrationTarget::default(), // default is "resources
@@ -1021,7 +1021,7 @@ fn gen_random_data(
#[gen_stub_pyfunction]
fn calibrate_settings(
py: Python,
data: String,
data: PathBuf,
model: PathBuf,
settings: PathBuf,
target: CalibrationTarget,
@@ -1076,7 +1076,7 @@ fn calibrate_settings(
/// Python object containing the witness values
///
#[pyfunction(signature = (
data=String::from(DEFAULT_DATA),
data=PathBuf::from(DEFAULT_DATA),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
output=PathBuf::from(DEFAULT_WITNESS),
vk_path=None,
@@ -1085,7 +1085,7 @@ fn calibrate_settings(
#[gen_stub_pyfunction]
fn gen_witness(
py: Python,
data: String,
data: PathBuf,
model: PathBuf,
output: Option<PathBuf>,
vk_path: Option<PathBuf>,
@@ -1754,7 +1754,7 @@ fn create_evm_vka(
/// bool
///
#[pyfunction(signature = (
input_data=String::from(DEFAULT_DATA),
input_data=PathBuf::from(DEFAULT_DATA),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
@@ -1763,7 +1763,7 @@ fn create_evm_vka(
#[gen_stub_pyfunction]
fn create_evm_data_attestation(
py: Python,
input_data: String,
input_data: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
@@ -1824,7 +1824,7 @@ fn create_evm_data_attestation(
#[gen_stub_pyfunction]
fn setup_test_evm_witness(
py: Python,
data_path: String,
data_path: PathBuf,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
input_source: PyTestDataSource,
@@ -1902,7 +1902,7 @@ fn deploy_evm(
fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
input_data: String,
input_data: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
@@ -1945,7 +1945,7 @@ fn deploy_da_evm(
/// does the verifier use data attestation ?
///
/// addr_vk: str
/// The address of the separate VK contract (if the verifier key is rendered as a separate contract)
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
/// Returns
/// -------
/// bool

View File

@@ -1,7 +1,7 @@
/*
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
use std::collections::HashMap;

View File

@@ -1,7 +1,7 @@
/*
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
pub mod poseidon_params;

View File

@@ -20,6 +20,7 @@ use crate::{
circuit::{
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
@@ -84,6 +85,55 @@ impl CheckMode {
}
}
#[allow(missing_docs)]
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
pub struct Tolerance {
pub val: f32,
pub scale: utils::F32,
}
impl std::fmt::Display for Tolerance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.2}", self.val)
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Tolerance {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl FromStr for Tolerance {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(val) = s.parse::<f32>() {
Ok(Tolerance {
val,
scale: utils::F32(1.0),
})
} else {
Err(
"Invalid tolerance value provided. It should expressed as a percentage (f32)."
.to_string(),
)
}
}
}
impl From<f32> for Tolerance {
fn from(value: f32) -> Self {
Tolerance {
val: value,
scale: utils::F32(1.0),
}
}
}
#[cfg(feature = "python-bindings")]
/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python)
impl IntoPy<PyObject> for CheckMode {
@@ -108,6 +158,29 @@ impl<'source> FromPyObject<'source> for CheckMode {
}
}
#[cfg(feature = "python-bindings")]
/// Converts Tolerance into a PyObject (Required for Tolerance to be compatible with Python)
impl IntoPy<PyObject> for Tolerance {
fn into_py(self, py: Python) -> PyObject {
(self.val, self.scale.0).to_object(py)
}
}
#[cfg(feature = "python-bindings")]
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
impl<'source> FromPyObject<'source> for Tolerance {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
Ok(Tolerance {
val,
scale: utils::F32(scale),
})
} else {
Err(PyValueError::new_err("Invalid tolerance value provided. "))
}
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct DynamicLookups {

View File

@@ -1,9 +1,9 @@
use super::*;
use crate::{
circuit::{layouts, utils},
circuit::{layouts, utils, Tolerance},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
tensor::{self, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
@@ -57,13 +57,11 @@ pub enum HybridOp {
stride: Vec<usize>,
kernel_shape: Vec<usize>,
normalized: bool,
data_format: DataFormat,
},
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
data_format: DataFormat,
},
ReduceMin {
axes: Vec<usize>,
@@ -79,6 +77,7 @@ pub enum HybridOp {
axes: Vec<usize>,
},
Output {
tol: Tolerance,
decomp: bool,
},
Greater,
@@ -155,10 +154,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
kernel_shape,
normalized, data_format
normalized,
} => format!(
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
padding, stride, kernel_shape, normalized, data_format
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
padding, stride, kernel_shape, normalized
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
@@ -166,10 +165,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => format!(
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
padding, stride, pool_dims, data_format
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
@@ -183,8 +181,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale, output_scale, axes
)
}
HybridOp::Output { decomp } => {
format!("OUTPUT (decomp={})", decomp)
HybridOp::Output { tol, decomp } => {
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
}
HybridOp::Greater => "GREATER".to_string(),
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
@@ -241,7 +239,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
normalized,
data_format,
} => layouts::sumpool(
config,
region,
@@ -250,7 +247,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
*normalized,
*data_format,
)?,
HybridOp::Recip {
input_scale,
@@ -291,7 +287,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => layouts::max_pool(
config,
region,
@@ -299,7 +294,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
*data_format,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?
@@ -325,9 +319,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
*output_scale,
axes,
)?,
HybridOp::Output { decomp } => {
layouts::output(config, region, values[..].try_into()?, *decomp)?
}
HybridOp::Output { tol, decomp } => layouts::output(
config,
region,
values[..].try_into()?,
tol.scale,
tol.val,
*decomp,
)?,
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
HybridOp::GreaterEqual => {
layouts::greater_equal(config, region, values[..].try_into()?)?

View File

@@ -18,11 +18,11 @@ use self::tensor::{create_constant_tensor, create_zero_tensor};
use super::{chip::BaseConfig, region::RegionCtx};
use crate::{
circuit::{ops::base::BaseOp, utils},
fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt},
tensor::{DataFormat, KernelFormat},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
tensor::{
Tensor, TensorError, ValType, create_unit_tensor, get_broadcasted_shape,
create_unit_tensor, get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
Tensor, TensorError, ValType,
},
};
@@ -156,6 +156,25 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
claimed_output.reshape(input_dims)?;
// implicitly check if the prover provided output is within range
let claimed_output = identity(config, region, &[claimed_output], true)?;
// check if x is too large only if the decomp would support overflow in the previous op
if F::from_u128(IntegerRep::MAX as u128)
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
{
// here we decompose and extract the sign of the input
let sign = sign(config, region, &[claimed_output.clone()])?;
let abs_value = pairwise(
config,
region,
&[claimed_output.clone(), sign],
BaseOp::Mult,
)?;
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
// assert the result is 1
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
enforce_equality(config, region, &[abs_value, comparison_unit])?;
}
let product = pairwise(
config,
@@ -229,6 +248,32 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&[equal_zero_mask.clone(), equal_inverse_mask],
)?;
let masked_output = pairwise(
config,
region,
&[claimed_output.clone(), not_equal_zero_mask.clone()],
BaseOp::Mult,
)?;
// check if x is too large only if the decomp would support overflow in the previous op
if F::from_u128(IntegerRep::MAX as u128)
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
{
// here we decompose and extract the sign of the input
let sign = sign(config, region, &[masked_output.clone()])?;
let abs_value = pairwise(
config,
region,
&[claimed_output.clone(), sign],
BaseOp::Mult,
)?;
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
// assert the result is 1
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
enforce_equality(config, region, &[abs_value, comparison_unit])?;
}
let err_func = |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
x: &ValTensor<F>|
@@ -304,7 +349,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
.into()
};
claimed_output.reshape(input_dims)?;
// force the output to be positive or zero, also implicitly checks that the output is in range
// force the output to be positive or zero, also implicitly checks that the ouput is in range
let claimed_output = abs(config, region, &[claimed_output.clone()])?;
// rescaled input
let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?;
@@ -1750,15 +1795,15 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
// assert than res is less than the product of the dims
if region.witness_gen() {
assert!(
res.int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as IntegerRep),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
res.int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as IntegerRep),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
}
@@ -1796,7 +1841,7 @@ pub(crate) fn get_missing_set_elements<
// get the difference between the two vectors
for eval in input_evals.iter() {
// delete first occurrence of that value
// delete first occurence of that value
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
fullset_evals.remove(pos);
}
@@ -1824,7 +1869,7 @@ pub(crate) fn get_missing_set_elements<
region.increment(claimed_output.len());
// input and claimed output should be the shuffles of fullset
// concatenate input and claimed output
// concatentate input and claimed output
let input_and_claimed_output = input.concat(claimed_output.clone())?;
// assert that this is a permutation/shuffle
@@ -2210,12 +2255,12 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
axes: &[usize],
// generic layout op
op: impl Fn(
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
) -> Result<ValTensor<F>, CircuitError> {
// calculate value of output
@@ -3180,7 +3225,6 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::circuit::region::RegionSettings;
/// use ezkl::circuit::BaseConfig;
/// use ezkl::tensor::ValTensor;
/// use ezkl::tensor::DataFormat;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
@@ -3190,12 +3234,12 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
/// ).unwrap());
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false, DataFormat::default()).unwrap();
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false).unwrap();
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled.int_evals().unwrap(), expected);
///
/// // This time with normalization
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true, DataFormat::default()).unwrap();
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true).unwrap();
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled.int_evals().unwrap(), expected);
/// ```
@@ -3207,19 +3251,9 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
stride: &[usize],
kernel_shape: &[usize],
normalized: bool,
data_format: DataFormat,
) -> Result<ValTensor<F>, CircuitError> {
let mut image = values[0].clone();
data_format.to_canonical(&mut image)?;
if data_format.has_no_batch() {
let mut dims = image.dims().to_vec();
dims.insert(0, 1);
image.reshape(&dims)?;
}
let batch_size = image.dims()[0];
let image_channels = image.dims()[1];
let batch_size = values[0].dims()[0];
let image_channels = values[0].dims()[1];
let kernel_len = kernel_shape.iter().product();
@@ -3244,16 +3278,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
.map(|coord| {
let (b, i) = (coord[0], coord[1]);
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
let output = conv(
config,
region,
&[input, kernel.clone()],
padding,
stride,
1,
DataFormat::default(),
KernelFormat::default(),
)?;
let output = conv(config, region, &[input, kernel.clone()], padding, stride, 1)?;
res.push(output);
Ok(())
})
@@ -3268,9 +3293,6 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
if normalized {
last_elem = div(config, region, &[last_elem], F::from(kernel_len as u64))?;
}
data_format.from_canonical(&mut last_elem)?;
Ok(last_elem)
}
@@ -3280,7 +3302,6 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::circuit::ops::layouts::max_pool;
/// use ezkl::tensor::DataFormat;
/// use halo2curves::bn256::Fr as Fp;
/// use ezkl::circuit::region::RegionCtx;
/// use ezkl::circuit::region::RegionSettings;
@@ -3295,7 +3316,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
/// ).unwrap());
/// let pooled = max_pool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2], DataFormat::default()).unwrap();
/// let pooled = max_pool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2]).unwrap();
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[5, 4, 4, 6]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled.int_evals().unwrap(), expected);
///
@@ -3307,16 +3328,8 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
padding: &[(usize, usize)],
stride: &[usize],
pool_dims: &[usize],
data_format: DataFormat,
) -> Result<ValTensor<F>, CircuitError> {
let mut image = values[0].clone();
data_format.to_canonical(&mut image)?;
if data_format.has_no_batch() {
let mut dims = image.dims().to_vec();
dims.insert(0, 1);
image.reshape(&dims)?;
}
let image = values[0].clone();
let image_dims = image.dims();
@@ -3375,38 +3388,38 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region.apply_in_loop(&mut output, inner_loop_function)?;
let mut res: ValTensor<F> = output.into();
data_format.from_canonical(&mut res)?;
let res: ValTensor<F> = output.into();
Ok(res)
}
/// Performs a deconvolution on the given input tensor.
/// # Examples
/// ```
// // expected ouputs are taken from pytorch torch.nn.functional.conv_transpose2d
///
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::circuit::ops::layouts::deconv;
/// use ezkl::tensor::{val::ValTensor, DataFormat, KernelFormat};
/// use halo2curves::bn256::Fr as Fp;
/// use ezkl::circuit::region::RegionCtx;
/// use ezkl::circuit::region::RegionSettings;
/// use ezkl::circuit::BaseConfig;
/// use ezkl::tensor::ValTensor;
///
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
///
/// // Original test case 1: Channel expansion
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
///
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 2: Basic deconvolution
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3415,11 +3428,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 3: With padding
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3428,11 +3441,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[17]), &[1, 1, 1, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 4: With stride
///
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3441,11 +3454,10 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 5: Zero padding with stride
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3454,11 +3466,10 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 6: Different kernel shape
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3467,11 +3478,10 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 2]),
/// &[1, 1, 2, 1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 7: Different kernel shape without padding
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
@@ -3480,21 +3490,20 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[3, 2]),
/// &[1, 1, 2, 1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 8: Channel expansion with stride
///
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
///
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Original test case 9: With bias
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[3, 8, 0, 8, 4, 9, 8, 1, 8]),
/// &[1, 1, 3, 3],
@@ -3507,89 +3516,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[1]),
/// &[1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 1: NHWC format with HWIO kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 2, 2, 1], // NHWC format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 5, 3]),
/// &[2, 2, 1, 1], // HWIO format
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::HWIO).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[27]), &[1, 1, 1, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 2: 1D deconvolution with NCHW format
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3]),
/// &[1, 1, 3], // NCH format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2]),
/// &[1, 1, 2], // OIH format
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0)], &vec![0], &vec![1], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 7, 6]), &[1, 1, 4]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 3: 3D deconvolution with NCHW format
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4]),
/// &[1, 1, 2, 2, 1], // NCDHW format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1]),
/// &[1, 1, 1, 1, 2], // OIDHW format
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 3], &vec![0; 3], &vec![1; 3], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 2, 3, 3, 4, 4]), &[1, 1, 2, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 4: Multi-channel with NHWC format and OHWI kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1, 3, 2, 1, 4]), // 2 channels, 2x2 spatial
/// &[1, 2, 2, 2], // NHWC format [batch, height, width, channels]
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[1, 2, 2, 2], // OHWI format [out_channels, height, width, in_channels]
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::OHWI).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[10, 24, 4, 41, 78, 27, 27, 66, 39]), &[1, 3, 3, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 5: CHW format (no batch dimension)
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 4, 0, 1]),
/// &[1, 2, 2], // CHW format [channels, height, width]
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4]),
/// &[1, 1, 2, 2], // OIHW format [out_channels, in_channels, height, width]
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::CHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 6, 6]), &[1, 1, 1, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Additional test case 6: HWC format with HWIO kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[2, 3, 4, 1]),
/// &[2, 2, 1], // HWC format [height, width, channels]
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 1, 2]),
/// &[2, 2, 1, 1], // HWIO format [height, width, in_channels, out_channels]
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::HWC, KernelFormat::HWIO).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 6, 6]), &[1, 1, 3, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
/// ```
///
pub fn deconv<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
>(
@@ -3600,14 +3531,9 @@ pub fn deconv<
output_padding: &[usize],
stride: &[usize],
num_groups: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
) -> Result<ValTensor<F>, CircuitError> {
let has_bias = inputs.len() == 3;
let (mut working_image, mut working_kernel) = (inputs[0].clone(), inputs[1].clone());
data_format.to_canonical(&mut working_image)?;
kernel_format.to_canonical(&mut working_kernel)?;
let (image, kernel) = (&inputs[0], &inputs[1]);
if stride.iter().any(|&s| s == 0) {
return Err(TensorError::DimMismatch(
@@ -3617,23 +3543,26 @@ pub fn deconv<
}
let null_val = ValType::Constant(F::ZERO);
let mut expanded_image = working_image.clone();
// Expand image by inserting zeros according to stride
let mut expanded_image = image.clone();
for (i, s) in stride.iter().enumerate() {
expanded_image.intercalate_values(null_val.clone(), *s, 2 + i)?;
}
// Pad to kernel size for each spatial dimension
expanded_image.pad(
working_kernel.dims()[2..]
kernel.dims()[2..]
.iter()
.map(|d| (d - 1, d - 1))
.collect::<Vec<_>>(),
2,
)?;
)?; // pad to the kernel size
// flip order
let channel_coord = (0..kernel.dims()[0])
.cartesian_product(0..kernel.dims()[1])
.collect::<Vec<_>>();
// Calculate slice coordinates considering padding and output padding
let slice_coord = expanded_image
.dims()
.iter()
@@ -3649,34 +3578,26 @@ pub fn deconv<
let sliced_expanded_image = expanded_image.get_slice(&slice_coord)?;
// Generate channel coordinates for kernel transformation
let (in_ch_dim, out_ch_dim) =
KernelFormat::default().get_channel_dims(working_kernel.dims().len());
let channel_coord = (0..working_kernel.dims()[out_ch_dim])
.cartesian_product(0..working_kernel.dims()[in_ch_dim])
.collect::<Vec<_>>();
// Invert kernels for deconvolution
let mut inverted_kernels = vec![];
for (i, j) in channel_coord {
let channel = working_kernel.get_slice(&[i..i + 1, j..j + 1])?;
let channel = kernel.get_slice(&[i..i + 1, j..j + 1])?;
let mut channel = Tensor::from(channel.get_inner_tensor()?.clone().into_iter().rev());
channel.reshape(&working_kernel.dims()[2..])?;
channel.reshape(&kernel.dims()[2..])?;
inverted_kernels.push(channel);
}
let mut deconv_kernel =
Tensor::new(Some(&inverted_kernels), &[inverted_kernels.len()])?.combine()?;
deconv_kernel.reshape(working_kernel.dims())?;
deconv_kernel.reshape(kernel.dims())?;
// Handle tensorflow-style input/output channel ordering
if working_kernel.dims()[0] == sliced_expanded_image.dims()[1] {
// tensorflow formatting patch
if kernel.dims()[0] == sliced_expanded_image.dims()[1] {
let mut dims = deconv_kernel.dims().to_vec();
dims.swap(0, 1);
deconv_kernel.reshape(&dims)?;
}
// Prepare inputs for convolution
let conv_input = if has_bias {
vec![
sliced_expanded_image,
@@ -3687,32 +3608,28 @@ pub fn deconv<
vec![sliced_expanded_image, deconv_kernel.clone().into()]
};
let conv_dim = working_kernel.dims()[2..].len();
let conv_dim = kernel.dims()[2..].len();
// Perform convolution with canonical formats
let mut output = conv(
let output = conv(
config,
region,
&conv_input,
&vec![(0, 0); conv_dim],
&vec![1; conv_dim],
num_groups,
data_format.canonical(), // Use canonical format
kernel_format.canonical(), // Use canonical format
)?;
// Convert output back to requested format
data_format.from_canonical(&mut output)?;
Ok(output)
}
/// Applies convolution over a ND tensor of shape C x H x D1...DN (and adds a bias).
/// ```
/// // expected ouputs are taken from pytorch torch.nn.functional.conv2d
///
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::circuit::ops::layouts::conv;
/// use ezkl::tensor::{val::ValTensor, DataFormat, KernelFormat};
/// use ezkl::tensor::val::ValTensor;
/// use halo2curves::bn256::Fr as Fp;
/// use ezkl::circuit::region::RegionCtx;
/// use ezkl::circuit::region::RegionSettings;
@@ -3721,7 +3638,6 @@ pub fn deconv<
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
///
/// // Test case 1: Basic 2D convolution with NCHW format (default)
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
@@ -3734,64 +3650,44 @@ pub fn deconv<
/// Some(&[0]),
/// &[1],
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Test case 2: NHWC format with HWIO kernel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 3, 3, 1], // NHWC format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 5, 1]),
/// &[2, 2, 1, 1], // HWIO format
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::HWIO).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[11, 24, 20, 14]), &[1, 2, 2, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Test case 3: Multi-channel NHWC with OHWI kernel
/// // Now test single channel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 3, 3, 2], // NHWC format
/// &[1, 2, 3, 3],
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[5, 1, 1, 2, 5, 2, 1, 2]),
/// &[1, 2, 2, 2], // OHWI format
/// Some(&[5, 1, 1, 1, 5, 2, 1, 1]),
/// &[2, 1, 2, 2],
/// ).unwrap());
/// let b = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1]),
/// &[2],
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::OHWI).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[64, 66, 46, 58]), &[1, 2, 2, 1]).unwrap();
///
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 2).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Test case 4: 1D convolution with NCHW format
/// // Now test multi channel
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5]),
/// &[1, 1, 5], // NCHW format
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 2, 3, 3],
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3]),
/// &[1, 1, 3], // OIHW format
/// Some(&[5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1, 5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1]),
/// &[4, 2, 2, 2],
/// ).unwrap());
/// let b = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1, 1, 1]),
/// &[4],
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0)], &vec![1], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 20, 26]), &[1, 1, 3]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// // Test case 5: 3D convolution with NCHW format
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[1, 1, 2, 2, 2], // NCDHW format
/// ).unwrap());
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[1, 1]),
/// &[1, 1, 1, 1, 2], // OIDHW format
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 3], &vec![1; 3], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 7, 11, 15]), &[1, 1, 2, 2, 1]).unwrap();
/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
/// ```
///
@@ -3804,14 +3700,9 @@ pub fn conv<
padding: &[(usize, usize)],
stride: &[usize],
num_groups: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
) -> Result<ValTensor<F>, CircuitError> {
let has_bias = values.len() == 3;
let (mut working_image, mut working_kernel) = (values[0].clone(), values[1].clone());
data_format.to_canonical(&mut working_image)?;
kernel_format.to_canonical(&mut working_kernel)?;
let (mut image, mut kernel) = (values[0].clone(), values[1].clone());
if stride.iter().any(|&s| s == 0) {
return Err(TensorError::DimMismatch(
@@ -3820,40 +3711,47 @@ pub fn conv<
.into());
}
// Assign tensors
// we specifically want to use the same kernel and image for all the convolutions and need to enforce this by assigning them
// 1. assign the kernel
let mut assigned_len = vec![];
if !working_kernel.all_prev_assigned() {
working_kernel = region.assign(&config.custom_gates.inputs[0], &working_kernel)?;
assigned_len.push(working_kernel.len());
if !kernel.all_prev_assigned() {
kernel = region.assign(&config.custom_gates.inputs[0], &kernel)?;
assigned_len.push(kernel.len());
}
if !working_image.all_prev_assigned() {
working_image = region.assign(&config.custom_gates.inputs[1], &working_image)?;
assigned_len.push(working_image.len());
// 2. assign the image
if !image.all_prev_assigned() {
image = region.assign(&config.custom_gates.inputs[1], &image)?;
assigned_len.push(image.len());
}
if !assigned_len.is_empty() {
// safe to unwrap since we've just checked it has at least one element
region.increment(*assigned_len.iter().max().unwrap());
}
if data_format.has_no_batch() {
let mut dim = working_image.dims().to_vec();
dim.insert(0, 1);
working_image.reshape(&dim)?;
// if image is 3d add a dummy batch dimension
if image.dims().len() == kernel.dims().len() - 1 {
image.reshape(&[1, image.dims()[0], image.dims()[1], image.dims()[2]])?;
}
let image_dims = working_image.dims();
let kernel_dims = working_kernel.dims();
let image_dims = image.dims();
let kernel_dims = kernel.dims();
// Apply padding
let mut padded_image = working_image.clone();
let mut padded_image = image.clone();
padded_image.pad(padding.to_vec(), 2)?;
// Extract dimensions
let batch_size = image_dims[0];
let input_channels = image_dims[1];
let output_channels = kernel_dims[0];
// Calculate slides for each spatial dimension
log::debug!(
"batch_size: {}, output_channels: {}, input_channels: {}",
batch_size,
output_channels,
input_channels
);
let slides = image_dims[2..]
.iter()
.enumerate()
@@ -3868,6 +3766,8 @@ pub fn conv<
})
.collect::<Result<Vec<_>, TensorError>>()?;
log::debug!("slides: {:?}", slides);
let input_channels_per_group = input_channels / num_groups;
let output_channels_per_group = output_channels / num_groups;
@@ -3875,15 +3775,24 @@ pub fn conv<
return Err(TensorError::DimMismatch(format!(
"Given groups={}, expected input channels and output channels to be divisible by groups, but got input_channels={}, output_channels={}",
num_groups, input_channels, output_channels
)).into());
))
.into());
}
log::debug!(
"num_groups: {}, input_channels_per_group: {}, output_channels_per_group: {}",
num_groups,
input_channels_per_group,
output_channels_per_group
);
let num_outputs =
batch_size * num_groups * output_channels_per_group * slides.iter().product::<usize>();
log::debug!("num_outputs: {}", num_outputs);
let mut output: Tensor<ValType<F>> = Tensor::new(None, &[num_outputs])?;
// Create iteration space
let mut iterations = vec![0..batch_size, 0..num_groups, 0..output_channels_per_group];
for slide in slides.iter() {
iterations.push(0..*slide);
@@ -3895,13 +3804,6 @@ pub fn conv<
.multi_cartesian_product()
.collect::<Vec<_>>();
let batch_offset = if data_format.has_no_batch() {
2 // No batch dimension, start coordinates after channels
} else {
3 // Has batch dimension, start coordinates after batch and channels
};
// Main convolution loop
let inner_loop_function = |idx: usize, region: &mut RegionCtx<F>| {
let cartesian_coord_per_group = &cartesian_coord[idx];
let (batch, group, i) = (
@@ -3915,19 +3817,22 @@ pub fn conv<
let mut slices = vec![batch..batch + 1, start_channel..end_channel];
for (i, stride) in stride.iter().enumerate() {
let coord = cartesian_coord_per_group[batch_offset + i] * stride;
let coord = cartesian_coord_per_group[3 + i] * stride;
let kernel_dim = kernel_dims[2 + i];
slices.push(coord..(coord + kernel_dim));
}
let mut local_image = padded_image.get_slice(&slices)?;
local_image.flatten();
let start_kernel_index = group * output_channels_per_group + i;
let end_kernel_index = start_kernel_index + 1;
let mut local_kernel = working_kernel.get_slice(&[start_kernel_index..end_kernel_index])?;
let mut local_kernel = kernel.get_slice(&[start_kernel_index..end_kernel_index])?;
local_kernel.flatten();
// this is dot product notation in einsum format
let mut res = einsum(config, region, &[local_image, local_kernel], "i,i->")?;
if has_bias {
@@ -3948,16 +3853,21 @@ pub fn conv<
region.flush()?;
region.apply_in_loop(&mut output, inner_loop_function)?;
// Reshape output
let mut dims = vec![batch_size, output_channels];
dims.extend(slides.iter().cloned());
output.reshape(&dims)?;
let reshape_output = |output: &mut Tensor<ValType<F>>| -> Result<(), TensorError> {
// remove dummy batch dimension if we added one
let mut dims = vec![batch_size, output_channels];
dims.extend(slides.iter().cloned());
output.reshape(&dims)?;
// Convert output back to requested format
let mut final_output: ValTensor<F> = output.into();
data_format.from_canonical(&mut final_output)?;
Ok(())
};
Ok(final_output)
// remove dummy batch dimension if we added one
reshape_output(&mut output)?;
let output: ValTensor<_> = output.into();
Ok(output)
}
/// Power accumulated layout
@@ -3998,7 +3908,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
Ok(rescaled_inputs)
}
/// Dummy (no constraints) reshape layout
/// Dummy (no contraints) reshape layout
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
new_dims: &[usize],
@@ -4008,7 +3918,7 @@ pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
Ok(t)
}
/// Dummy (no constraints) move_axis layout
/// Dummy (no contraints) move_axis layout
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
source: usize,
@@ -4100,11 +4010,6 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
decomp: bool,
) -> Result<ValTensor<F>, CircuitError> {
let mut output = values[0].clone();
println!("output: {:?}", output);
println!("all prev assigned: {:?}", output.all_prev_assigned());
println!("decomp: {:?}", decomp);
if !output.all_prev_assigned() {
// checks they are in range
if decomp {
@@ -5599,7 +5504,6 @@ pub(crate) fn leaky_relu<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
Tensor::from([alpha.0; 1].into_iter()),
*input_scale,
&crate::graph::Visibility::Fixed,
false,
)?;
let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1);
@@ -5622,12 +5526,12 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
axes: &[usize],
op: impl Fn(
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
&BaseConfig<F>,
&mut RegionCtx<F>,
&[ValTensor<F>; 1],
) -> Result<ValTensor<F>, CircuitError>
+ Send
+ Sync,
) -> Result<ValTensor<F>, CircuitError> {
let mut input = values[0].clone();
@@ -5843,12 +5747,14 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Some(&[101, 201, 302, 403, 503, 603]),
/// &[2, 3],
/// ).unwrap());
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], false).unwrap();
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0, false).unwrap();
/// ```
pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
scale: utils::F32,
tol: f32,
decomp: bool,
) -> Result<ValTensor<F>, CircuitError> {
let mut values = [values[0].clone(), values[1].clone()];
@@ -5863,6 +5769,43 @@ pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values[1] = layouts::identity(config, region, &[values[1].clone()], decomp)?;
}
// regular equality constraint
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
if tol == 0.0 {
// regular equality constraint
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
}
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
// integer scale
let int_scale = scale.0 as IntegerRep;
// felt scale
let felt_scale = integer_rep_to_felt(int_scale);
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
let input_scale_ratio = (scale.0 * tol) as IntegerRep / 2 * 2;
let recip = recip(
config,
region,
&[values[0].clone()],
felt_scale,
felt_scale * F::from(100),
)?;
log::debug!("recip: {}", recip.show());
// Multiply the difference by the recip
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
log::debug!("product: {}", product.show());
let rebased_product = div(
config,
region,
&[product],
integer_rep_to_felt(input_scale_ratio),
)?;
log::debug!("rebased_product: {}", rebased_product.show());
// check that it is within the tolerance range
range_check(config, region, &[rebased_product], &(-int_scale, int_scale))
}

View File

@@ -1,8 +1,6 @@
use std::any::Any;
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::prelude::DatumType;
use crate::{
graph::quantize_tensor,
@@ -98,8 +96,6 @@ pub enum InputType {
Int,
///
TDim,
///
Unknown,
}
impl InputType {
@@ -136,7 +132,6 @@ impl InputType {
let int_input = input.clone().to_i64().unwrap();
*input = T::from_i64(int_input).unwrap();
}
InputType::Unknown => {}
}
}
}
@@ -157,28 +152,6 @@ impl std::str::FromStr for InputType {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<DatumType> for InputType {
fn from(datum_type: DatumType) -> Self {
match datum_type {
DatumType::Bool => InputType::Bool,
DatumType::F16 => InputType::F16,
DatumType::F32 => InputType::F32,
DatumType::F64 => InputType::F64,
DatumType::I8 => InputType::Int,
DatumType::I16 => InputType::Int,
DatumType::I32 => InputType::Int,
DatumType::I64 => InputType::Int,
DatumType::U8 => InputType::Int,
DatumType::U16 => InputType::Int,
DatumType::U32 => InputType::Int,
DatumType::U64 => InputType::Int,
DatumType::TDim => InputType::TDim,
_ => unimplemented!(),
}
}
}
///
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Input {
@@ -301,8 +274,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
Some(v) => v,
None => return Err(CircuitError::UnsetVisibility),
};
self.quantized_values =
quantize_tensor(self.raw_values.clone(), new_scale, &visibility, true)?;
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
Ok(())
}
@@ -318,8 +290,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self

View File

@@ -4,7 +4,6 @@ use crate::{
utils::{self, F32},
},
tensor::{self, Tensor, TensorError},
tensor::{DataFormat, KernelFormat},
};
use super::{base::BaseOp, *};
@@ -44,8 +43,6 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
},
Downsample {
axis: usize,
@@ -57,8 +54,6 @@ pub enum PolyOp {
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
},
Add,
Sub,
@@ -170,12 +165,10 @@ impl<
stride,
padding,
group,
data_format,
kernel_format,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, group, data_format, kernel_format
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
)
}
PolyOp::DeConv {
@@ -183,12 +176,11 @@ impl<
padding,
output_padding,
group,
data_format,
kernel_format,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, output_padding, group, data_format, kernel_format)
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
PolyOp::Slice { axis, start, end } => {
@@ -250,8 +242,6 @@ impl<
padding,
stride,
group,
data_format,
kernel_format,
} => layouts::conv(
config,
region,
@@ -259,8 +249,6 @@ impl<
padding,
stride,
*group,
*data_format,
*kernel_format,
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
@@ -321,8 +309,6 @@ impl<
output_padding,
stride,
group,
data_format,
kernel_format,
} => layouts::deconv(
config,
region,
@@ -331,8 +317,6 @@ impl<
output_padding,
stride,
*group,
*data_format,
*kernel_format,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,

View File

@@ -1,6 +1,5 @@
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{DataFormat, KernelFormat};
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -1066,8 +1065,6 @@ mod conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1223,8 +1220,6 @@ mod conv_col_ultra_overflow {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1382,8 +1377,6 @@ mod conv_relu_col_ultra_overflow {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis);

View File

@@ -1,6 +1,6 @@
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{Generator, Shell, generate};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{Commitments, RunArgs, pfsys::ProofType};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
use crate::graph::TestDataSource;
@@ -360,13 +360,8 @@ pub fn get_styles() -> clap::builder::Styles {
}
/// Print completions for the given generator
pub fn print_completions<G: Generator>(r#gen: G, cmd: &mut Command) {
generate(
r#gen,
cmd,
cmd.get_name().to_string(),
&mut std::io::stdout(),
);
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
}
#[allow(missing_docs)]
@@ -401,9 +396,8 @@ pub enum Commands {
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
data: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
@@ -435,7 +429,7 @@ pub enum Commands {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to the .json data file to output
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// Hand-written parser for graph variables, eg. batch_size=1
@@ -448,9 +442,8 @@ pub enum Commands {
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
CalibrateSettings {
/// The path to the .json calibration data file.
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
data: Option<PathBuf>,
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
@@ -633,9 +626,8 @@ pub enum Commands {
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
data: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
@@ -661,9 +653,8 @@ pub enum Commands {
#[arg(long, value_hint = clap::ValueHint::Other)]
addr: H160Flag,
/// The path to the .json data file.
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
data: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
@@ -780,7 +771,7 @@ pub enum Commands {
/// view functions that return the data that the network
/// ingests as inputs.
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
data: Option<PathBuf>,
/// The path to the witness file. This is needed for proof swapping for kzg commitments.
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
@@ -875,9 +866,8 @@ pub enum Commands {
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
data: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,

View File

@@ -383,7 +383,7 @@ pub async fn deploy_contract_via_solidity(
///
pub async fn deploy_da_verifier_via_solidity(
settings_path: PathBuf,
input: String,
input: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<&str>,
runs: usize,
@@ -391,7 +391,7 @@ pub async fn deploy_da_verifier_via_solidity(
) -> Result<H160, EthError> {
let (client, client_address) = setup_eth_backend(rpc_url, private_key).await?;
let input = GraphData::from_str(&input).map_err(|_| EthError::GraphData)?;
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
let settings = GraphSettings::load(&settings_path).map_err(|_| EthError::GraphSettings)?;
@@ -688,10 +688,10 @@ fn parse_call_to_account(call_to_account: CallToAccount) -> Result<ParsedCallToA
pub async fn update_account_calls(
addr: H160,
input: String,
input: PathBuf,
rpc_url: Option<&str>,
) -> Result<(), EthError> {
let input = GraphData::from_str(&input).map_err(|_| EthError::GraphData)?;
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
// The data that will be stored in the test contracts that will eventually be read from.
let mut calls_to_accounts = vec![];

View File

@@ -1,6 +1,5 @@
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::CheckMode;
use crate::circuit::region::RegionSettings;
use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::eth::{
deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_multi_sol,
@@ -13,21 +12,21 @@ use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::{
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
};
use crate::pfsys::{
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::EZKL_BUF_CAPACITY;
use crate::{commands::*, EZKLError};
use crate::{Commitments, RunArgs};
use crate::{EZKLError, commands::*};
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::plonk::{self, Circuit};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
@@ -40,6 +39,7 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
use halo2_proofs::poly::kzg::{
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
use halo2_solidity_verifier;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
@@ -50,12 +50,12 @@ use instant::Instant;
use itertools::Itertools;
use log::debug;
use log::{info, trace, warn};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde::Serialize;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::Config;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
use std::fs::File;
use std::io::BufWriter;
use std::io::{Cursor, Write};
@@ -516,9 +516,7 @@ fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
.status()
.is_err()
{
log::warn!(
"bash is not installed on this system, trying to run the install script with sh (may fail)"
);
log::warn!("bash is not installed on this system, trying to run the install script with sh (may fail)");
"sh"
} else {
"bash"
@@ -727,7 +725,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLErr
pub(crate) async fn gen_witness(
compiled_circuit_path: PathBuf,
data: String,
data: PathBuf,
output: Option<PathBuf>,
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
@@ -735,7 +733,7 @@ pub(crate) async fn gen_witness(
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
let data = GraphData::from_str(&data)?;
let data: GraphData = GraphData::from_path(data)?;
let settings = circuit.settings().clone();
let vk = if let Some(vk) = vk_path {
@@ -878,7 +876,7 @@ pub(crate) fn gen_random_data(
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
let slice = tensor.as_slice_mut::<f32>().unwrap();
slice.iter_mut().for_each(|x| *x = rng.r#gen());
slice.iter_mut().for_each(|x| *x = rng.gen());
tensor.cast_to_dt(datum_type).unwrap().into_owned()
}
@@ -1046,7 +1044,7 @@ impl AccuracyResults {
#[allow(clippy::too_many_arguments)]
pub(crate) async fn calibrate(
model_path: PathBuf,
data: String,
data: PathBuf,
settings_path: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: f64,
@@ -1060,7 +1058,7 @@ pub(crate) async fn calibrate(
use crate::fieldutils::IntegerRep;
let data = GraphData::from_str(&data)?;
let data = GraphData::from_path(data)?;
// load the pre-generated settings
let settings = GraphSettings::load(&settings_path)?;
// now retrieve the run args
@@ -1524,7 +1522,7 @@ pub(crate) async fn create_evm_data_attestation(
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
input: String,
input: PathBuf,
witness: Option<PathBuf>,
) -> Result<String, EZKLError> {
#[allow(unused_imports)]
@@ -1538,7 +1536,7 @@ pub(crate) async fn create_evm_data_attestation(
// if input is not provided, we just instantiate dummy input data
let data =
GraphData::from_str(&input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
GraphData::from_path(input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
// The number of input and output instances we attest to for the single call data attestation
let mut input_len = None;
@@ -1627,7 +1625,7 @@ pub(crate) async fn create_evm_data_attestation(
}
pub(crate) async fn deploy_da_evm(
data: String,
data: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
@@ -1870,7 +1868,7 @@ pub(crate) fn setup(
}
pub(crate) async fn setup_test_evm_witness(
data_path: String,
data_path: PathBuf,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
rpc_url: Option<String>,
@@ -1879,7 +1877,7 @@ pub(crate) async fn setup_test_evm_witness(
) -> Result<String, EZKLError> {
use crate::graph::TestOnChainData;
let mut data = GraphData::from_str(&data_path)?;
let mut data = GraphData::from_path(data_path)?;
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
// if both input and output are from files fail
@@ -1907,7 +1905,7 @@ pub(crate) async fn setup_test_evm_witness(
use crate::pfsys::ProofType;
pub(crate) async fn test_update_account_calls(
addr: H160Flag,
data: String,
data: PathBuf,
rpc_url: Option<String>,
) -> Result<String, EZKLError> {
use crate::eth::update_account_calls;

View File

@@ -33,7 +33,7 @@ pub enum GraphError {
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node has misformed params: {0}")]
#[error("a node is has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]

View File

@@ -528,27 +528,6 @@ impl GraphData {
}
}
/// Loads graph input data from a string, first seeing if it is a file path or JSON data
/// If it is a file path, it will load the data from the file
/// Otherwise, it will attempt to parse the string as JSON data
///
/// # Arguments
/// * `data` - String containing the input data
/// # Returns
/// A new GraphData instance containing the loaded data
pub fn from_str(data: &str) -> Result<Self, GraphError> {
let graph_input = serde_json::from_str(data);
match graph_input {
Ok(graph_input) => {
return Ok(graph_input);
}
Err(_) => {
let path = std::path::PathBuf::from(data);
GraphData::from_path(path)
}
}
}
/// Loads graph input data from a file
///
/// # Arguments
@@ -630,12 +609,8 @@ impl GraphData {
if input.len() % input_size != 0 {
return Err(GraphError::InvalidDims(
0,
format!(
"calibration data length (={}) must be evenly divisible by the original input_size(={})",
input.len(),
input_size
),
"calibration data length must be evenly divisible by the original input_size"
.to_string(),
));
}

View File

@@ -455,10 +455,6 @@ pub struct GraphSettings {
pub num_blinding_factors: Option<usize>,
/// unix time timestamp
pub timestamp: Option<u128>,
/// Model inputs types (if any)
pub input_types: Option<Vec<InputType>>,
/// Model outputs types (if any)
pub output_types: Option<Vec<InputType>>,
}
impl GraphSettings {

View File

@@ -1,6 +1,7 @@
use super::errors::GraphError;
use super::extract_const_quantized_values;
use super::node::*;
use super::scale_to_multiplier;
use super::vars::*;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
@@ -378,15 +379,9 @@ pub struct ParsedNodes {
pub nodes: BTreeMap<usize, NodeType>,
inputs: Vec<usize>,
outputs: Vec<Outlet>,
output_types: Vec<InputType>,
}
impl ParsedNodes {
/// Returns the output types of the computational graph.
pub fn get_output_types(&self) -> Vec<InputType> {
self.output_types.clone()
}
/// Returns the number of the computational graph's inputs
pub fn num_inputs(&self) -> usize {
self.inputs.len()
@@ -496,16 +491,6 @@ impl Model {
Ok(om)
}
/// Gets the input types from the parsed nodes
pub fn get_input_types(&self) -> Result<Vec<InputType>, GraphError> {
self.graph.get_input_types()
}
/// Gets the output types from the parsed nodes
pub fn get_output_types(&self) -> Vec<InputType> {
self.graph.get_output_types()
}
///
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
@@ -589,11 +574,6 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
num_shuffles: res.num_shuffles,
@@ -724,11 +704,6 @@ impl Model {
nodes,
inputs: model.inputs.iter().map(|o| o.node).collect(),
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
output_types: model
.outputs
.iter()
.map(|o| Ok::<InputType, GraphError>(model.outlet_fact(*o)?.datum_type.into()))
.collect::<Result<Vec<_>, GraphError>>()?,
};
let duration = start_time.elapsed();
@@ -887,15 +862,6 @@ impl Model {
nodes: subgraph_nodes,
inputs: model.inputs.iter().map(|o| o.node).collect(),
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
output_types: model
.outputs
.iter()
.map(|o| {
Ok::<InputType, GraphError>(
model.outlet_fact(*o)?.datum_type.into(),
)
})
.collect::<Result<Vec<_>, GraphError>>()?,
};
let om = Model {
@@ -1172,10 +1138,17 @@ impl Model {
})?;
if run_args.output_visibility.is_public() || run_args.output_visibility.is_fixed() {
let output_scales = self.graph.get_output_scales().map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
tol.scale = scale_to_multiplier(output_scales[i]).into();
let comparators = if run_args.output_visibility == Visibility::Public {
let res = vars
.instance
@@ -1198,6 +1171,7 @@ impl Model {
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::Output {
tol,
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
)
@@ -1459,9 +1433,11 @@ impl Model {
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.map(|output| {
.enumerate()
.map(|(i, output)| {
let mut comparator: ValTensor<Fp> = (0..output.len())
.map(|_| {
if !self.visibility.output.is_fixed() {
@@ -1474,10 +1450,14 @@ impl Model {
.into();
comparator.reshape(output.dims())?;
let mut tol = run_args.tolerance;
tol.scale = scale_to_multiplier(output_scales[i]).into();
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
Box::new(HybridOp::Output {
tol,
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
)
@@ -1599,16 +1579,4 @@ impl Model {
}
Ok(instance_shapes)
}
/// Input types of the computational graph's public inputs (if any)
pub fn instance_types(&self) -> Result<Vec<InputType>, GraphError> {
let mut instance_types = vec![];
if self.visibility.input.is_public() {
instance_types.extend(self.graph.get_input_types()?);
}
if self.visibility.output.is_public() {
instance_types.extend(self.graph.get_output_types());
}
Ok(instance_types)
}
}

View File

@@ -14,14 +14,14 @@ use super::VarScales;
use super::Visibility;
// Import operation types for different circuit components
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
use crate::circuit::CircuitError;
use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
// Import graph error types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
@@ -740,7 +740,7 @@ fn rescale_const_with_single_use(
if scale_max > &current_scale {
let raw_values = constant.raw_values.clone();
constant.quantized_values =
super::quantize_tensor(raw_values, *scale_max, param_visibility, true)?;
super::quantize_tensor(raw_values, *scale_max, param_visibility)?;
}
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,7 +22,6 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -32,6 +31,7 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{
@@ -39,8 +39,9 @@ use tract_onnx::tract_hir::{
ops::array::{Pad, PadMode, TypedConcat},
ops::cnn::PoolSpec,
ops::konst::Const,
ops::nn::DataFormat,
tract_core::ops::cast::Cast,
tract_core::ops::cnn::{MaxPool, SumPool},
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
};
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
@@ -68,33 +69,6 @@ pub fn quantize_float(
Ok(scaled)
}
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
/// NAN gets mapped to 0. And values that exceed the max are capped.
/// Arguments
///
/// * `elem` - the element to quantize.
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float_and_cap(
elem: &f64,
shift: f64,
scale: crate::Scale,
) -> Result<IntegerRep, TensorError> {
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
return Ok(IntegerRep::MAX);
} else if *elem < -max_value {
return Ok(IntegerRep::MIN);
}
// we parallelize the quantization process as it seems to be quite slow at times
let scaled = (mult * *elem + shift).round() as IntegerRep;
Ok(scaled)
}
/// Dequantizes a field element to a f64 using a fixed point representation.
/// Arguments
/// * `felt` - the field element to dequantize.
@@ -406,7 +380,7 @@ pub fn new_op_from_onnx(
let range = (start..end).step_by(delta).collect::<Vec<_>>();
let raw_value = range.iter().map(|x| *x as f32).collect::<Tensor<_>>();
// Quantize the raw value (integers)
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed, false)?;
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
@@ -738,7 +712,6 @@ pub fn new_op_from_onnx(
raw_value.clone(),
constant_scale,
&run_args.param_visibility,
true,
)?;
let c = crate::circuit::ops::Constant::new(
quantized_value,
@@ -1173,6 +1146,13 @@ pub fn new_op_from_onnx(
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
let kernel_shape = &pool_spec.kernel_shape;
@@ -1181,7 +1161,6 @@ pub fn new_op_from_onnx(
padding,
stride: stride.to_vec(),
pool_dims: kernel_shape.to_vec(),
data_format: pool_spec.data_format.into(),
})
}
"Ceil" => {
@@ -1335,6 +1314,15 @@ pub fn new_op_from_onnx(
}
}
if ((conv_node.pool_spec.data_format != DataFormat::NCHW)
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
{
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
));
}
let pool_spec = &conv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
@@ -1362,8 +1350,6 @@ pub fn new_op_from_onnx(
padding,
stride,
group,
data_format: conv_node.pool_spec.data_format.into(),
kernel_format: conv_node.kernel_fmt.into(),
})
}
"Not" => SupportedOp::Linear(PolyOp::Not),
@@ -1387,6 +1373,14 @@ pub fn new_op_from_onnx(
}
}
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|| (deconv_node.kernel_format != KernelFormat::OIHW)
{
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
));
}
let pool_spec = &deconv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
@@ -1412,8 +1406,6 @@ pub fn new_op_from_onnx(
output_padding: deconv_node.adjustments.to_vec(),
stride,
group: deconv_node.group,
data_format: deconv_node.pool_spec.data_format.into(),
kernel_format: deconv_node.kernel_format.into(),
})
}
"Downsample" => {
@@ -1497,6 +1489,13 @@ pub fn new_op_from_onnx(
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
@@ -1505,7 +1504,6 @@ pub fn new_op_from_onnx(
stride: stride.to_vec(),
kernel_shape: pool_spec.kernel_shape.to_vec(),
normalized: sumpool_node.normalize,
data_format: pool_spec.data_format.into(),
})
}
"Pad" => {
@@ -1578,20 +1576,13 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
const_value: Tensor<f32>,
scale: crate::Scale,
visibility: &Visibility,
cap: bool,
) -> Result<Tensor<F>, TensorError> {
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
if cap {
Ok::<F, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(
quantize_float_and_cap(&(x).into(), 0.0, scale)?,
))
} else {
Ok(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,
)?))
}
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,
)?))
})?;
value.set_scale(scale);
@@ -1679,7 +1670,7 @@ pub mod tests {
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
let scale = 0;
let visibility = &Visibility::Public;
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility, false).unwrap();
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
assert_eq!(quantized.len(), 10);
assert_eq!(quantized, reference);
}

View File

@@ -97,9 +97,10 @@ impl From<String> for EZKLError {
use std::str::FromStr;
use circuit::{table::Range, CheckMode};
use circuit::{table::Range, CheckMode, Tolerance};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use fieldutils::IntegerRep;
use graph::{Visibility, MAX_PUBLIC_SRS};
use halo2_proofs::poly::{
@@ -274,6 +275,10 @@ impl From<String> for Commitments {
derive(Args, ToFlags)
)]
pub struct RunArgs {
/// Error tolerance for model outputs
/// Only applicable when outputs are public
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
pub tolerance: Tolerance,
/// Fixed point scaling factor for quantizing inputs
/// Higher values provide more precision but increase circuit complexity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
@@ -360,6 +365,7 @@ impl Default for RunArgs {
fn default() -> Self {
Self {
bounded_log_lookup: false,
tolerance: Tolerance::default(),
input_scale: 7,
param_scale: 7,
scale_rebase_multiplier: 1,
@@ -393,16 +399,6 @@ impl RunArgs {
pub fn validate(&self) -> Result<(), String> {
let mut errors = Vec::new();
// check if the largest represented integer in the decomposed form overflows IntegerRep
// try it with the largest possible value
let max_decomp = (self.decomp_base as IntegerRep).checked_pow(self.decomp_legs as u32);
if max_decomp.is_none() {
errors.push(format!(
"decomp_base^decomp_legs overflows IntegerRep: {}^{}",
self.decomp_base, self.decomp_legs
));
}
// Visibility validations
if self.param_visibility == Visibility::Public {
errors.push(
@@ -411,6 +407,10 @@ impl RunArgs {
);
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
}
// Scale validations
if self.scale_rebase_multiplier < 1 {
errors.push("scale_rebase_multiplier must be >= 1".to_string());
@@ -459,6 +459,11 @@ impl RunArgs {
warn!("logrows exceeds maximum public SRS size");
}
// Validate tolerance is non-negative
if self.tolerance.val < 0.0 {
errors.push("tolerance cannot be negative".to_string());
}
// Performance warnings
if self.input_scale > 20 || self.param_scale > 20 {
warn!("High scale values (>20) may impact performance");
@@ -605,6 +610,23 @@ mod tests {
assert!(err.contains("num_inner_cols must be >= 1"));
}
#[test]
fn test_invalid_tolerance() {
let mut args = RunArgs::default();
args.tolerance.val = 1.0;
args.output_visibility = Visibility::Private;
let err = args.validate().unwrap_err();
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
}
#[test]
fn test_negative_tolerance() {
let mut args = RunArgs::default();
args.tolerance.val = -1.0;
let err = args.validate().unwrap_err();
assert!(err.contains("tolerance cannot be negative"));
}
#[test]
fn test_zero_batch_size() {
let mut args = RunArgs::default();

View File

@@ -1,6 +1,6 @@
use thiserror::Error;
use super::{ops::DecompositionError, DataFormat};
use super::ops::DecompositionError;
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
@@ -44,7 +44,4 @@ pub enum TensorError {
/// Index out of bounds
#[error("index {0} out of bounds for dimension {1}")]
IndexOutOfBounds(usize, usize),
/// Invalid data conversion
#[error("invalid data conversion from format {0} to {1}")]
InvalidDataConversion(DataFormat, DataFormat),
}

View File

@@ -9,7 +9,6 @@ pub mod var;
pub use errors::TensorError;
use core::hash::Hash;
use halo2curves::ff::PrimeField;
use maybe_rayon::{
prelude::{
@@ -27,7 +26,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{IntegerRep, integer_rep_to_felt},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::Visibility,
};
@@ -62,7 +61,7 @@ pub trait TensorType: Clone + Debug + 'static {
}
macro_rules! tensor_type {
($rust_type:ty, $tensor_type:ident, $zero:expr_2021, $one:expr_2021) => {
($rust_type:ty, $tensor_type:ident, $zero:expr, $one:expr) => {
impl TensorType for $rust_type {
fn zero() -> Option<Self> {
Some($zero)
@@ -415,7 +414,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
Err(_) => {
return Err(TensorError::FileLoadError(
"Failed to read tensor".to_string(),
));
))
}
}
}
@@ -1254,7 +1253,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get last element of empty tensor".to_string(),
));
))
}
};
@@ -1279,7 +1278,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get first element of empty tensor".to_string(),
));
))
}
};
@@ -1692,8 +1691,8 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
lhs.par_iter_mut()
.zip(rhs)
.map(|(o, r)| match T::zero() {
Some(zero) => {
.map(|(o, r)| {
if let Some(zero) = T::zero() {
if r != zero {
*o = o.clone() % r;
Ok(())
@@ -1702,10 +1701,11 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
"Cannot divide by zero in remainder".to_string(),
))
}
} else {
Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
))
}
_ => Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
)),
})
.collect::<Result<Vec<_>, _>>()?;
@@ -1768,228 +1768,6 @@ pub fn get_broadcasted_shape(
}
////////////////////////
/// The shape of data for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
pub enum DataFormat {
/// NCHW
#[default]
NCHW,
/// NHWC
NHWC,
/// CHW
CHW,
/// HWC
HWC,
}
// as str
impl core::fmt::Display for DataFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DataFormat::NCHW => write!(f, "NCHW"),
DataFormat::NHWC => write!(f, "NHWC"),
DataFormat::CHW => write!(f, "CHW"),
DataFormat::HWC => write!(f, "HWC"),
}
}
}
impl DataFormat {
/// Get the format's canonical form
pub fn canonical(&self) -> DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => self.clone(),
}
}
/// no batch dim
pub fn has_no_batch(&self) -> bool {
match self {
DataFormat::CHW | DataFormat::HWC => true,
_ => false,
}
}
/// Convert tensor to canonical format (NCHW or CHW)
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
tensor: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
DataFormat::NHWC => {
// For ND: Move channels from last axis to position after batch
let ndims = tensor.dims().len();
if ndims > 2 {
tensor.move_axis(ndims - 1, 1)?;
}
}
DataFormat::HWC => {
// For ND: Move channels from last axis to first position
let ndims = tensor.dims().len();
if ndims > 1 {
tensor.move_axis(ndims - 1, 0)?;
}
}
_ => {} // NCHW/CHW are already in canonical format
}
Ok(())
}
/// Convert tensor from canonical format to target format
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
tensor: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
DataFormat::NHWC => {
// Move channels from position 1 to end
let ndims = tensor.dims().len();
if ndims > 2 {
tensor.move_axis(1, ndims - 1)?;
}
}
DataFormat::HWC => {
// Move channels from position 0 to end
let ndims = tensor.dims().len();
if ndims > 1 {
tensor.move_axis(0, ndims - 1)?;
}
}
_ => {} // NCHW/CHW don't need conversion
}
Ok(())
}
/// Get the position of the channel dimension
pub fn get_channel_dim(&self, ndims: usize) -> usize {
match self {
DataFormat::NCHW => 1,
DataFormat::NHWC => ndims - 1,
DataFormat::CHW => 0,
DataFormat::HWC => ndims - 1,
}
}
}
/// The shape of the kernel for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
pub enum KernelFormat {
/// HWIO
HWIO,
/// OIHW
#[default]
OIHW,
/// OHWI
OHWI,
}
impl core::fmt::Display for KernelFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KernelFormat::HWIO => write!(f, "HWIO"),
KernelFormat::OIHW => write!(f, "OIHW"),
KernelFormat::OHWI => write!(f, "OHWI"),
}
}
}
impl KernelFormat {
/// Get the format's canonical form
pub fn canonical(&self) -> KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => self.clone(),
}
}
/// Convert kernel to canonical format (OIHW)
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
kernel: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
KernelFormat::HWIO => {
let kdims = kernel.dims().len();
// Move output channels from last to first
kernel.move_axis(kdims - 1, 0)?;
// Move input channels from new last to second position
kernel.move_axis(kdims - 1, 1)?;
}
KernelFormat::OHWI => {
let kdims = kernel.dims().len();
// Move input channels from last to second position
kernel.move_axis(kdims - 1, 1)?;
}
_ => {} // OIHW is already canonical
}
Ok(())
}
/// Convert kernel from canonical format to target format
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
kernel: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
KernelFormat::HWIO => {
let kdims = kernel.dims().len();
// Move input channels from second position to last
kernel.move_axis(1, kdims - 1)?;
// Move output channels from first to last
kernel.move_axis(0, kdims - 1)?;
}
KernelFormat::OHWI => {
let kdims = kernel.dims().len();
// Move input channels from second position to last
kernel.move_axis(1, kdims - 1)?;
}
_ => {} // OIHW doesn't need conversion
}
Ok(())
}
/// Get the position of input and output channel dimensions
pub fn get_channel_dims(&self, ndims: usize) -> (usize, usize) {
// (input_ch, output_ch)
match self {
KernelFormat::OIHW => (1, 0),
KernelFormat::HWIO => (ndims - 2, ndims - 1),
KernelFormat::OHWI => (ndims - 1, 0),
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<tract_onnx::tract_hir::ops::nn::DataFormat> for DataFormat {
fn from(fmt: tract_onnx::tract_hir::ops::nn::DataFormat) -> Self {
match fmt {
tract_onnx::tract_hir::ops::nn::DataFormat::NCHW => DataFormat::NCHW,
tract_onnx::tract_hir::ops::nn::DataFormat::NHWC => DataFormat::NHWC,
tract_onnx::tract_hir::ops::nn::DataFormat::CHW => DataFormat::CHW,
tract_onnx::tract_hir::ops::nn::DataFormat::HWC => DataFormat::HWC,
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat> for KernelFormat {
fn from(fmt: tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat) -> Self {
match fmt {
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::HWIO => {
KernelFormat::HWIO
}
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OIHW => {
KernelFormat::OIHW
}
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OHWI => {
KernelFormat::OHWI
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -387,7 +387,7 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.is_empty() {
} else if t.len() == 0 {
return Err(TensorError::DimMismatch("add".to_string()));
}
@@ -441,7 +441,7 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.is_empty() {
} else if t.len() == 0 {
return Err(TensorError::DimMismatch("sub".to_string()));
}
// calculate value of output
@@ -492,7 +492,7 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.is_empty() {
} else if t.len() == 0 {
return Err(TensorError::DimMismatch("mult".to_string()));
}
// calculate value of output
@@ -1326,6 +1326,7 @@ pub fn pad<T: TensorType>(
///
/// # Errors
/// Returns a TensorError if the tensors in `inputs` have incompatible dimensions for concatenation along the specified `axis`.
pub fn concat<T: TensorType + Send + Sync>(
inputs: &[&Tensor<T>],
axis: usize,
@@ -2101,6 +2102,7 @@ pub mod nonlinearities {
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn tanh(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;

Binary file not shown.

View File

@@ -1,7 +1,8 @@
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[cfg(test)]
mod native_tests {
use ezkl::circuit::Tolerance;
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
@@ -186,7 +187,7 @@ mod native_tests {
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
const LARGE_TESTS: [&str; 8] = [
const LARGE_TESTS: [&str; 7] = [
"self_attention",
"nanoGPT",
"multihead_attention",
@@ -194,7 +195,6 @@ mod native_tests {
"mnist_gan",
"smallworm",
"fr_age",
"1d_conv",
];
const ACCURACY_CAL_TESTS: [&str; 6] = [
@@ -522,7 +522,7 @@ mod native_tests {
use crate::native_tests::run_js_tests;
use crate::native_tests::render_circuit;
use crate::native_tests::model_serialization_different_binaries;
use rand::Rng;
use tempdir::TempDir;
use ezkl::Commitments;
@@ -543,7 +543,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, false, None, None);
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
});
@@ -608,7 +608,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -618,10 +618,22 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, true, Some(8194), Some(4));
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_tolerance_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// gen random number between 0.0 and 1.0
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(32776), Some(5));
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
@@ -632,7 +644,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, false, None, None);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
}
@@ -642,7 +654,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -651,7 +663,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -660,7 +672,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -669,7 +681,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -678,7 +690,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -687,7 +699,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -696,7 +708,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -706,7 +718,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -716,7 +728,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -725,7 +737,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -735,7 +747,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -744,7 +756,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -754,7 +766,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -764,7 +776,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -774,7 +786,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, false, None, None);
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -784,7 +796,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, false, None, None);
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -794,7 +806,7 @@ mod native_tests {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
// needs an extra row for the large model
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, false, None, None);
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
test_dir.close().unwrap();
}
@@ -953,7 +965,7 @@ mod native_tests {
});
seq!(N in 0..=7 {
seq!(N in 0..=6 {
#(#[test_case(LARGE_TESTS[N])])*
#[ignore]
@@ -971,7 +983,7 @@ mod native_tests {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, false, None, Some(5));
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, Some(5));
test_dir.close().unwrap();
}
});
@@ -1447,10 +1459,12 @@ mod native_tests {
batch_size: usize,
cal_target: &str,
scales_to_use: Option<Vec<u32>>,
tolerance: f32,
bounded_lookup_log: bool,
decomp_base: Option<usize>,
decomp_legs: Option<usize>,
) {
let mut tolerance = tolerance;
gen_circuit_settings_and_witness(
test_dir,
example_name.clone(),
@@ -1461,6 +1475,7 @@ mod native_tests {
cal_target,
scales_to_use,
2,
&mut tolerance,
Commitments::KZG,
2,
bounded_lookup_log,
@@ -1468,17 +1483,128 @@ mod native_tests {
decomp_legs,
);
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
if tolerance > 0.0 {
// load witness and shift the output by a small amount that is less than tolerance percent
let witness = GraphWitness::from_path(
format!("{}/{}/witness.json", test_dir, example_name).into(),
)
.unwrap();
let witness = witness.clone();
let outputs = witness.outputs.clone();
// get values as i64
let output_perturbed_safe: Vec<Vec<halo2curves::bn256::Fr>> = outputs
.iter()
.map(|sv| {
sv.iter()
.map(|v| {
// randomly perturb by a small amount less than tolerance
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::zero()
} else {
integer_rep_to_felt(
(felt_to_integer_rep(*v) as f32
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
as IntegerRep,
)
};
*v + perturbation
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// get values as i64
let output_perturbed_bad: Vec<Vec<halo2curves::bn256::Fr>> = outputs
.iter()
.map(|sv| {
sv.iter()
.map(|v| {
// randomly perturb by a small amount less than tolerance
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::from(2)
} else {
integer_rep_to_felt(
(felt_to_integer_rep(*v) as f32
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
as IntegerRep,
)
};
*v + perturbation
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let good_witness = GraphWitness {
outputs: output_perturbed_safe,
..witness.clone()
};
// save
good_witness
.save(format!("{}/{}/witness_ok.json", test_dir, example_name).into())
.unwrap();
let bad_witness = GraphWitness {
outputs: output_perturbed_bad,
..witness.clone()
};
// save
bad_witness
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
.unwrap();
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness_ok.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness_bad.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(!status.success());
} else {
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
}
#[allow(clippy::too_many_arguments)]
@@ -1492,6 +1618,7 @@ mod native_tests {
cal_target: &str,
scales_to_use: Option<Vec<u32>>,
num_inner_columns: usize,
tolerance: &mut f32,
commitment: Commitments,
lookup_safety_margin: usize,
bounded_lookup_log: bool,
@@ -1506,16 +1633,13 @@ mod native_tests {
"--settings-path={}/{}/settings.json",
test_dir, example_name
),
format!(
"--variables=batch_size->{},sequence_length->100,<Sym1>->1",
batch_size
),
format!("--variables=batch_size->{}", batch_size),
format!("--input-visibility={}", input_visibility),
format!("--param-visibility={}", param_visibility),
format!("--output-visibility={}", output_visibility),
format!("--num-inner-cols={}", num_inner_columns),
format!("--tolerance={}", tolerance),
format!("--commitment={}", commitment),
format!("--logrows={}", 22),
];
// if output-visibility is fixed set --range-check-inputs-outputs to False
@@ -1571,6 +1695,24 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let mut settings =
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
if any_output_scales_smol {
// set the tolerance to 0.0
settings.run_args.tolerance = Tolerance {
val: 0.0,
scale: 0.0.into(),
};
settings
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
.unwrap();
*tolerance = 0.0;
}
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"compile-circuit",
@@ -1583,6 +1725,7 @@ mod native_tests {
test_dir, example_name
),
])
.stdout(std::process::Stdio::null())
.status()
.expect("failed to execute process");
assert!(status.success());
@@ -1625,6 +1768,7 @@ mod native_tests {
cal_target,
None,
2,
&mut 0.0,
Commitments::KZG,
2,
false,
@@ -1910,6 +2054,7 @@ mod native_tests {
target_str,
scales_to_use,
num_inner_columns,
&mut 0.0,
commitment,
lookup_safety_margin,
false,
@@ -2344,6 +2489,7 @@ mod native_tests {
// we need the accuracy
Some(vec![4]),
1,
&mut 0.0,
Commitments::KZG,
2,
false,

View File

@@ -46,28 +46,7 @@ mod py_tests {
assert!(status.success());
});
// set VOICE_DATA_DIR environment variable
unsafe {
std::env::set_var("VOICE_DATA_DIR", format!("{}", voice_data_dir));
}
}
fn download_catdog_data() {
let cat_and_dog_data_dir = shellexpand::tilde("~/data/catdog_data");
DOWNLOAD_VOICE_DATA.call_once(|| {
let status = Command::new("bash")
.args([
"examples/notebooks/cat_and_dog_data.sh",
&cat_and_dog_data_dir,
])
.status()
.expect("failed to execute process");
assert!(status.success());
});
// set VOICE_DATA_DIR environment variable
unsafe {
std::env::set_var("CATDOG_DATA_DIR", format!("{}", cat_and_dog_data_dir));
}
std::env::set_var("VOICE_DATA_DIR", format!("{}", voice_data_dir));
}
fn setup_py_env() {
@@ -246,20 +225,6 @@ mod py_tests {
anvil_child.kill().unwrap();
}
#[test]
fn cat_and_dog_notebook_() {
crate::py_tests::init_binary();
let mut anvil_child = crate::py_tests::start_anvil(false);
crate::py_tests::download_catdog_data();
let test_dir: TempDir = TempDir::new("cat_and_dog").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "cat_and_dog.ipynb");
run_notebook(path, "cat_and_dog.ipynb");
test_dir.close().unwrap();
anvil_child.kill().unwrap();
}
#[test]
fn reusable_verifier_notebook_() {
crate::py_tests::init_binary();

View File

@@ -48,6 +48,7 @@ def test_py_run_args():
run_args = ezkl.PyRunArgs()
run_args.input_visibility = "hashed"
run_args.output_visibility = "hashed"
run_args.tolerance = 1.5
def test_poseidon_hash():
@@ -872,8 +873,7 @@ def get_examples():
'linear_regression',
"mnist_gan",
"smallworm",
"fr_age",
"1d_conv",
"fr_age"
]
examples = []
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
@@ -900,12 +900,7 @@ async def test_all_examples(model_file, input_file):
proof_path = os.path.join(folder_path, 'proof.json')
print("Testing example: ", model_file)
run_args = ezkl.PyRunArgs()
run_args.variables = [("batch_size", 1), ("sequence_length", 100), ("<Sym1>", 1)]
run_args.logrows = 22
res = ezkl.gen_settings(model_file, settings_path, py_run_args=run_args)
res = ezkl.gen_settings(model_file, settings_path)
assert res
res = await ezkl.calibrate_settings(

View File

@@ -337,7 +337,7 @@ mod wasm32 {
// Run compiled circuit validation on onnx network (should fail)
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK.to_vec()));
assert!(circuit.is_err());
// Run compiled circuit validation on compiled network (should pass)
// Run compiled circuit validation on comiled network (should pass)
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()));
assert!(circuit.is_ok());
// Run input validation on witness (should fail)