mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
21 Commits
release-v1
...
ac/panic-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be5d241b42 | ||
|
|
ae076aef09 | ||
|
|
a7544f4060 | ||
|
|
c19fa5218a | ||
|
|
eb205d0c73 | ||
|
|
db498f8d7c | ||
|
|
a363c91160 | ||
|
|
f7f04415fa | ||
|
|
de8d419e5d | ||
|
|
a38d318923 | ||
|
|
864990fe2d | ||
|
|
29c3e4f977 | ||
|
|
0689115828 | ||
|
|
99f741304a | ||
|
|
20ac99fdbf | ||
|
|
532fa65e93 | ||
|
|
cfe5db545c | ||
|
|
21ad56aea1 | ||
|
|
4ed7e0fd29 | ||
|
|
05d1f10615 | ||
|
|
9a8c754e45 |
44
.github/workflows/benchmarks.yml
vendored
44
.github/workflows/benchmarks.yml
vendored
@@ -12,10 +12,10 @@ jobs:
|
||||
contents: read
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -29,10 +29,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -46,10 +46,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -63,10 +63,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -80,10 +80,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -97,10 +97,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -114,10 +114,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -131,10 +131,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -148,10 +148,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -165,10 +165,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -182,10 +182,10 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [bench_poseidon]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
|
||||
22
.github/workflows/engine.yml
vendored
22
.github/workflows/engine.yml
vendored
@@ -24,15 +24,15 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
@@ -47,6 +47,10 @@ jobs:
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -172,7 +176,7 @@ jobs:
|
||||
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
@@ -197,7 +201,7 @@ jobs:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Update version in package.json
|
||||
@@ -226,13 +230,13 @@ jobs:
|
||||
NR==30{$0=" specifier: \"" tag "\""}
|
||||
NR==31{$0=" version: \"" tag "\""}
|
||||
NR==400{$0=" /@ezkljs/engine@" tag ":"}
|
||||
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
|
||||
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
with:
|
||||
version: 8
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
@@ -243,4 +247,4 @@ jobs:
|
||||
pnpm run build
|
||||
pnpm publish --no-git-checks
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
4
.github/workflows/large-tests.yml
vendored
4
.github/workflows/large-tests.yml
vendored
@@ -10,10 +10,10 @@ jobs:
|
||||
contents: read
|
||||
runs-on: kaiju
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
|
||||
34
.github/workflows/pypi-gpu.yml
vendored
34
.github/workflows/pypi-gpu.yml
vendored
@@ -28,34 +28,36 @@ jobs:
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
- name: Set pyproject.toml version to match github tag
|
||||
- name: Set pyproject.toml version to match github tag and rename ezkl to ezkl-gpu
|
||||
shell: bash
|
||||
run: |
|
||||
mv pyproject.toml pyproject.toml.orig
|
||||
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
|
||||
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig > pyproject.toml.tmp
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.tmp > pyproject.toml
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
- name: Set Cargo.toml version to match github tag and rename ezkl to ezkl-gpu
|
||||
shell: bash
|
||||
# the ezkl substitution here looks for the first instance of name = "ezkl" and changes it to "ezkl-gpu"
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
|
||||
sed "0,/name = \"ezkl\"/s/name = \"ezkl\"/name = \"ezkl-gpu\"/" Cargo.toml.orig > Cargo.toml.tmp
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.tmp > Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig > Cargo.lock
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
@@ -63,7 +65,7 @@ jobs:
|
||||
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@v1
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: auto
|
||||
@@ -76,7 +78,7 @@ jobs:
|
||||
pip install ezkl-gpu --no-index --find-links dist --force-reinstall
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
with:
|
||||
name: wheels
|
||||
path: dist
|
||||
@@ -92,7 +94,7 @@ jobs:
|
||||
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
|
||||
needs: [linux]
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
|
||||
with:
|
||||
name: wheels
|
||||
- name: List Files
|
||||
@@ -104,14 +106,14 @@ jobs:
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
|
||||
with:
|
||||
packages-dir: ./
|
||||
packages-dir: ./wheels
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
packages-dir: ./wheels
|
||||
|
||||
84
.github/workflows/pypi.yml
vendored
84
.github/workflows/pypi.yml
vendored
@@ -26,10 +26,10 @@ jobs:
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
@@ -56,13 +56,13 @@ jobs:
|
||||
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'universal2-apple-darwin'
|
||||
uses: PyO3/maturin-action@v1
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'x86_64'
|
||||
uses: PyO3/maturin-action@v1
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
@@ -73,9 +73,9 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-macos-${{ matrix.target }}
|
||||
path: dist
|
||||
|
||||
windows:
|
||||
@@ -87,10 +87,10 @@ jobs:
|
||||
matrix:
|
||||
target: [x64, x86]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: ${{ matrix.target }}
|
||||
@@ -113,14 +113,14 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@v1
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
args: --release --out dist --features python-bindings
|
||||
@@ -130,9 +130,9 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0 #v4.6.0
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-windows-${{ matrix.target }}
|
||||
path: dist
|
||||
|
||||
linux:
|
||||
@@ -144,10 +144,10 @@ jobs:
|
||||
matrix:
|
||||
target: [x86_64]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
@@ -176,7 +176,7 @@ jobs:
|
||||
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@v1
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: auto
|
||||
@@ -203,9 +203,9 @@ jobs:
|
||||
python -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-linux-${{ matrix.target }}
|
||||
path: dist
|
||||
|
||||
musllinux:
|
||||
@@ -218,10 +218,10 @@ jobs:
|
||||
target:
|
||||
- x86_64-unknown-linux-musl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
@@ -250,7 +250,7 @@ jobs:
|
||||
sudo apt-get update && sudo apt-get install -y pkg-config libssl-dev
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@v1
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: musllinux_1_2
|
||||
@@ -271,9 +271,9 @@ jobs:
|
||||
python3 -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-musllinux-${{ matrix.target }}
|
||||
path: dist
|
||||
|
||||
musllinux-cross:
|
||||
@@ -287,10 +287,10 @@ jobs:
|
||||
- target: aarch64-unknown-linux-musl
|
||||
arch: aarch64
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
@@ -313,13 +313,13 @@ jobs:
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@v1
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
with:
|
||||
target: ${{ matrix.platform.target }}
|
||||
manylinux: musllinux_1_2
|
||||
args: --release --out dist --features python-bindings
|
||||
|
||||
- uses: uraimo/run-on-arch-action@v2.8.1
|
||||
- uses: uraimo/run-on-arch-action@5397f9e30a9b62422f302092631c99ae1effcd9e #v2.8.1
|
||||
name: Install built wheel
|
||||
with:
|
||||
arch: ${{ matrix.platform.arch }}
|
||||
@@ -334,9 +334,9 @@ jobs:
|
||||
python3 -c "import ezkl"
|
||||
|
||||
- name: Upload wheels
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
|
||||
with:
|
||||
name: wheels
|
||||
name: dist-musllinux-${{ matrix.platform.target }}
|
||||
path: dist
|
||||
|
||||
pypi-publish:
|
||||
@@ -347,24 +347,26 @@ jobs:
|
||||
if: "startsWith(github.ref, 'refs/tags/')"
|
||||
needs: [macos, windows, linux, musllinux, musllinux-cross]
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
|
||||
with:
|
||||
name: wheels
|
||||
pattern: dist-*
|
||||
merge-multiple: true
|
||||
path: wheels
|
||||
- name: List Files
|
||||
run: ls -R
|
||||
|
||||
|
||||
# # publishes to TestPyPI
|
||||
# - name: Publish package distribution to TestPyPI
|
||||
# uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
# uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# packages-dir: ./
|
||||
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
|
||||
with:
|
||||
packages-dir: ./
|
||||
packages-dir: ./wheels
|
||||
|
||||
|
||||
doc-publish:
|
||||
@@ -375,7 +377,7 @@ jobs:
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
|
||||
20
.github/workflows/release.yml
vendored
20
.github/workflows/release.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
|
||||
- name: Create Github Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@v1
|
||||
uses: softprops/action-gh-release@c95fe1489396fe8a9eb87c0abf8aa5b2ef267fda #v2.2.1
|
||||
with:
|
||||
token: ${{ secrets.RELEASE_TOKEN }}
|
||||
tag_name: ${{ env.EZKL_VERSION }}
|
||||
@@ -49,14 +49,14 @@ jobs:
|
||||
RUST_BACKTRACE: 1
|
||||
PCRE2_SYS_STATIC: 1
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ jobs:
|
||||
echo "ASSET=build-artifacts/ezkl-linux-gpu.tar.gz" >> $GITHUB_ENV
|
||||
|
||||
- name: Upload release archive
|
||||
uses: actions/upload-release-asset@v1.0.2
|
||||
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
|
||||
with:
|
||||
@@ -144,8 +144,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Get release version from tag
|
||||
@@ -170,7 +170,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@nightly
|
||||
uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
|
||||
with:
|
||||
target: ${{ matrix.target }}
|
||||
|
||||
@@ -196,7 +196,7 @@ jobs:
|
||||
echo "target flag is: ${{ env.TARGET_FLAGS }}"
|
||||
echo "target dir is: ${{ env.TARGET_DIR }}"
|
||||
|
||||
- name: Build release binary (no asm or metal)
|
||||
- name: Build release binary (no asm or metal)
|
||||
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
|
||||
@@ -233,7 +233,7 @@ jobs:
|
||||
echo "ASSET=build-artifacts/ezkl-win.zip" >> $GITHUB_ENV
|
||||
|
||||
- name: Upload release archive
|
||||
uses: actions/upload-release-asset@v1.0.2
|
||||
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
|
||||
with:
|
||||
|
||||
176
.github/workflows/rust.yml
vendored
176
.github/workflows/rust.yml
vendored
@@ -25,10 +25,10 @@ jobs:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
@@ -45,10 +45,10 @@ jobs:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
@@ -61,10 +61,10 @@ jobs:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
@@ -77,15 +77,15 @@ jobs:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -102,19 +102,19 @@ jobs:
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - uses: mwilliamson/setup-wasmtime-action@v2
|
||||
# - uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
# with:
|
||||
# wasmtime-version: "3.0.1"
|
||||
# - name: Install wasm32-wasi
|
||||
@@ -139,19 +139,19 @@ jobs:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: mwilliamson/setup-wasmtime-action@v2
|
||||
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
- name: Install wasm32-wasi
|
||||
@@ -176,19 +176,19 @@ jobs:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: mwilliamson/setup-wasmtime-action@v2
|
||||
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
- name: Install wasm32-wasi
|
||||
@@ -213,15 +213,15 @@ jobs:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -233,19 +233,19 @@ jobs:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
- uses: nanasess/setup-chromedriver@v2
|
||||
- uses: nanasess/setup-chromedriver@e93e57b843c0c92788f22483f1a31af8ee48db25 #v2.3.0
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
@@ -262,20 +262,22 @@ jobs:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: Large 1D Conv Mock
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_7_expects -- --include-ignored
|
||||
- name: MNIST Gan Mock
|
||||
run: cargo nextest run --verbose tests::large_mock_::large_tests_4_expects -- --include-ignored
|
||||
- name: NanoGPT Mock
|
||||
@@ -292,8 +294,6 @@ jobs:
|
||||
run: cargo nextest run --verbose tests::mock_fixed_params_ --test-threads 32
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
run: cargo nextest run --verbose tests::mock_large_batch_public_outputs_ --test-threads 16
|
||||
- name: kzg inputs
|
||||
@@ -329,27 +329,27 @@ jobs:
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
with:
|
||||
version: 8
|
||||
- name: Use Node.js 18.12.1
|
||||
uses: actions/setup-node@v3
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
@@ -414,28 +414,28 @@ jobs:
|
||||
# runs-on: macos-13
|
||||
# # needs: [build, library-tests, docs]
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: jetli/wasm-pack-action@v0.4.0
|
||||
# - uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
# with:
|
||||
# # Pin to version 0.12.1
|
||||
# version: 'v0.12.1'
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2024-07-18
|
||||
# - uses: actions/checkout@v3
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - name: Use pnpm 8
|
||||
# uses: pnpm/action-setup@v2
|
||||
# uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
# with:
|
||||
# version: 8
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
@@ -448,15 +448,15 @@ jobs:
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
@@ -465,15 +465,15 @@ jobs:
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
with:
|
||||
version: 8
|
||||
- name: Use Node.js 18.12.1
|
||||
uses: actions/setup-node@v3
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
@@ -483,7 +483,7 @@ jobs:
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -529,18 +529,18 @@ jobs:
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
# - uses: actions/checkout@v3
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
@@ -567,15 +567,15 @@ jobs:
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -587,15 +587,15 @@ jobs:
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
@@ -608,15 +608,15 @@ jobs:
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -629,15 +629,15 @@ jobs:
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -654,15 +654,15 @@ jobs:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -675,13 +675,13 @@ jobs:
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
@@ -705,18 +705,18 @@ jobs:
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -756,18 +756,18 @@ jobs:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -781,6 +781,8 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --profile=test-runs
|
||||
- name: Cat and Dog notebook
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::cat_and_dog_notebook_
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
@@ -812,15 +814,15 @@ jobs:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -834,10 +836,10 @@ jobs:
|
||||
needs: [ios-integration-tests]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
|
||||
5
.github/workflows/static-analysis.yml
vendored
5
.github/workflows/static-analysis.yml
vendored
@@ -12,10 +12,10 @@ jobs:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@v1
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
@@ -30,4 +30,3 @@ jobs:
|
||||
run: zizmor .
|
||||
|
||||
|
||||
|
||||
6
.github/workflows/swift-pm.yml
vendored
6
.github/workflows/swift-pm.yml
vendored
@@ -19,8 +19,8 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout EZKL
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Extract TAG from github.ref_name
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Rust (nightly)
|
||||
uses: actions-rs/toolchain@v1
|
||||
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly
|
||||
override: true
|
||||
|
||||
6
.github/workflows/tagging.yml
vendored
6
.github/workflows/tagging.yml
vendored
@@ -11,12 +11,12 @@ jobs:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
uses: mathieudutour/github-tag-action@a22cf08638b34d5badda920f9daf6e72c477b07b #v6.2
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
git tag $RELEASE_TAG
|
||||
|
||||
- name: Push changes
|
||||
uses: ad-m/github-push-action@master
|
||||
uses: ad-m/github-push-action@77c5b412c50b723d2a4fbc6d71fb5723bcd439aa #master
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
with:
|
||||
|
||||
8
Cargo.lock
generated
8
Cargo.lock
generated
@@ -944,7 +944,7 @@ dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"itertools 0.11.0",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
@@ -2397,7 +2397,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#d7ecad83c7439fa1cb450ee4a89c2d0b45604ceb"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec",
|
||||
@@ -2414,7 +2414,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996#bf9d0057a82443be48c4779bbe14961c18fb5996"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
@@ -3139,7 +3139,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.52.6",
|
||||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -276,10 +276,11 @@ macos-metal = ["halo2_proofs/macos"]
|
||||
ios-metal = ["halo2_proofs/ios"]
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#bf9d0057a82443be48c4779bbe14961c18fb5996", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
|
||||
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
@@ -150,6 +150,13 @@ Ezkl is unaudited, beta software undergoing rapid development. There may be bugs
|
||||
|
||||
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
|
||||
|
||||
|
||||
### Advanced security topics
|
||||
|
||||
Check out `docs/advanced_security` for more advanced information on potential threat vectors.
|
||||
|
||||
|
||||
|
||||
### no warranty
|
||||
|
||||
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
|
||||
@@ -73,6 +73,8 @@ impl Circuit<Fr> for MyCircuit {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::NCHW,
|
||||
kernel_format: KernelFormat::OIHW,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -69,6 +69,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
stride: vec![1, 1],
|
||||
kernel_shape: vec![2, 2],
|
||||
normalized: false,
|
||||
data_format: DataFormat::NCHW,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -23,8 +23,6 @@ use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
const L: usize = 10;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct MyCircuit {
|
||||
image: ValTensor<Fr>,
|
||||
@@ -40,7 +38,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, 10>::configure(cs, ())
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::configure(cs, ())
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -48,7 +46,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
|
||||
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE> =
|
||||
PoseidonChip::new(config);
|
||||
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
|
||||
Ok(())
|
||||
@@ -59,7 +57,7 @@ fn runposeidon(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("poseidon");
|
||||
|
||||
for size in [64, 784, 2352, 12288].iter() {
|
||||
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::num_rows(*size)
|
||||
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::num_rows(*size)
|
||||
as f32)
|
||||
.log2()
|
||||
.ceil() as u32;
|
||||
@@ -67,7 +65,7 @@ fn runposeidon(c: &mut Criterion) {
|
||||
|
||||
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
|
||||
let _output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.to_vec())
|
||||
.unwrap();
|
||||
|
||||
let mut image = Tensor::from(message.into_iter().map(Value::known));
|
||||
|
||||
41
docs/advanced_security/public_commitments.md
Normal file
41
docs/advanced_security/public_commitments.md
Normal file
@@ -0,0 +1,41 @@
|
||||
## EZKL Security Note: Public Commitments and Low-Entropy Data
|
||||
|
||||
> **Disclaimer:** this a more technical post that requires some prior knowledge of how ZK proving systems like Halo2 operate, and in particular in how these APIs are constructed. For background reading we highly recommend the [Halo2 book](https://zcash.github.io/halo2/) and [Halo2 Club](https://halo2.club/).
|
||||
|
||||
## Overview of commitments in EZKL
|
||||
|
||||
A common design pattern in a zero knowledge (zk) application is thus:
|
||||
- A prover has some data which is used within a circuit.
|
||||
- This data, as it may be high-dimensional or somewhat private, is pre-committed to using some hash function.
|
||||
- The zk-circuit which forms the core of the application then proves (para-phrasing) a statement of the form:
|
||||
>"I know some data D which when hashed corresponds to the pre-committed to value H + whatever else the circuit is proving over D".
|
||||
|
||||
From our own experience, we've implemented such patterns using snark-friendly hash functions like [Poseidon](https://www.poseidon-hash.info/), for which there is a relatively well vetted [implementation](https://docs.rs/halo2_gadgets/latest/halo2_gadgets/poseidon/index.html) in Halo2. Even then these hash functions can introduce lots of overhead and can be very expensive to generate proofs for if the dimensionality of the data D is large.
|
||||
|
||||
You can also implement such a pattern using Halo2's `Fixed` columns _if the privacy preservation of the pre-image is not necessary_. These are Halo2 columns (i.e in reality just polynomials) that are left unblinded (unlike the blinded `Advice` columns), and whose commitments are shared with the verifier by way of the verifying key for the application's zk-circuit. These commitments are much lower cost to generate than implementing a hashing function, such as Poseidon, within a circuit.
|
||||
|
||||
> **Note:** Blinding is the process whereby a certain set of the final elements (i.e rows) of a Halo2 column are set to random field elements. This is the mechanism by which Halo2 achieves its zero knowledge properties for `Advice` columns. By contrast `Fixed` columns aren't zero-knowledge in that they are vulnerable to dictionary attacks in the same manner a hash function is. Given some set of known or popular data D an attacker can attempt to recover the pre-image of a hash by running D through the hash function to see if the outputs match a public commitment. These attacks aren't "possible" on blinded `Advice` columns.
|
||||
|
||||
> **Further Note:** Note that without blinding, with access to `M` proofs, each of which contains an evaluation of the polynomial at a different point, an attacker can more easily recover a non blinded column's pre-image. This is because each proof generates a new query and evaluation of the polynomial represented by the column and as such with repetition a clearer picture can emerge of the column's pre-image. Thus unblinded columns should only be used for privacy preservation, in the manner of a hash, if the number of proofs generated against a fixed set of values is limited. More formally if M independent and _unique_ queries are generated; if M is equal to the degree + 1 of the polynomial represented by the column (i.e the unique lagrange interpolation of the values in the columns), then the column's pre-image can be recovered. As such as the logrows K increases, the more queries are required to recover the pre-image (as 2^K unique queries are required). This assumes that the entries in the column are not structured, as if they are then the number of queries required to recover the pre-image is reduced (eg. if all rows above a certain point are known to be nil).
|
||||
|
||||
The annoyance in using `Fixed` columns comes from the fact that they require generating a new verifying key every time a new set of commitments is generated.
|
||||
|
||||
> **Example:** Say for instance an application leverages a zero-knowledge circuit to prove the correct execution of a neural network. Every week the neural network is finetuned or retrained on new data. If the architecture remains the same then commiting to the new network parameters, along with a new proof of performance on a test set, would be an ideal setup. If we leverage `Fixed` columns to commit to the model parameters, each new commitment will require re-generating a verifying key and sharing the new key with the verifier(s). This is not-ideal UX and can become expensive if the verifier is deployed on-chain.
|
||||
|
||||
An ideal commitment would thus have the low cost of a `Fixed` column but wouldn't require regenerating a new verifying key for each new commitment.
|
||||
|
||||
### Unblinded Advice Columns
|
||||
|
||||
A first step in designing such a commitment is to allow for optionally unblinded `Advice` columns within the Halo2 API. These won't be included in the verifying key, AND are blinded with a constant factor `1` -- such that if someone knows the pre-image to the commitment, they can recover it by running it through the corresponding polynomial commitment scheme (in ezkl's case [KZG commitments](https://dankradfeist.de/ethereum/2020/06/16/kate-polynomial-commitments.html)).
|
||||
|
||||
This is implemented using the `polycommit` visibility parameter in the ezkl API.
|
||||
|
||||
## The Vulnerability of Public Commitments
|
||||
|
||||
|
||||
Public commitments in EZKL (both Poseidon-hashed inputs and KZG commitments) can be vulnerable to brute-force attacks when input data has low entropy. A malicious actor could reveal committed data by searching through possible input values, compromising privacy in applications like anonymous credentials. This is particularly relevant when input data comes from known finite sets (e.g., names, dates).
|
||||
|
||||
Example Risk: In an anonymous credential system using EZKL for ID verification, an attacker could match hashed outputs against a database of common identifying information to deanonymize users.
|
||||
|
||||
|
||||
|
||||
22
docs/advanced_security/quantization_backdoors.md
Normal file
22
docs/advanced_security/quantization_backdoors.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# EZKL Security Note: Quantization-Induced Model Backdoors
|
||||
|
||||
> Note: this only affects a situation where a party separate to an application's developer has access to the model's weights and can modify them. This is a common scenario in adversarial machine learning research, but can be less common in real-world applications. If you're building your models in house and deploying them yourself, this is less of a concern. If you're building a permisionless system where anyone can submit models, this is more of a concern.
|
||||
|
||||
Models processed through EZKL's quantization step can harbor backdoors that are dormant in the original full-precision model but activate during quantization. These backdoors force specific outputs when triggered, with impact varying by application.
|
||||
|
||||
Key Factors:
|
||||
|
||||
- Larger models increase attack feasibility through more parameter capacity
|
||||
- Smaller quantization scales facilitate attacks by allowing greater weight modifications
|
||||
- Rebase ratio of 1 enables exploitation of convolutional layer consistency
|
||||
|
||||
Limitations:
|
||||
|
||||
- Attack effectiveness depends on calibration settings and internal rescaling operations.
|
||||
- Further research needed on backdoor persistence through witness/proof stages.
|
||||
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model.
|
||||
|
||||
References:
|
||||
|
||||
1. [Quantization Backdoors to Deep Learning Commercial Frameworks (Ma et al., 2021)](https://arxiv.org/abs/2108.09187)
|
||||
2. [Planting Undetectable Backdoors in Machine Learning Models (Goldwasser et al., 2022)](https://arxiv.org/abs/2204.06974)
|
||||
@@ -32,6 +32,7 @@ use mnist::*;
|
||||
use rand::rngs::OsRng;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
||||
mod params;
|
||||
|
||||
const K: usize = 20;
|
||||
@@ -208,6 +209,8 @@ where
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::NCHW,
|
||||
kernel_format: KernelFormat::OIHW,
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
|
||||
1110
examples/notebooks/cat_and_dog.ipynb
Normal file
1110
examples/notebooks/cat_and_dog.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
13
examples/notebooks/cat_and_dog_data.sh
Normal file
13
examples/notebooks/cat_and_dog_data.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
# download tess data
|
||||
# check if first argument has been set
|
||||
if [ ! -z "$1" ]; then
|
||||
DATA_DIR=$1
|
||||
else
|
||||
DATA_DIR=data
|
||||
fi
|
||||
|
||||
echo "Downloading data to $DATA_DIR"
|
||||
|
||||
if [ ! -d "$DATA_DIR/CATDOG" ]; then
|
||||
kaggle datasets download tongpython/cat-and-dog -p $DATA_DIR/CATDOG --unzip
|
||||
fi
|
||||
106
examples/onnx/1d_conv/input.json
Normal file
106
examples/onnx/1d_conv/input.json
Normal file
@@ -0,0 +1,106 @@
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
8761,
|
||||
7654,
|
||||
8501,
|
||||
2404,
|
||||
6929,
|
||||
8858,
|
||||
5946,
|
||||
3673,
|
||||
4131,
|
||||
3854,
|
||||
8137,
|
||||
8239,
|
||||
9038,
|
||||
6299,
|
||||
1118,
|
||||
9737,
|
||||
208,
|
||||
7954,
|
||||
3691,
|
||||
610,
|
||||
3468,
|
||||
3314,
|
||||
8658,
|
||||
8366,
|
||||
2850,
|
||||
477,
|
||||
6114,
|
||||
232,
|
||||
4601,
|
||||
7420,
|
||||
5713,
|
||||
2936,
|
||||
6061,
|
||||
2870,
|
||||
8421,
|
||||
177,
|
||||
7107,
|
||||
7382,
|
||||
6115,
|
||||
5487,
|
||||
8502,
|
||||
2559,
|
||||
1875,
|
||||
129,
|
||||
8533,
|
||||
8201,
|
||||
8414,
|
||||
4775,
|
||||
9817,
|
||||
3127,
|
||||
8761,
|
||||
7654,
|
||||
8501,
|
||||
2404,
|
||||
6929,
|
||||
8858,
|
||||
5946,
|
||||
3673,
|
||||
4131,
|
||||
3854,
|
||||
8137,
|
||||
8239,
|
||||
9038,
|
||||
6299,
|
||||
1118,
|
||||
9737,
|
||||
208,
|
||||
7954,
|
||||
3691,
|
||||
610,
|
||||
3468,
|
||||
3314,
|
||||
8658,
|
||||
8366,
|
||||
2850,
|
||||
477,
|
||||
6114,
|
||||
232,
|
||||
4601,
|
||||
7420,
|
||||
5713,
|
||||
2936,
|
||||
6061,
|
||||
2870,
|
||||
8421,
|
||||
177,
|
||||
7107,
|
||||
7382,
|
||||
6115,
|
||||
5487,
|
||||
8502,
|
||||
2559,
|
||||
1875,
|
||||
129,
|
||||
8533,
|
||||
8201,
|
||||
8414,
|
||||
4775,
|
||||
9817,
|
||||
3127
|
||||
]
|
||||
]
|
||||
}
|
||||
BIN
examples/onnx/1d_conv/network.onnx
Normal file
BIN
examples/onnx/1d_conv/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/integer_div/gen.py
Normal file
42
examples/onnx/integer_div/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x // 3
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.randint(0, 10, (1, 2, 2, 8))
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(x)
|
||||
print(out)
|
||||
print(x/3)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/integer_div/input.json
Normal file
1
examples/onnx/integer_div/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}
|
||||
BIN
examples/onnx/integer_div/network.onnx
Normal file
BIN
examples/onnx/integer_div/network.onnx
Normal file
Binary file not shown.
@@ -4,11 +4,10 @@ use crate::circuit::modules::poseidon::{
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
|
||||
@@ -156,9 +155,6 @@ impl pyo3::ToPyObject for PyG1Affine {
|
||||
#[derive(Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// float: The tolerance for error on model outputs
|
||||
pub tolerance: f32,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing inputs
|
||||
pub input_scale: crate::Scale,
|
||||
@@ -226,7 +222,6 @@ impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
num_inner_cols: py_run_args.num_inner_cols,
|
||||
@@ -251,7 +246,6 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
num_inner_cols: self.num_inner_cols,
|
||||
@@ -338,6 +332,8 @@ enum PyInputType {
|
||||
Int,
|
||||
///
|
||||
TDim,
|
||||
///
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl From<InputType> for PyInputType {
|
||||
@@ -349,6 +345,7 @@ impl From<InputType> for PyInputType {
|
||||
InputType::F64 => PyInputType::F64,
|
||||
InputType::Int => PyInputType::Int,
|
||||
InputType::TDim => PyInputType::TDim,
|
||||
InputType::Unknown => PyInputType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -362,6 +359,7 @@ impl From<PyInputType> for InputType {
|
||||
PyInputType::F64 => InputType::F64,
|
||||
PyInputType::Int => InputType::Int,
|
||||
PyInputType::TDim => InputType::TDim,
|
||||
PyInputType::Unknown => InputType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -376,6 +374,7 @@ impl FromStr for PyInputType {
|
||||
"f64" => Ok(PyInputType::F64),
|
||||
"int" => Ok(PyInputType::Int),
|
||||
"tdim" => Ok(PyInputType::TDim),
|
||||
"unknown" => Ok(PyInputType::Unknown),
|
||||
_ => Err("Invalid value for InputType".to_string()),
|
||||
}
|
||||
}
|
||||
@@ -578,10 +577,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
.map(crate::pfsys::string_to_field::<Fr>)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(|_| PyIOError::new_err("Failed to run poseidon"))?;
|
||||
|
||||
let hash = output[0]
|
||||
@@ -596,7 +592,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
@@ -655,7 +651,7 @@ fn kzg_commit(
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
@@ -1949,7 +1945,7 @@ fn deploy_da_evm(
|
||||
/// does the verifier use data attestation ?
|
||||
///
|
||||
/// addr_vk: str
|
||||
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
/// The address of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
|
||||
@@ -8,10 +8,7 @@ use crate::{
|
||||
Module,
|
||||
},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
|
||||
graph::{
|
||||
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
|
||||
GraphSettings,
|
||||
},
|
||||
graph::{quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings},
|
||||
};
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::{
|
||||
@@ -231,10 +228,7 @@ pub fn poseidonHash(
|
||||
let message: Vec<Fr> = serde_json::from_slice(&message[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
|
||||
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/*
|
||||
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
|
||||
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
*/
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
/*
|
||||
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
|
||||
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
*/
|
||||
|
||||
pub mod poseidon_params;
|
||||
pub mod spec;
|
||||
|
||||
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
use halo2_gadgets::poseidon::{primitives::*, Hash, Pow5Chip, Pow5Config};
|
||||
use halo2_proofs::arithmetic::Field;
|
||||
use halo2_gadgets::poseidon::{
|
||||
primitives::VariableLength, primitives::*, Hash, Pow5Chip, Pow5Config,
|
||||
};
|
||||
use halo2_proofs::halo2curves::bn256::Fr as Fp;
|
||||
use halo2_proofs::{circuit::*, plonk::*};
|
||||
// use maybe_rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
|
||||
use maybe_rayon::prelude::ParallelIterator;
|
||||
use maybe_rayon::slice::ParallelSlice;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -40,22 +38,17 @@ pub struct PoseidonConfig<const WIDTH: usize, const RATE: usize> {
|
||||
pub pow5_config: Pow5Config<Fp, WIDTH, RATE>,
|
||||
}
|
||||
|
||||
type InputAssignments = (Vec<AssignedCell<Fp, Fp>>, AssignedCell<Fp, Fp>);
|
||||
type InputAssignments = Vec<AssignedCell<Fp, Fp>>;
|
||||
|
||||
/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PoseidonChip<
|
||||
S: Spec<Fp, WIDTH, RATE> + Sync,
|
||||
const WIDTH: usize,
|
||||
const RATE: usize,
|
||||
const L: usize,
|
||||
> {
|
||||
pub struct PoseidonChip<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> {
|
||||
config: PoseidonConfig<WIDTH, RATE>,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
PoseidonChip<S, WIDTH, RATE, L>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
PoseidonChip<S, WIDTH, RATE>
|
||||
{
|
||||
/// Creates a new PoseidonChip
|
||||
pub fn configure_with_cols(
|
||||
@@ -82,8 +75,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
PoseidonChip<S, WIDTH, RATE, L>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
|
||||
PoseidonChip<S, WIDTH, RATE>
|
||||
{
|
||||
/// Configuration of the PoseidonChip
|
||||
pub fn configure_with_optional_instance(
|
||||
@@ -113,8 +106,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
|
||||
Module<Fp> for PoseidonChip<S, WIDTH, RATE, L>
|
||||
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Module<Fp>
|
||||
for PoseidonChip<S, WIDTH, RATE>
|
||||
{
|
||||
type Config = PoseidonConfig<WIDTH, RATE>;
|
||||
type InputAssignments = InputAssignments;
|
||||
@@ -183,95 +176,81 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let res = layouter.assign_region(
|
||||
|| "load message",
|
||||
|mut region| {
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
|
||||
match &message {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
v.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, _> = match &message {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
|
||||
match value {
|
||||
ValType::Value(v) => region
|
||||
.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
)
|
||||
.map_err(|e| e.into()),
|
||||
ValType::PrevAssigned(v)
|
||||
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants
|
||||
.get(f)
|
||||
.unwrap()
|
||||
.assigned_cell()
|
||||
.ok_or(ModuleError::ConstantNotAssigned)?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
constants.insert(
|
||||
*f,
|
||||
ValType::AssignedConstant(res.clone(), *f),
|
||||
);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"AssignedValue".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
match value {
|
||||
ValType::Value(v) => region
|
||||
.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
};
|
||||
.map_err(|e| e.into()),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
Ok(v.clone())
|
||||
}
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants
|
||||
.get(f)
|
||||
.unwrap()
|
||||
.assigned_cell()
|
||||
.ok_or(ModuleError::ConstantNotAssigned)?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
let offset = message.len() / WIDTH + 1;
|
||||
constants
|
||||
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
|
||||
|
||||
let zero_val = region
|
||||
.assign_advice_from_constant(
|
||||
|| "",
|
||||
self.config.hash_inputs[0],
|
||||
offset,
|
||||
Fp::ZERO,
|
||||
)
|
||||
.unwrap();
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"AssignedValue".to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
};
|
||||
|
||||
Ok((assigned_message?, zero_val))
|
||||
Ok(assigned_message?)
|
||||
},
|
||||
);
|
||||
log::trace!(
|
||||
@@ -292,7 +271,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
let input_cells = self.layout_inputs(layouter, input, constants)?;
|
||||
|
||||
// empty hash case
|
||||
if input_cells.is_empty() {
|
||||
@@ -306,52 +285,25 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let mut one_iter = false;
|
||||
// do the Tree dance baby
|
||||
while input_cells.len() > 1 || !one_iter {
|
||||
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
|
||||
.chunks(L)
|
||||
.enumerate()
|
||||
.map(|(i, block)| {
|
||||
let _start_time = instant::Instant::now();
|
||||
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
|
||||
// initialize the hasher
|
||||
let hasher = Hash::<_, _, S, VariableLength, WIDTH, RATE>::init(
|
||||
pow5_chip,
|
||||
layouter.namespace(|| "block_hasher"),
|
||||
)?;
|
||||
|
||||
let mut block = block.to_vec();
|
||||
let remainder = block.len() % L;
|
||||
|
||||
if remainder != 0 {
|
||||
block.extend(vec![zero_val.clone(); L - remainder]);
|
||||
}
|
||||
|
||||
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
|
||||
// initialize the hasher
|
||||
let hasher = Hash::<_, _, S, ConstantLength<L>, WIDTH, RATE>::init(
|
||||
pow5_chip,
|
||||
layouter.namespace(|| "block_hasher"),
|
||||
)?;
|
||||
|
||||
let hash = hasher.hash(
|
||||
layouter.namespace(|| "hash"),
|
||||
block.to_vec().try_into().map_err(|_| Error::Synthesis)?,
|
||||
);
|
||||
|
||||
if i == 0 {
|
||||
log::trace!("block (L={:?}) took: {:?}", L, _start_time.elapsed());
|
||||
}
|
||||
|
||||
hash
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into());
|
||||
|
||||
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
|
||||
one_iter = true;
|
||||
input_cells = hashes?;
|
||||
}
|
||||
let hash: AssignedCell<Fp, Fp> = hasher.hash(
|
||||
layouter.namespace(|| "hash"),
|
||||
input_cells
|
||||
.to_vec()
|
||||
.try_into()
|
||||
.map_err(|_| Error::Synthesis)?,
|
||||
)?;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
log::trace!("layout (N={:?}) took: {:?}", len, duration);
|
||||
|
||||
let result = Tensor::from(input_cells.iter().map(|e| ValType::from(e.clone())));
|
||||
let result = Tensor::from(vec![ValType::from(hash.clone())].into_iter());
|
||||
|
||||
let output = match result[0].clone() {
|
||||
ValType::PrevAssigned(v) => v,
|
||||
@@ -390,69 +342,59 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
|
||||
///
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
|
||||
let mut hash_inputs = message;
|
||||
|
||||
let len = hash_inputs.len();
|
||||
let len = message.len();
|
||||
if len == 0 {
|
||||
return Ok(vec![vec![]]);
|
||||
}
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let mut one_iter = false;
|
||||
// do the Tree dance baby
|
||||
while hash_inputs.len() > 1 || !one_iter {
|
||||
let hashes: Vec<Fp> = hash_inputs
|
||||
.par_chunks(L)
|
||||
.map(|block| {
|
||||
let mut block = block.to_vec();
|
||||
let remainder = block.len() % L;
|
||||
|
||||
if remainder != 0 {
|
||||
block.extend(vec![Fp::ZERO; L - remainder].iter());
|
||||
}
|
||||
|
||||
let block_len = block.len();
|
||||
|
||||
let message = block
|
||||
.try_into()
|
||||
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
|
||||
|
||||
Ok(halo2_gadgets::poseidon::primitives::Hash::<
|
||||
_,
|
||||
S,
|
||||
ConstantLength<L>,
|
||||
{ WIDTH },
|
||||
{ RATE },
|
||||
>::init()
|
||||
.hash(message))
|
||||
})
|
||||
.collect::<Result<Vec<_>, ModuleError>>()?;
|
||||
one_iter = true;
|
||||
hash_inputs = hashes;
|
||||
}
|
||||
let hash = halo2_gadgets::poseidon::primitives::Hash::<
|
||||
_,
|
||||
S,
|
||||
VariableLength,
|
||||
{ WIDTH },
|
||||
{ RATE },
|
||||
>::init()
|
||||
.hash(message);
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
log::trace!("run (N={:?}) took: {:?}", len, duration);
|
||||
|
||||
Ok(vec![hash_inputs])
|
||||
Ok(vec![vec![hash]])
|
||||
}
|
||||
|
||||
fn num_rows(mut input_len: usize) -> usize {
|
||||
fn num_rows(input_len: usize) -> usize {
|
||||
// this was determined by running the circuit and looking at the number of constraints
|
||||
// in the test called hash_for_a_range_of_input_sizes, then regressing in python to find the slope
|
||||
let fixed_cost: usize = 41 * L;
|
||||
// import numpy as np
|
||||
// from scipy import stats
|
||||
|
||||
let mut num_rows = 0;
|
||||
// x = np.array([32, 64, 96, 128, 160, 192])
|
||||
// y = np.array([1298, 2594, 3890, 5186, 6482, 7778])
|
||||
|
||||
loop {
|
||||
// the number of times the input_len is divisible by L
|
||||
let num_chunks = input_len / L + 1;
|
||||
num_rows += num_chunks * fixed_cost;
|
||||
if num_chunks == 1 {
|
||||
break;
|
||||
}
|
||||
input_len = num_chunks;
|
||||
}
|
||||
// slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
|
||||
|
||||
num_rows
|
||||
// print(f"slope: {slope}")
|
||||
// print(f"intercept: {intercept}")
|
||||
// print(f"R^2: {r_value**2}")
|
||||
|
||||
// # Predict for any x
|
||||
// def predict(x):
|
||||
// return slope * x + intercept
|
||||
|
||||
// # Test prediction
|
||||
// test_x = 256
|
||||
// print(f"Predicted value for x={test_x}: {predict(test_x)}")
|
||||
// our output:
|
||||
// slope: 40.5
|
||||
// intercept: 2.0
|
||||
// R^2: 1.0
|
||||
// Predicted value for x=256: 10370.0
|
||||
let fixed_cost: usize = 41 * input_len;
|
||||
|
||||
// the cost of the hash function is linear with the number of inputs
|
||||
fixed_cost + 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -479,12 +421,12 @@ mod tests {
|
||||
const RATE: usize = POSEIDON_RATE;
|
||||
const R: usize = 240;
|
||||
|
||||
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>, const L: usize> {
|
||||
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>> {
|
||||
message: ValTensor<Fp>,
|
||||
_spec: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<S: Spec<Fp, WIDTH, RATE>, const L: usize> Circuit<Fp> for HashCircuit<S, L> {
|
||||
impl<S: Spec<Fp, WIDTH, RATE>> Circuit<Fp> for HashCircuit<S> {
|
||||
type Config = PoseidonConfig<WIDTH, RATE>;
|
||||
type FloorPlanner = ModulePlanner;
|
||||
type Params = ();
|
||||
@@ -500,7 +442,7 @@ mod tests {
|
||||
}
|
||||
|
||||
fn configure(meta: &mut ConstraintSystem<Fp>) -> PoseidonConfig<WIDTH, RATE> {
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, L>::configure(meta, ())
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(meta, ())
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -508,7 +450,7 @@ mod tests {
|
||||
config: PoseidonConfig<WIDTH, RATE>,
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
|
||||
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> = PoseidonChip::new(config);
|
||||
chip.layout(
|
||||
&mut layouter,
|
||||
&[self.message.clone()],
|
||||
@@ -523,15 +465,15 @@ mod tests {
|
||||
#[test]
|
||||
fn poseidon_hash_empty() {
|
||||
let message = [];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 2> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
|
||||
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, vec![vec![]]).unwrap();
|
||||
assert_eq!(prover.verify(), Ok(()))
|
||||
}
|
||||
|
||||
@@ -540,13 +482,13 @@ mod tests {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
let message = [Fp::random(rng), Fp::random(rng)];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 2> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -559,13 +501,13 @@ mod tests {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
let message = [Fp::random(rng), Fp::random(rng), Fp::random(rng)];
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 3>::run(message.to_vec()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 9;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 3> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -581,23 +523,21 @@ mod tests {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
env_logger::init();
|
||||
|
||||
{
|
||||
let i = 32;
|
||||
for i in (32..128).step_by(32) {
|
||||
// print a bunch of new lines
|
||||
println!(
|
||||
log::info!(
|
||||
"i is {} -------------------------------------------------",
|
||||
i
|
||||
);
|
||||
|
||||
let message: Vec<Fp> = (0..i).map(|_| Fp::random(rng)).collect::<Vec<_>>();
|
||||
let output =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, 32>::run(message.clone()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 17;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 32> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
@@ -614,13 +554,13 @@ mod tests {
|
||||
|
||||
let mut message: Vec<Fp> = (0..2048).map(|_| Fp::random(rng)).collect::<Vec<_>>();
|
||||
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 25>::run(message.clone()).unwrap();
|
||||
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
|
||||
|
||||
let mut message: Tensor<ValType<Fp>> =
|
||||
message.into_iter().map(|m| Value::known(m).into()).into();
|
||||
|
||||
let k = 17;
|
||||
let circuit = HashCircuit::<PoseidonSpec, 25> {
|
||||
let circuit = HashCircuit::<PoseidonSpec> {
|
||||
message: message.into(),
|
||||
_spec: PhantomData,
|
||||
};
|
||||
|
||||
@@ -20,7 +20,6 @@ use crate::{
|
||||
circuit::{
|
||||
ops::base::BaseOp,
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
@@ -85,55 +84,6 @@ impl CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
|
||||
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
|
||||
pub struct Tolerance {
|
||||
pub val: f32,
|
||||
pub scale: utils::F32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Tolerance {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:.2}", self.val)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Tolerance {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Tolerance {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if let Ok(val) = s.parse::<f32>() {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(1.0),
|
||||
})
|
||||
} else {
|
||||
Err(
|
||||
"Invalid tolerance value provided. It should expressed as a percentage (f32)."
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for Tolerance {
|
||||
fn from(value: f32) -> Self {
|
||||
Tolerance {
|
||||
val: value,
|
||||
scale: utils::F32(1.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CheckMode {
|
||||
@@ -158,29 +108,6 @@ impl<'source> FromPyObject<'source> for CheckMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts Tolerance into a PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl IntoPy<PyObject> for Tolerance {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
(self.val, self.scale.0).to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Tolerance {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(scale),
|
||||
})
|
||||
} else {
|
||||
Err(PyValueError::new_err("Invalid tolerance value provided. "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookups {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
circuit::{layouts, utils},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -57,11 +57,13 @@ pub enum HybridOp {
|
||||
stride: Vec<usize>,
|
||||
kernel_shape: Vec<usize>,
|
||||
normalized: bool,
|
||||
data_format: DataFormat,
|
||||
},
|
||||
MaxPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
pool_dims: Vec<usize>,
|
||||
data_format: DataFormat,
|
||||
},
|
||||
ReduceMin {
|
||||
axes: Vec<usize>,
|
||||
@@ -77,7 +79,6 @@ pub enum HybridOp {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Output {
|
||||
tol: Tolerance,
|
||||
decomp: bool,
|
||||
},
|
||||
Greater,
|
||||
@@ -154,10 +155,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
normalized, data_format
|
||||
} => format!(
|
||||
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
|
||||
padding, stride, kernel_shape, normalized
|
||||
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
|
||||
padding, stride, kernel_shape, normalized, data_format
|
||||
),
|
||||
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
|
||||
@@ -165,9 +166,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
data_format,
|
||||
} => format!(
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
padding, stride, pool_dims
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
|
||||
padding, stride, pool_dims, data_format
|
||||
),
|
||||
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
|
||||
@@ -181,8 +183,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale, output_scale, axes
|
||||
)
|
||||
}
|
||||
HybridOp::Output { tol, decomp } => {
|
||||
format!("OUTPUT (tol={:?}, decomp={})", tol, decomp)
|
||||
HybridOp::Output { decomp } => {
|
||||
format!("OUTPUT (decomp={})", decomp)
|
||||
}
|
||||
HybridOp::Greater => "GREATER".to_string(),
|
||||
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
|
||||
@@ -239,6 +241,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
data_format,
|
||||
} => layouts::sumpool(
|
||||
config,
|
||||
region,
|
||||
@@ -247,6 +250,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
stride,
|
||||
kernel_shape,
|
||||
*normalized,
|
||||
*data_format,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
@@ -287,6 +291,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
data_format,
|
||||
} => layouts::max_pool(
|
||||
config,
|
||||
region,
|
||||
@@ -294,6 +299,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
*data_format,
|
||||
)?,
|
||||
HybridOp::ReduceMax { axes } => {
|
||||
layouts::max_axes(config, region, values[..].try_into()?, axes)?
|
||||
@@ -319,14 +325,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
*output_scale,
|
||||
axes,
|
||||
)?,
|
||||
HybridOp::Output { tol, decomp } => layouts::output(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
tol.scale,
|
||||
tol.val,
|
||||
*decomp,
|
||||
)?,
|
||||
HybridOp::Output { decomp } => {
|
||||
layouts::output(config, region, values[..].try_into()?, *decomp)?
|
||||
}
|
||||
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
|
||||
HybridOp::GreaterEqual => {
|
||||
layouts::greater_equal(config, region, values[..].try_into()?)?
|
||||
|
||||
@@ -24,6 +24,7 @@ use crate::{
|
||||
ops::{accumulated, add, mult, sub},
|
||||
Tensor, TensorError, ValType,
|
||||
},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -156,23 +157,6 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
claimed_output.reshape(input_dims)?;
|
||||
// implicitly check if the prover provided output is within range
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[claimed_output.clone()])?;
|
||||
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let product = pairwise(
|
||||
config,
|
||||
@@ -246,30 +230,6 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[equal_zero_mask.clone(), equal_inverse_mask],
|
||||
)?;
|
||||
|
||||
let masked_output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), not_equal_zero_mask.clone()],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
// check if x is too large only if the decomp would support overflow in the previous op
|
||||
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
|
||||
// here we decompose and extract the sign of the input
|
||||
let sign = sign(config, region, &[masked_output.clone()])?;
|
||||
let abs_value = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[claimed_output.clone(), sign],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
let max_val = create_constant_tensor(integer_rep_to_felt(IntegerRep::MAX), 1);
|
||||
let less_than_max = less(config, region, &[abs_value.clone(), max_val])?;
|
||||
// assert the result is 1
|
||||
let comparison_unit = create_constant_tensor(F::ONE, less_than_max.len());
|
||||
enforce_equality(config, region, &[abs_value, comparison_unit])?;
|
||||
}
|
||||
|
||||
let err_func = |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
x: &ValTensor<F>|
|
||||
@@ -345,7 +305,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.into()
|
||||
};
|
||||
claimed_output.reshape(input_dims)?;
|
||||
// force the output to be positive or zero, also implicitly checks that the ouput is in range
|
||||
// force the output to be positive or zero, also implicitly checks that the output is in range
|
||||
let claimed_output = abs(config, region, &[claimed_output.clone()])?;
|
||||
// rescaled input
|
||||
let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?;
|
||||
@@ -1837,7 +1797,7 @@ pub(crate) fn get_missing_set_elements<
|
||||
|
||||
// get the difference between the two vectors
|
||||
for eval in input_evals.iter() {
|
||||
// delete first occurence of that value
|
||||
// delete first occurrence of that value
|
||||
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
|
||||
fullset_evals.remove(pos);
|
||||
}
|
||||
@@ -1865,7 +1825,7 @@ pub(crate) fn get_missing_set_elements<
|
||||
region.increment(claimed_output.len());
|
||||
|
||||
// input and claimed output should be the shuffles of fullset
|
||||
// concatentate input and claimed output
|
||||
// concatenate input and claimed output
|
||||
let input_and_claimed_output = input.concat(claimed_output.clone())?;
|
||||
|
||||
// assert that this is a permutation/shuffle
|
||||
@@ -2652,9 +2612,9 @@ pub fn mean_of_squares_axes<F: PrimeField + TensorType + PartialOrd + std::hash:
|
||||
let squared = pow(config, region, values, 2)?;
|
||||
let sum_squared = sum_axes(config, region, &[squared], axes)?;
|
||||
|
||||
let dividand: usize = values[0].len() / sum_squared.len();
|
||||
let dividend: usize = values[0].len() / sum_squared.len();
|
||||
|
||||
let mean_squared = div(config, region, &[sum_squared], F::from(dividand as u64))?;
|
||||
let mean_squared = div(config, region, &[sum_squared], F::from(dividend as u64))?;
|
||||
Ok(mean_squared)
|
||||
}
|
||||
|
||||
@@ -3221,6 +3181,7 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
/// use ezkl::tensor::DataFormat;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
@@ -3230,12 +3191,12 @@ pub fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
/// ).unwrap());
|
||||
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false).unwrap();
|
||||
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false, DataFormat::default()).unwrap();
|
||||
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // This time with normalization
|
||||
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true).unwrap();
|
||||
/// let pooled = sumpool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true, DataFormat::default()).unwrap();
|
||||
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
@@ -3247,9 +3208,19 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
stride: &[usize],
|
||||
kernel_shape: &[usize],
|
||||
normalized: bool,
|
||||
data_format: DataFormat,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let batch_size = values[0].dims()[0];
|
||||
let image_channels = values[0].dims()[1];
|
||||
let mut image = values[0].clone();
|
||||
data_format.to_canonical(&mut image)?;
|
||||
|
||||
if data_format.has_no_batch() {
|
||||
let mut dims = image.dims().to_vec();
|
||||
dims.insert(0, 1);
|
||||
image.reshape(&dims)?;
|
||||
}
|
||||
|
||||
let batch_size = image.dims()[0];
|
||||
let image_channels = image.dims()[1];
|
||||
|
||||
let kernel_len = kernel_shape.iter().product();
|
||||
|
||||
@@ -3274,7 +3245,16 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.map(|coord| {
|
||||
let (b, i) = (coord[0], coord[1]);
|
||||
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
|
||||
let output = conv(config, region, &[input, kernel.clone()], padding, stride, 1)?;
|
||||
let output = conv(
|
||||
config,
|
||||
region,
|
||||
&[input, kernel.clone()],
|
||||
padding,
|
||||
stride,
|
||||
1,
|
||||
DataFormat::default(),
|
||||
KernelFormat::default(),
|
||||
)?;
|
||||
res.push(output);
|
||||
Ok(())
|
||||
})
|
||||
@@ -3289,6 +3269,9 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
if normalized {
|
||||
last_elem = div(config, region, &[last_elem], F::from(kernel_len as u64))?;
|
||||
}
|
||||
|
||||
data_format.from_canonical(&mut last_elem)?;
|
||||
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
@@ -3298,6 +3281,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::max_pool;
|
||||
/// use ezkl::tensor::DataFormat;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
@@ -3312,7 +3296,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
/// ).unwrap());
|
||||
/// let pooled = max_pool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2]).unwrap();
|
||||
/// let pooled = max_pool::<Fp>(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2], DataFormat::default()).unwrap();
|
||||
/// let expected: Tensor<IntegerRep> = Tensor::<IntegerRep>::new(Some(&[5, 4, 4, 6]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled.int_evals().unwrap(), expected);
|
||||
///
|
||||
@@ -3324,8 +3308,16 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
padding: &[(usize, usize)],
|
||||
stride: &[usize],
|
||||
pool_dims: &[usize],
|
||||
data_format: DataFormat,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let image = values[0].clone();
|
||||
let mut image = values[0].clone();
|
||||
data_format.to_canonical(&mut image)?;
|
||||
|
||||
if data_format.has_no_batch() {
|
||||
let mut dims = image.dims().to_vec();
|
||||
dims.insert(0, 1);
|
||||
image.reshape(&dims)?;
|
||||
}
|
||||
|
||||
let image_dims = image.dims();
|
||||
|
||||
@@ -3384,38 +3376,38 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let res: ValTensor<F> = output.into();
|
||||
let mut res: ValTensor<F> = output.into();
|
||||
|
||||
data_format.from_canonical(&mut res)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Performs a deconvolution on the given input tensor.
|
||||
/// # Examples
|
||||
/// ```
|
||||
// // expected ouputs are taken from pytorch torch.nn.functional.conv_transpose2d
|
||||
///
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::deconv;
|
||||
/// use ezkl::tensor::{val::ValTensor, DataFormat, KernelFormat};
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// use ezkl::tensor::ValTensor;
|
||||
///
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// // Original test case 1: Channel expansion
|
||||
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2], 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 2: Basic deconvolution
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3424,11 +3416,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 1, 1, 5]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// // Original test case 3: With padding
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3437,11 +3429,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 1, 1, 5]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[17]), &[1, 1, 1, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// // Original test case 4: With stride
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3450,10 +3442,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 1, 1, 5]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 5: Zero padding with stride
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3462,10 +3455,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 1, 1, 5]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 6: Different kernel shape
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3474,10 +3468,11 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 2]),
|
||||
/// &[1, 1, 2, 1],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 7: Different kernel shape without padding
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
@@ -3486,20 +3481,21 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[3, 2]),
|
||||
/// &[1, 1, 2, 1],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// // Original test case 8: Channel expansion with stride
|
||||
/// let c = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap());
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Original test case 9: With bias
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 8, 0, 8, 4, 9, 8, 1, 8]),
|
||||
/// &[1, 1, 3, 3],
|
||||
@@ -3512,11 +3508,89 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[1]),
|
||||
/// &[1],
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 1: NHWC format with HWIO kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 2, 2, 1], // NHWC format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 5, 3]),
|
||||
/// &[2, 2, 1, 1], // HWIO format
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::HWIO).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[27]), &[1, 1, 1, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 2: 1D deconvolution with NCHW format
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3]),
|
||||
/// &[1, 1, 3], // NCH format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2]),
|
||||
/// &[1, 1, 2], // OIH format
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0)], &vec![0], &vec![1], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 7, 6]), &[1, 1, 4]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 3: 3D deconvolution with NCHW format
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4]),
|
||||
/// &[1, 1, 2, 2, 1], // NCDHW format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1]),
|
||||
/// &[1, 1, 1, 1, 2], // OIDHW format
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 3], &vec![0; 3], &vec![1; 3], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 2, 3, 3, 4, 4]), &[1, 1, 2, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 4: Multi-channel with NHWC format and OHWI kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1, 3, 2, 1, 4]), // 2 channels, 2x2 spatial
|
||||
/// &[1, 2, 2, 2], // NHWC format [batch, height, width, channels]
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[1, 2, 2, 2], // OHWI format [out_channels, height, width, in_channels]
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::OHWI).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[10, 24, 4, 41, 78, 27, 27, 66, 39]), &[1, 3, 3, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 5: CHW format (no batch dimension)
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 4, 0, 1]),
|
||||
/// &[1, 2, 2], // CHW format [channels, height, width]
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4]),
|
||||
/// &[1, 1, 2, 2], // OIHW format [out_channels, in_channels, height, width]
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::CHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 6, 6]), &[1, 1, 1, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Additional test case 6: HWC format with HWIO kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 4, 1]),
|
||||
/// &[2, 2, 1], // HWC format [height, width, channels]
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 1, 2]),
|
||||
/// &[2, 2, 1, 1], // HWIO format [height, width, in_channels, out_channels]
|
||||
/// ).unwrap());
|
||||
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1, DataFormat::HWC, KernelFormat::HWIO).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[6, 6, 6]), &[1, 1, 3, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn deconv<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
|
||||
>(
|
||||
@@ -3527,9 +3601,14 @@ pub fn deconv<
|
||||
output_padding: &[usize],
|
||||
stride: &[usize],
|
||||
num_groups: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let has_bias = inputs.len() == 3;
|
||||
let (image, kernel) = (&inputs[0], &inputs[1]);
|
||||
let (mut working_image, mut working_kernel) = (inputs[0].clone(), inputs[1].clone());
|
||||
|
||||
data_format.to_canonical(&mut working_image)?;
|
||||
kernel_format.to_canonical(&mut working_kernel)?;
|
||||
|
||||
if stride.iter().any(|&s| s == 0) {
|
||||
return Err(TensorError::DimMismatch(
|
||||
@@ -3539,26 +3618,23 @@ pub fn deconv<
|
||||
}
|
||||
|
||||
let null_val = ValType::Constant(F::ZERO);
|
||||
let mut expanded_image = working_image.clone();
|
||||
|
||||
let mut expanded_image = image.clone();
|
||||
|
||||
// Expand image by inserting zeros according to stride
|
||||
for (i, s) in stride.iter().enumerate() {
|
||||
expanded_image.intercalate_values(null_val.clone(), *s, 2 + i)?;
|
||||
}
|
||||
|
||||
// Pad to kernel size for each spatial dimension
|
||||
expanded_image.pad(
|
||||
kernel.dims()[2..]
|
||||
working_kernel.dims()[2..]
|
||||
.iter()
|
||||
.map(|d| (d - 1, d - 1))
|
||||
.collect::<Vec<_>>(),
|
||||
2,
|
||||
)?; // pad to the kernel size
|
||||
|
||||
// flip order
|
||||
let channel_coord = (0..kernel.dims()[0])
|
||||
.cartesian_product(0..kernel.dims()[1])
|
||||
.collect::<Vec<_>>();
|
||||
)?;
|
||||
|
||||
// Calculate slice coordinates considering padding and output padding
|
||||
let slice_coord = expanded_image
|
||||
.dims()
|
||||
.iter()
|
||||
@@ -3574,26 +3650,34 @@ pub fn deconv<
|
||||
|
||||
let sliced_expanded_image = expanded_image.get_slice(&slice_coord)?;
|
||||
|
||||
let mut inverted_kernels = vec![];
|
||||
// Generate channel coordinates for kernel transformation
|
||||
let (in_ch_dim, out_ch_dim) =
|
||||
KernelFormat::default().get_channel_dims(working_kernel.dims().len());
|
||||
let channel_coord = (0..working_kernel.dims()[out_ch_dim])
|
||||
.cartesian_product(0..working_kernel.dims()[in_ch_dim])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Invert kernels for deconvolution
|
||||
let mut inverted_kernels = vec![];
|
||||
for (i, j) in channel_coord {
|
||||
let channel = kernel.get_slice(&[i..i + 1, j..j + 1])?;
|
||||
let channel = working_kernel.get_slice(&[i..i + 1, j..j + 1])?;
|
||||
let mut channel = Tensor::from(channel.get_inner_tensor()?.clone().into_iter().rev());
|
||||
channel.reshape(&kernel.dims()[2..])?;
|
||||
channel.reshape(&working_kernel.dims()[2..])?;
|
||||
inverted_kernels.push(channel);
|
||||
}
|
||||
|
||||
let mut deconv_kernel =
|
||||
Tensor::new(Some(&inverted_kernels), &[inverted_kernels.len()])?.combine()?;
|
||||
deconv_kernel.reshape(kernel.dims())?;
|
||||
deconv_kernel.reshape(working_kernel.dims())?;
|
||||
|
||||
// tensorflow formatting patch
|
||||
if kernel.dims()[0] == sliced_expanded_image.dims()[1] {
|
||||
// Handle tensorflow-style input/output channel ordering
|
||||
if working_kernel.dims()[0] == sliced_expanded_image.dims()[1] {
|
||||
let mut dims = deconv_kernel.dims().to_vec();
|
||||
dims.swap(0, 1);
|
||||
deconv_kernel.reshape(&dims)?;
|
||||
}
|
||||
|
||||
// Prepare inputs for convolution
|
||||
let conv_input = if has_bias {
|
||||
vec![
|
||||
sliced_expanded_image,
|
||||
@@ -3604,28 +3688,32 @@ pub fn deconv<
|
||||
vec![sliced_expanded_image, deconv_kernel.clone().into()]
|
||||
};
|
||||
|
||||
let conv_dim = kernel.dims()[2..].len();
|
||||
let conv_dim = working_kernel.dims()[2..].len();
|
||||
|
||||
let output = conv(
|
||||
// Perform convolution with canonical formats
|
||||
let mut output = conv(
|
||||
config,
|
||||
region,
|
||||
&conv_input,
|
||||
&vec![(0, 0); conv_dim],
|
||||
&vec![1; conv_dim],
|
||||
num_groups,
|
||||
data_format.canonical(), // Use canonical format
|
||||
kernel_format.canonical(), // Use canonical format
|
||||
)?;
|
||||
|
||||
// Convert output back to requested format
|
||||
data_format.from_canonical(&mut output)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Applies convolution over a ND tensor of shape C x H x D1...DN (and adds a bias).
|
||||
/// ```
|
||||
/// // expected ouputs are taken from pytorch torch.nn.functional.conv2d
|
||||
///
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::conv;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use ezkl::tensor::{val::ValTensor, DataFormat, KernelFormat};
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
@@ -3634,6 +3722,7 @@ pub fn deconv<
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(65536, 4));
|
||||
///
|
||||
/// // Test case 1: Basic 2D convolution with NCHW format (default)
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
@@ -3646,44 +3735,64 @@ pub fn deconv<
|
||||
/// Some(&[0]),
|
||||
/// &[1],
|
||||
/// ).unwrap());
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Now test single channel
|
||||
/// // Test case 2: NHWC format with HWIO kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 2, 3, 3],
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 3, 3, 1], // NHWC format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 1, 5, 2, 1, 1]),
|
||||
/// &[2, 1, 2, 2],
|
||||
/// Some(&[1, 1, 5, 1]),
|
||||
/// &[2, 2, 1, 1], // HWIO format
|
||||
/// ).unwrap());
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::HWIO).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[11, 24, 20, 14]), &[1, 2, 2, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Test case 3: Multi-channel NHWC with OHWI kernel
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 3, 3, 2], // NHWC format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 2, 5, 2, 1, 2]),
|
||||
/// &[1, 2, 2, 2], // OHWI format
|
||||
/// ).unwrap());
|
||||
/// let b = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1]),
|
||||
/// &[2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap();
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1, DataFormat::NHWC, KernelFormat::OHWI).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[64, 66, 46, 58]), &[1, 2, 2, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// // Now test multi channel
|
||||
/// // Test case 4: 1D convolution with NCHW format
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 2, 3, 3],
|
||||
/// Some(&[1, 2, 3, 4, 5]),
|
||||
/// &[1, 1, 5], // NCHW format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1, 5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1]),
|
||||
/// &[4, 2, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let b = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1, 1, 1]),
|
||||
/// &[4],
|
||||
/// Some(&[1, 2, 3]),
|
||||
/// &[1, 1, 3], // OIHW format
|
||||
/// ).unwrap());
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0)], &vec![1], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 20, 26]), &[1, 1, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap();
|
||||
/// // Test case 5: 3D convolution with NCHW format
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[1, 1, 2, 2, 2], // NCDHW format
|
||||
/// ).unwrap());
|
||||
/// let k = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 1]),
|
||||
/// &[1, 1, 1, 1, 2], // OIDHW format
|
||||
/// ).unwrap());
|
||||
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 3], &vec![1; 3], 1, DataFormat::NCHW, KernelFormat::OIHW).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 7, 11, 15]), &[1, 1, 2, 2, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
@@ -3696,9 +3805,14 @@ pub fn conv<
|
||||
padding: &[(usize, usize)],
|
||||
stride: &[usize],
|
||||
num_groups: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let has_bias = values.len() == 3;
|
||||
let (mut image, mut kernel) = (values[0].clone(), values[1].clone());
|
||||
let (mut working_image, mut working_kernel) = (values[0].clone(), values[1].clone());
|
||||
|
||||
data_format.to_canonical(&mut working_image)?;
|
||||
kernel_format.to_canonical(&mut working_kernel)?;
|
||||
|
||||
if stride.iter().any(|&s| s == 0) {
|
||||
return Err(TensorError::DimMismatch(
|
||||
@@ -3707,47 +3821,40 @@ pub fn conv<
|
||||
.into());
|
||||
}
|
||||
|
||||
// we specifically want to use the same kernel and image for all the convolutions and need to enforce this by assigning them
|
||||
// 1. assign the kernel
|
||||
// Assign tensors
|
||||
let mut assigned_len = vec![];
|
||||
|
||||
if !kernel.all_prev_assigned() {
|
||||
kernel = region.assign(&config.custom_gates.inputs[0], &kernel)?;
|
||||
assigned_len.push(kernel.len());
|
||||
if !working_kernel.all_prev_assigned() {
|
||||
working_kernel = region.assign(&config.custom_gates.inputs[0], &working_kernel)?;
|
||||
assigned_len.push(working_kernel.len());
|
||||
}
|
||||
// 2. assign the image
|
||||
if !image.all_prev_assigned() {
|
||||
image = region.assign(&config.custom_gates.inputs[1], &image)?;
|
||||
assigned_len.push(image.len());
|
||||
if !working_image.all_prev_assigned() {
|
||||
working_image = region.assign(&config.custom_gates.inputs[1], &working_image)?;
|
||||
assigned_len.push(working_image.len());
|
||||
}
|
||||
|
||||
if !assigned_len.is_empty() {
|
||||
// safe to unwrap since we've just checked it has at least one element
|
||||
region.increment(*assigned_len.iter().max().unwrap());
|
||||
}
|
||||
|
||||
// if image is 3d add a dummy batch dimension
|
||||
if image.dims().len() == kernel.dims().len() - 1 {
|
||||
image.reshape(&[1, image.dims()[0], image.dims()[1], image.dims()[2]])?;
|
||||
if data_format.has_no_batch() {
|
||||
let mut dim = working_image.dims().to_vec();
|
||||
dim.insert(0, 1);
|
||||
working_image.reshape(&dim)?;
|
||||
}
|
||||
|
||||
let image_dims = image.dims();
|
||||
let kernel_dims = kernel.dims();
|
||||
let image_dims = working_image.dims();
|
||||
let kernel_dims = working_kernel.dims();
|
||||
|
||||
let mut padded_image = image.clone();
|
||||
// Apply padding
|
||||
let mut padded_image = working_image.clone();
|
||||
padded_image.pad(padding.to_vec(), 2)?;
|
||||
|
||||
// Extract dimensions
|
||||
let batch_size = image_dims[0];
|
||||
let input_channels = image_dims[1];
|
||||
let output_channels = kernel_dims[0];
|
||||
|
||||
log::debug!(
|
||||
"batch_size: {}, output_channels: {}, input_channels: {}",
|
||||
batch_size,
|
||||
output_channels,
|
||||
input_channels
|
||||
);
|
||||
|
||||
// Calculate slides for each spatial dimension
|
||||
let slides = image_dims[2..]
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -3762,8 +3869,6 @@ pub fn conv<
|
||||
})
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
|
||||
log::debug!("slides: {:?}", slides);
|
||||
|
||||
let input_channels_per_group = input_channels / num_groups;
|
||||
let output_channels_per_group = output_channels / num_groups;
|
||||
|
||||
@@ -3771,24 +3876,15 @@ pub fn conv<
|
||||
return Err(TensorError::DimMismatch(format!(
|
||||
"Given groups={}, expected input channels and output channels to be divisible by groups, but got input_channels={}, output_channels={}",
|
||||
num_groups, input_channels, output_channels
|
||||
))
|
||||
.into());
|
||||
)).into());
|
||||
}
|
||||
|
||||
log::debug!(
|
||||
"num_groups: {}, input_channels_per_group: {}, output_channels_per_group: {}",
|
||||
num_groups,
|
||||
input_channels_per_group,
|
||||
output_channels_per_group
|
||||
);
|
||||
|
||||
let num_outputs =
|
||||
batch_size * num_groups * output_channels_per_group * slides.iter().product::<usize>();
|
||||
|
||||
log::debug!("num_outputs: {}", num_outputs);
|
||||
|
||||
let mut output: Tensor<ValType<F>> = Tensor::new(None, &[num_outputs])?;
|
||||
|
||||
// Create iteration space
|
||||
let mut iterations = vec![0..batch_size, 0..num_groups, 0..output_channels_per_group];
|
||||
for slide in slides.iter() {
|
||||
iterations.push(0..*slide);
|
||||
@@ -3800,6 +3896,13 @@ pub fn conv<
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let batch_offset = if data_format.has_no_batch() {
|
||||
2 // No batch dimension, start coordinates after channels
|
||||
} else {
|
||||
3 // Has batch dimension, start coordinates after batch and channels
|
||||
};
|
||||
|
||||
// Main convolution loop
|
||||
let inner_loop_function = |idx: usize, region: &mut RegionCtx<F>| {
|
||||
let cartesian_coord_per_group = &cartesian_coord[idx];
|
||||
let (batch, group, i) = (
|
||||
@@ -3813,22 +3916,19 @@ pub fn conv<
|
||||
|
||||
let mut slices = vec![batch..batch + 1, start_channel..end_channel];
|
||||
for (i, stride) in stride.iter().enumerate() {
|
||||
let coord = cartesian_coord_per_group[3 + i] * stride;
|
||||
let coord = cartesian_coord_per_group[batch_offset + i] * stride;
|
||||
let kernel_dim = kernel_dims[2 + i];
|
||||
slices.push(coord..(coord + kernel_dim));
|
||||
}
|
||||
|
||||
let mut local_image = padded_image.get_slice(&slices)?;
|
||||
|
||||
local_image.flatten();
|
||||
|
||||
let start_kernel_index = group * output_channels_per_group + i;
|
||||
let end_kernel_index = start_kernel_index + 1;
|
||||
let mut local_kernel = kernel.get_slice(&[start_kernel_index..end_kernel_index])?;
|
||||
|
||||
let mut local_kernel = working_kernel.get_slice(&[start_kernel_index..end_kernel_index])?;
|
||||
local_kernel.flatten();
|
||||
|
||||
// this is dot product notation in einsum format
|
||||
let mut res = einsum(config, region, &[local_image, local_kernel], "i,i->")?;
|
||||
|
||||
if has_bias {
|
||||
@@ -3849,21 +3949,16 @@ pub fn conv<
|
||||
region.flush()?;
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let reshape_output = |output: &mut Tensor<ValType<F>>| -> Result<(), TensorError> {
|
||||
// remove dummy batch dimension if we added one
|
||||
let mut dims = vec![batch_size, output_channels];
|
||||
dims.extend(slides.iter().cloned());
|
||||
output.reshape(&dims)?;
|
||||
// Reshape output
|
||||
let mut dims = vec![batch_size, output_channels];
|
||||
dims.extend(slides.iter().cloned());
|
||||
output.reshape(&dims)?;
|
||||
|
||||
Ok(())
|
||||
};
|
||||
// Convert output back to requested format
|
||||
let mut final_output: ValTensor<F> = output.into();
|
||||
data_format.from_canonical(&mut final_output)?;
|
||||
|
||||
// remove dummy batch dimension if we added one
|
||||
reshape_output(&mut output)?;
|
||||
|
||||
let output: ValTensor<_> = output.into();
|
||||
|
||||
Ok(output)
|
||||
Ok(final_output)
|
||||
}
|
||||
|
||||
/// Power accumulated layout
|
||||
@@ -3904,7 +3999,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
Ok(rescaled_inputs)
|
||||
}
|
||||
|
||||
/// Dummy (no contraints) reshape layout
|
||||
/// Dummy (no constraints) reshape layout
|
||||
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
new_dims: &[usize],
|
||||
@@ -3914,7 +4009,7 @@ pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
/// Dummy (no contraints) move_axis layout
|
||||
/// Dummy (no constraints) move_axis layout
|
||||
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
source: usize,
|
||||
@@ -5743,14 +5838,12 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[101, 201, 302, 403, 503, 603]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0, false).unwrap();
|
||||
/// let result = output::<Fp>(&dummy_config, &mut dummy_region, &[x, y], false).unwrap();
|
||||
/// ```
|
||||
pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
scale: utils::F32,
|
||||
tol: f32,
|
||||
decomp: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut values = [values[0].clone(), values[1].clone()];
|
||||
@@ -5765,43 +5858,6 @@ pub fn output<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values[1] = layouts::identity(config, region, &[values[1].clone()], decomp)?;
|
||||
}
|
||||
|
||||
if tol == 0.0 {
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
|
||||
}
|
||||
|
||||
// Calculate the difference between the expected output and actual output
|
||||
let diff = pairwise(config, region, &values, BaseOp::Sub)?;
|
||||
|
||||
// integer scale
|
||||
let int_scale = scale.0 as IntegerRep;
|
||||
// felt scale
|
||||
let felt_scale = integer_rep_to_felt(int_scale);
|
||||
// input scale ratio we multiply by tol such that in the new scale range_check_len represents tol percent
|
||||
let input_scale_ratio = (scale.0 * tol) as IntegerRep / 2 * 2;
|
||||
|
||||
let recip = recip(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone()],
|
||||
felt_scale,
|
||||
felt_scale * F::from(100),
|
||||
)?;
|
||||
|
||||
log::debug!("recip: {}", recip.show());
|
||||
|
||||
// Multiply the difference by the recip
|
||||
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
|
||||
|
||||
log::debug!("product: {}", product.show());
|
||||
let rebased_product = div(
|
||||
config,
|
||||
region,
|
||||
&[product],
|
||||
integer_rep_to_felt(input_scale_ratio),
|
||||
)?;
|
||||
log::debug!("rebased_product: {}", rebased_product.show());
|
||||
|
||||
// check that it is within the tolerance range
|
||||
range_check(config, region, &[rebased_product], &(-int_scale, int_scale))
|
||||
// regular equality constraint
|
||||
return enforce_equality(config, region, &[values[0].clone(), values[1].clone()]);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::any::Any;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::prelude::DatumType;
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
@@ -96,6 +98,8 @@ pub enum InputType {
|
||||
Int,
|
||||
///
|
||||
TDim,
|
||||
///
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl InputType {
|
||||
@@ -132,6 +136,7 @@ impl InputType {
|
||||
let int_input = input.clone().to_i64().unwrap();
|
||||
*input = T::from_i64(int_input).unwrap();
|
||||
}
|
||||
InputType::Unknown => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -152,6 +157,28 @@ impl std::str::FromStr for InputType {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<DatumType> for InputType {
|
||||
fn from(datum_type: DatumType) -> Self {
|
||||
match datum_type {
|
||||
DatumType::Bool => InputType::Bool,
|
||||
DatumType::F16 => InputType::F16,
|
||||
DatumType::F32 => InputType::F32,
|
||||
DatumType::F64 => InputType::F64,
|
||||
DatumType::I8 => InputType::Int,
|
||||
DatumType::I16 => InputType::Int,
|
||||
DatumType::I32 => InputType::Int,
|
||||
DatumType::I64 => InputType::Int,
|
||||
DatumType::U8 => InputType::Int,
|
||||
DatumType::U16 => InputType::Int,
|
||||
DatumType::U32 => InputType::Int,
|
||||
DatumType::U64 => InputType::Int,
|
||||
DatumType::TDim => InputType::TDim,
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Input {
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::{
|
||||
utils::{self, F32},
|
||||
},
|
||||
tensor::{self, Tensor, TensorError},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
};
|
||||
|
||||
use super::{base::BaseOp, *};
|
||||
@@ -43,6 +44,8 @@ pub enum PolyOp {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -54,6 +57,8 @@ pub enum PolyOp {
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
data_format: DataFormat,
|
||||
kernel_format: KernelFormat,
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -165,10 +170,12 @@ impl<
|
||||
stride,
|
||||
padding,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => {
|
||||
format!(
|
||||
"CONV (stride={:?}, padding={:?}, group={})",
|
||||
stride, padding, group
|
||||
"CONV (stride={:?}, padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
|
||||
stride, padding, group, data_format, kernel_format
|
||||
)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
@@ -176,11 +183,12 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
|
||||
stride, padding, output_padding, group
|
||||
)
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
|
||||
stride, padding, output_padding, group, data_format, kernel_format)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
@@ -242,6 +250,8 @@ impl<
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => layouts::conv(
|
||||
config,
|
||||
region,
|
||||
@@ -249,6 +259,8 @@ impl<
|
||||
padding,
|
||||
stride,
|
||||
*group,
|
||||
*data_format,
|
||||
*kernel_format,
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
@@ -309,6 +321,8 @@ impl<
|
||||
output_padding,
|
||||
stride,
|
||||
group,
|
||||
data_format,
|
||||
kernel_format,
|
||||
} => layouts::deconv(
|
||||
config,
|
||||
region,
|
||||
@@ -317,6 +331,8 @@ impl<
|
||||
output_padding,
|
||||
stride,
|
||||
*group,
|
||||
*data_format,
|
||||
*kernel_format,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::circuit::ops::poly::PolyOp;
|
||||
use crate::circuit::*;
|
||||
use crate::tensor::{DataFormat, KernelFormat};
|
||||
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
@@ -1065,6 +1066,8 @@ mod conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1220,6 +1223,8 @@ mod conv_col_ultra_overflow {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1377,6 +1382,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
data_format: DataFormat::default(),
|
||||
kernel_format: KernelFormat::default(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
@@ -1999,7 +2006,7 @@ mod add_with_overflow_and_poseidon {
|
||||
let base = BaseConfig::configure(cs, &[a, b], &output, CheckMode::SAFE);
|
||||
VarTensor::constant_cols(cs, K, 2, false);
|
||||
|
||||
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::configure(cs, ());
|
||||
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(cs, ());
|
||||
|
||||
MyCircuitConfig { base, poseidon }
|
||||
}
|
||||
@@ -2009,7 +2016,7 @@ mod add_with_overflow_and_poseidon {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
|
||||
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> =
|
||||
PoseidonChip::new(config.poseidon.clone());
|
||||
|
||||
let assigned_inputs_a =
|
||||
@@ -2044,11 +2051,9 @@ mod add_with_overflow_and_poseidon {
|
||||
let b = (0..LEN)
|
||||
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
|
||||
.collect::<Vec<_>>();
|
||||
let commitment_a =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone()).unwrap()[0][0];
|
||||
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0];
|
||||
|
||||
let commitment_b =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone()).unwrap()[0][0];
|
||||
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0];
|
||||
|
||||
// parameters
|
||||
let a = Tensor::from(a.into_iter().map(Value::known));
|
||||
@@ -2070,13 +2075,11 @@ mod add_with_overflow_and_poseidon {
|
||||
let b = (0..LEN)
|
||||
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
|
||||
.collect::<Vec<_>>();
|
||||
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone())
|
||||
.unwrap()[0][0]
|
||||
+ Fr::one();
|
||||
let commitment_a =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0] + Fr::one();
|
||||
|
||||
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone())
|
||||
.unwrap()[0][0]
|
||||
+ Fr::one();
|
||||
let commitment_b =
|
||||
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0] + Fr::one();
|
||||
|
||||
// parameters
|
||||
let a = Tensor::from(a.into_iter().map(Value::known));
|
||||
|
||||
@@ -19,6 +19,11 @@ pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
/// Converts a PrimeField element to an f64.
|
||||
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN as f64;
|
||||
} else if x < -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
panic!("Felt value out of range for conversion to integer rep");
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -31,11 +36,13 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i64.
|
||||
/// Converts a PrimeField element to an integer rep.
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
return IntegerRep::MIN;
|
||||
} else if x < -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
|
||||
panic!("Felt value out of range for conversion to integer rep");
|
||||
}
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
@@ -70,6 +77,13 @@ mod test {
|
||||
assert_eq!(res, F::from(131072));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn felttointegerrep_overflow() {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(IntegerRep::MIN) - F::ONE;
|
||||
let _xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
|
||||
@@ -33,7 +33,7 @@ pub enum GraphError {
|
||||
#[error("a node is missing required params: {0}")]
|
||||
MissingParams(String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node is has misformed params: {0}")]
|
||||
#[error("a node has misformed params: {0}")]
|
||||
MisformedParams(String),
|
||||
/// Error in the configuration of the visibility of variables
|
||||
#[error("there should be at least one set of public variables")]
|
||||
|
||||
@@ -609,8 +609,12 @@ impl GraphData {
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"calibration data length must be evenly divisible by the original input_size"
|
||||
.to_string(),
|
||||
format!(
|
||||
"calibration data length (={}) must be evenly divisible by the original input_size(={})",
|
||||
input.len(),
|
||||
input_size
|
||||
),
|
||||
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -455,6 +455,10 @@ pub struct GraphSettings {
|
||||
pub num_blinding_factors: Option<usize>,
|
||||
/// unix time timestamp
|
||||
pub timestamp: Option<u128>,
|
||||
/// Model inputs types (if any)
|
||||
pub input_types: Option<Vec<InputType>>,
|
||||
/// Model outputs types (if any)
|
||||
pub output_types: Option<Vec<InputType>>,
|
||||
}
|
||||
|
||||
impl GraphSettings {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use super::errors::GraphError;
|
||||
use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::scale_to_multiplier;
|
||||
use super::vars::*;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
@@ -379,9 +378,15 @@ pub struct ParsedNodes {
|
||||
pub nodes: BTreeMap<usize, NodeType>,
|
||||
inputs: Vec<usize>,
|
||||
outputs: Vec<Outlet>,
|
||||
output_types: Vec<InputType>,
|
||||
}
|
||||
|
||||
impl ParsedNodes {
|
||||
/// Returns the output types of the computational graph.
|
||||
pub fn get_output_types(&self) -> Vec<InputType> {
|
||||
self.output_types.clone()
|
||||
}
|
||||
|
||||
/// Returns the number of the computational graph's inputs
|
||||
pub fn num_inputs(&self) -> usize {
|
||||
self.inputs.len()
|
||||
@@ -491,6 +496,16 @@ impl Model {
|
||||
Ok(om)
|
||||
}
|
||||
|
||||
/// Gets the input types from the parsed nodes
|
||||
pub fn get_input_types(&self) -> Result<Vec<InputType>, GraphError> {
|
||||
self.graph.get_input_types()
|
||||
}
|
||||
|
||||
/// Gets the output types from the parsed nodes
|
||||
pub fn get_output_types(&self) -> Vec<InputType> {
|
||||
self.graph.get_output_types()
|
||||
}
|
||||
|
||||
///
|
||||
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
|
||||
let f = std::fs::File::create(&path).map_err(|e| {
|
||||
@@ -574,6 +589,11 @@ impl Model {
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
input_types: match self.get_input_types() {
|
||||
Ok(x) => Some(x),
|
||||
Err(_) => None,
|
||||
},
|
||||
output_types: Some(self.get_output_types()),
|
||||
num_dynamic_lookups: res.num_dynamic_lookups,
|
||||
total_dynamic_col_size: res.dynamic_lookup_col_coord,
|
||||
num_shuffles: res.num_shuffles,
|
||||
@@ -704,6 +724,11 @@ impl Model {
|
||||
nodes,
|
||||
inputs: model.inputs.iter().map(|o| o.node).collect(),
|
||||
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
|
||||
output_types: model
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|o| Ok::<InputType, GraphError>(model.outlet_fact(*o)?.datum_type.into()))
|
||||
.collect::<Result<Vec<_>, GraphError>>()?,
|
||||
};
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
@@ -862,6 +887,15 @@ impl Model {
|
||||
nodes: subgraph_nodes,
|
||||
inputs: model.inputs.iter().map(|o| o.node).collect(),
|
||||
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
|
||||
output_types: model
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|o| {
|
||||
Ok::<InputType, GraphError>(
|
||||
model.outlet_fact(*o)?.datum_type.into(),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?,
|
||||
};
|
||||
|
||||
let om = Model {
|
||||
@@ -1138,17 +1172,10 @@ impl Model {
|
||||
})?;
|
||||
|
||||
if run_args.output_visibility.is_public() || run_args.output_visibility.is_fixed() {
|
||||
let output_scales = self.graph.get_output_scales().map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut tol: crate::circuit::Tolerance = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let comparators = if run_args.output_visibility == Visibility::Public {
|
||||
let res = vars
|
||||
.instance
|
||||
@@ -1171,7 +1198,6 @@ impl Model {
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
)
|
||||
@@ -1433,11 +1459,9 @@ impl Model {
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
|
||||
let output_scales = self.graph.get_output_scales()?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
.map(|output| {
|
||||
let mut comparator: ValTensor<Fp> = (0..output.len())
|
||||
.map(|_| {
|
||||
if !self.visibility.output.is_fixed() {
|
||||
@@ -1450,14 +1474,10 @@ impl Model {
|
||||
.into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let mut tol = run_args.tolerance;
|
||||
tol.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
Box::new(HybridOp::Output {
|
||||
tol,
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
)
|
||||
@@ -1579,4 +1599,16 @@ impl Model {
|
||||
}
|
||||
Ok(instance_shapes)
|
||||
}
|
||||
|
||||
/// Input types of the computational graph's public inputs (if any)
|
||||
pub fn instance_types(&self) -> Result<Vec<InputType>, GraphError> {
|
||||
let mut instance_types = vec![];
|
||||
if self.visibility.input.is_public() {
|
||||
instance_types.extend(self.graph.get_input_types()?);
|
||||
}
|
||||
if self.visibility.output.is_public() {
|
||||
instance_types.extend(self.graph.get_output_types());
|
||||
}
|
||||
Ok(instance_types)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,11 @@ use serde::{Deserialize, Serialize};
|
||||
use super::errors::GraphError;
|
||||
use super::{VarVisibility, Visibility};
|
||||
|
||||
/// poseidon len to hash in tree
|
||||
pub const POSEIDON_LEN_GRAPH: usize = 32;
|
||||
/// Poseidon number of instances
|
||||
pub const POSEIDON_INSTANCES: usize = 1;
|
||||
|
||||
/// Poseidon module type
|
||||
pub type ModulePoseidon =
|
||||
PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>;
|
||||
pub type ModulePoseidon = PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>;
|
||||
/// Poseidon module config
|
||||
pub type ModulePoseidonConfig = PoseidonConfig<POSEIDON_WIDTH, POSEIDON_RATE>;
|
||||
|
||||
|
||||
@@ -39,9 +39,8 @@ use tract_onnx::tract_hir::{
|
||||
ops::array::{Pad, PadMode, TypedConcat},
|
||||
ops::cnn::PoolSpec,
|
||||
ops::konst::Const,
|
||||
ops::nn::DataFormat,
|
||||
tract_core::ops::cast::Cast,
|
||||
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
|
||||
tract_core::ops::cnn::{MaxPool, SumPool},
|
||||
};
|
||||
|
||||
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
|
||||
@@ -274,11 +273,9 @@ pub fn new_op_from_onnx(
|
||||
symbol_values: &SymbolValues,
|
||||
run_args: &crate::RunArgs,
|
||||
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
|
||||
use std::f64::consts::E;
|
||||
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
use crate::circuit::InputType;
|
||||
use std::f64::consts::E;
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
let input_scales = inputs
|
||||
.iter()
|
||||
@@ -1148,13 +1145,6 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
@@ -1163,6 +1153,7 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
stride: stride.to_vec(),
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
data_format: pool_spec.data_format.into(),
|
||||
})
|
||||
}
|
||||
"Ceil" => {
|
||||
@@ -1274,9 +1265,19 @@ pub fn new_op_from_onnx(
|
||||
// get the non constant index
|
||||
let denom = c.raw_values[0];
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Div {
|
||||
let op = SupportedOp::Hybrid(HybridOp::Div {
|
||||
denom: denom.into(),
|
||||
})
|
||||
});
|
||||
|
||||
// if the input is scale 0 we re up to the max scale
|
||||
if input_scales[0] == 0 {
|
||||
SupportedOp::Rescaled(Rescaled {
|
||||
inner: Box::new(op),
|
||||
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
|
||||
})
|
||||
} else {
|
||||
op
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::MisformedParams(
|
||||
"only support non zero divisors of size 1".to_string(),
|
||||
@@ -1306,15 +1307,6 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
if ((conv_node.pool_spec.data_format != DataFormat::NCHW)
|
||||
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|
||||
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool_spec = &conv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
@@ -1342,6 +1334,8 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
data_format: conv_node.pool_spec.data_format.into(),
|
||||
kernel_format: conv_node.kernel_fmt.into(),
|
||||
})
|
||||
}
|
||||
"Not" => SupportedOp::Linear(PolyOp::Not),
|
||||
@@ -1365,14 +1359,6 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|
||||
|| (deconv_node.kernel_format != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool_spec = &deconv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
@@ -1398,6 +1384,8 @@ pub fn new_op_from_onnx(
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
stride,
|
||||
group: deconv_node.group,
|
||||
data_format: deconv_node.pool_spec.data_format.into(),
|
||||
kernel_format: deconv_node.kernel_format.into(),
|
||||
})
|
||||
}
|
||||
"Downsample" => {
|
||||
@@ -1481,13 +1469,6 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, &input_dims[0])?;
|
||||
|
||||
@@ -1496,6 +1477,7 @@ pub fn new_op_from_onnx(
|
||||
stride: stride.to_vec(),
|
||||
kernel_shape: pool_spec.kernel_shape.to_vec(),
|
||||
normalized: sumpool_node.normalize,
|
||||
data_format: pool_spec.data_format.into(),
|
||||
})
|
||||
}
|
||||
"Pad" => {
|
||||
|
||||
44
src/lib.rs
44
src/lib.rs
@@ -97,10 +97,9 @@ impl From<String> for EZKLError {
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use circuit::{table::Range, CheckMode};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::Args;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use halo2_proofs::poly::{
|
||||
@@ -275,10 +274,6 @@ impl From<String> for Commitments {
|
||||
derive(Args, ToFlags)
|
||||
)]
|
||||
pub struct RunArgs {
|
||||
/// Error tolerance for model outputs
|
||||
/// Only applicable when outputs are public
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
|
||||
pub tolerance: Tolerance,
|
||||
/// Fixed point scaling factor for quantizing inputs
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
@@ -365,7 +360,6 @@ impl Default for RunArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
scale_rebase_multiplier: 1,
|
||||
@@ -399,6 +393,16 @@ impl RunArgs {
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// check if the largest represented integer in the decomposed form overflows IntegerRep
|
||||
// try it with the largest possible value
|
||||
let max_decomp = (self.decomp_base as IntegerRep).checked_pow(self.decomp_legs as u32);
|
||||
if max_decomp.is_none() {
|
||||
errors.push(format!(
|
||||
"decomp_base^decomp_legs overflows IntegerRep: {}^{}",
|
||||
self.decomp_base, self.decomp_legs
|
||||
));
|
||||
}
|
||||
|
||||
// Visibility validations
|
||||
if self.param_visibility == Visibility::Public {
|
||||
errors.push(
|
||||
@@ -407,10 +411,6 @@ impl RunArgs {
|
||||
);
|
||||
}
|
||||
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
errors.push("Non-zero tolerance requires output_visibility to be public".to_string());
|
||||
}
|
||||
|
||||
// Scale validations
|
||||
if self.scale_rebase_multiplier < 1 {
|
||||
errors.push("scale_rebase_multiplier must be >= 1".to_string());
|
||||
@@ -459,11 +459,6 @@ impl RunArgs {
|
||||
warn!("logrows exceeds maximum public SRS size");
|
||||
}
|
||||
|
||||
// Validate tolerance is non-negative
|
||||
if self.tolerance.val < 0.0 {
|
||||
errors.push("tolerance cannot be negative".to_string());
|
||||
}
|
||||
|
||||
// Performance warnings
|
||||
if self.input_scale > 20 || self.param_scale > 20 {
|
||||
warn!("High scale values (>20) may impact performance");
|
||||
@@ -610,23 +605,6 @@ mod tests {
|
||||
assert!(err.contains("num_inner_cols must be >= 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = 1.0;
|
||||
args.output_visibility = Visibility::Private;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("Non-zero tolerance requires output_visibility to be public"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negative_tolerance() {
|
||||
let mut args = RunArgs::default();
|
||||
args.tolerance.val = -1.0;
|
||||
let err = args.validate().unwrap_err();
|
||||
assert!(err.contains("tolerance cannot be negative"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_batch_size() {
|
||||
let mut args = RunArgs::default();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ops::DecompositionError;
|
||||
use super::{ops::DecompositionError, DataFormat};
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
@@ -44,4 +44,7 @@ pub enum TensorError {
|
||||
/// Index out of bounds
|
||||
#[error("index {0} out of bounds for dimension {1}")]
|
||||
IndexOutOfBounds(usize, usize),
|
||||
/// Invalid data conversion
|
||||
#[error("invalid data conversion from format {0} to {1}")]
|
||||
InvalidDataConversion(DataFormat, DataFormat),
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ pub mod var;
|
||||
|
||||
pub use errors::TensorError;
|
||||
|
||||
use core::hash::Hash;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
@@ -1767,6 +1768,229 @@ pub fn get_broadcasted_shape(
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
///
|
||||
|
||||
/// The shape of data for some operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
|
||||
pub enum DataFormat {
|
||||
/// NCHW
|
||||
#[default]
|
||||
NCHW,
|
||||
/// NHWC
|
||||
NHWC,
|
||||
/// CHW
|
||||
CHW,
|
||||
/// HWC
|
||||
HWC,
|
||||
}
|
||||
|
||||
// as str
|
||||
impl core::fmt::Display for DataFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DataFormat::NCHW => write!(f, "NCHW"),
|
||||
DataFormat::NHWC => write!(f, "NHWC"),
|
||||
DataFormat::CHW => write!(f, "CHW"),
|
||||
DataFormat::HWC => write!(f, "HWC"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DataFormat {
|
||||
/// Get the format's canonical form
|
||||
pub fn canonical(&self) -> DataFormat {
|
||||
match self {
|
||||
DataFormat::NHWC => DataFormat::NCHW,
|
||||
DataFormat::HWC => DataFormat::CHW,
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// no batch dim
|
||||
pub fn has_no_batch(&self) -> bool {
|
||||
match self {
|
||||
DataFormat::CHW | DataFormat::HWC => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert tensor to canonical format (NCHW or CHW)
|
||||
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
tensor: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
DataFormat::NHWC => {
|
||||
// For ND: Move channels from last axis to position after batch
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 2 {
|
||||
tensor.move_axis(ndims - 1, 1)?;
|
||||
}
|
||||
}
|
||||
DataFormat::HWC => {
|
||||
// For ND: Move channels from last axis to first position
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 1 {
|
||||
tensor.move_axis(ndims - 1, 0)?;
|
||||
}
|
||||
}
|
||||
_ => {} // NCHW/CHW are already in canonical format
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert tensor from canonical format to target format
|
||||
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
tensor: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
DataFormat::NHWC => {
|
||||
// Move channels from position 1 to end
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 2 {
|
||||
tensor.move_axis(1, ndims - 1)?;
|
||||
}
|
||||
}
|
||||
DataFormat::HWC => {
|
||||
// Move channels from position 0 to end
|
||||
let ndims = tensor.dims().len();
|
||||
if ndims > 1 {
|
||||
tensor.move_axis(0, ndims - 1)?;
|
||||
}
|
||||
}
|
||||
_ => {} // NCHW/CHW don't need conversion
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the position of the channel dimension
|
||||
pub fn get_channel_dim(&self, ndims: usize) -> usize {
|
||||
match self {
|
||||
DataFormat::NCHW => 1,
|
||||
DataFormat::NHWC => ndims - 1,
|
||||
DataFormat::CHW => 0,
|
||||
DataFormat::HWC => ndims - 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The shape of the kernel for some operations
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
|
||||
pub enum KernelFormat {
|
||||
/// HWIO
|
||||
HWIO,
|
||||
/// OIHW
|
||||
#[default]
|
||||
OIHW,
|
||||
/// OHWI
|
||||
OHWI,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for KernelFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
KernelFormat::HWIO => write!(f, "HWIO"),
|
||||
KernelFormat::OIHW => write!(f, "OIHW"),
|
||||
KernelFormat::OHWI => write!(f, "OHWI"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelFormat {
|
||||
/// Get the format's canonical form
|
||||
pub fn canonical(&self) -> KernelFormat {
|
||||
match self {
|
||||
KernelFormat::HWIO => KernelFormat::OIHW,
|
||||
KernelFormat::OHWI => KernelFormat::OIHW,
|
||||
_ => self.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert kernel to canonical format (OIHW)
|
||||
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
kernel: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
KernelFormat::HWIO => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move output channels from last to first
|
||||
kernel.move_axis(kdims - 1, 0)?;
|
||||
// Move input channels from new last to second position
|
||||
kernel.move_axis(kdims - 1, 1)?;
|
||||
}
|
||||
KernelFormat::OHWI => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from last to second position
|
||||
kernel.move_axis(kdims - 1, 1)?;
|
||||
}
|
||||
_ => {} // OIHW is already canonical
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert kernel from canonical format to target format
|
||||
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
|
||||
&self,
|
||||
kernel: &mut ValTensor<F>,
|
||||
) -> Result<(), TensorError> {
|
||||
match self {
|
||||
KernelFormat::HWIO => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from second position to last
|
||||
kernel.move_axis(1, kdims - 1)?;
|
||||
// Move output channels from first to last
|
||||
kernel.move_axis(0, kdims - 1)?;
|
||||
}
|
||||
KernelFormat::OHWI => {
|
||||
let kdims = kernel.dims().len();
|
||||
// Move input channels from second position to last
|
||||
kernel.move_axis(1, kdims - 1)?;
|
||||
}
|
||||
_ => {} // OIHW doesn't need conversion
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the position of input and output channel dimensions
|
||||
pub fn get_channel_dims(&self, ndims: usize) -> (usize, usize) {
|
||||
// (input_ch, output_ch)
|
||||
match self {
|
||||
KernelFormat::OIHW => (1, 0),
|
||||
KernelFormat::HWIO => (ndims - 2, ndims - 1),
|
||||
KernelFormat::OHWI => (ndims - 1, 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<tract_onnx::tract_hir::ops::nn::DataFormat> for DataFormat {
|
||||
fn from(fmt: tract_onnx::tract_hir::ops::nn::DataFormat) -> Self {
|
||||
match fmt {
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::NCHW => DataFormat::NCHW,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::NHWC => DataFormat::NHWC,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::CHW => DataFormat::CHW,
|
||||
tract_onnx::tract_hir::ops::nn::DataFormat::HWC => DataFormat::HWC,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat> for KernelFormat {
|
||||
fn from(fmt: tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat) -> Self {
|
||||
match fmt {
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::HWIO => {
|
||||
KernelFormat::HWIO
|
||||
}
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OIHW => {
|
||||
KernelFormat::OIHW
|
||||
}
|
||||
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OHWI => {
|
||||
KernelFormat::OHWI
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -387,7 +387,7 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
} else if t.is_empty() {
|
||||
return Err(TensorError::DimMismatch("add".to_string()));
|
||||
}
|
||||
|
||||
@@ -441,7 +441,7 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
} else if t.is_empty() {
|
||||
return Err(TensorError::DimMismatch("sub".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
@@ -492,7 +492,7 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if t.len() == 1 {
|
||||
return Ok(t[0].clone());
|
||||
} else if t.len() == 0 {
|
||||
} else if t.is_empty() {
|
||||
return Err(TensorError::DimMismatch("mult".to_string()));
|
||||
}
|
||||
// calculate value of output
|
||||
@@ -1326,7 +1326,6 @@ pub fn pad<T: TensorType>(
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns a TensorError if the tensors in `inputs` have incompatible dimensions for concatenation along the specified `axis`.
|
||||
|
||||
pub fn concat<T: TensorType + Send + Sync>(
|
||||
inputs: &[&Tensor<T>],
|
||||
axis: usize,
|
||||
@@ -2102,7 +2101,6 @@ pub mod nonlinearities {
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
|
||||
pub fn tanh(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
|
||||
Binary file not shown.
@@ -1,8 +1,7 @@
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(test)]
|
||||
mod native_tests {
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
@@ -187,7 +186,7 @@ mod native_tests {
|
||||
|
||||
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
|
||||
|
||||
const LARGE_TESTS: [&str; 7] = [
|
||||
const LARGE_TESTS: [&str; 8] = [
|
||||
"self_attention",
|
||||
"nanoGPT",
|
||||
"multihead_attention",
|
||||
@@ -195,6 +194,7 @@ mod native_tests {
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age",
|
||||
"1d_conv",
|
||||
];
|
||||
|
||||
const ACCURACY_CAL_TESTS: [&str; 6] = [
|
||||
@@ -206,7 +206,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 98] = [
|
||||
const TESTS: [&str; 99] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice", //1
|
||||
"1l_concat", //2
|
||||
@@ -309,6 +309,7 @@ mod native_tests {
|
||||
"log", // 95
|
||||
"exp", // 96
|
||||
"general_exp", // 97
|
||||
"integer_div", // 98
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -521,7 +522,7 @@ mod native_tests {
|
||||
use crate::native_tests::run_js_tests;
|
||||
use crate::native_tests::render_circuit;
|
||||
use crate::native_tests::model_serialization_different_binaries;
|
||||
use rand::Rng;
|
||||
|
||||
use tempdir::TempDir;
|
||||
use ezkl::Commitments;
|
||||
|
||||
@@ -542,12 +543,12 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=97 {
|
||||
seq!(N in 0..=98 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -607,7 +608,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -617,22 +618,10 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true, Some(8194), Some(4));
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, true, Some(8194), Some(4));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_tolerance_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false, Some(32776), Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
@@ -643,7 +632,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false, None, None);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -653,7 +642,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -662,7 +651,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -671,7 +660,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -680,7 +669,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -689,7 +678,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -698,7 +687,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -707,7 +696,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -717,7 +706,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -727,7 +716,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -736,7 +725,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -746,7 +735,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -755,7 +744,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -765,7 +754,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -775,7 +764,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -785,7 +774,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -795,7 +784,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -805,7 +794,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false, None, None);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, false, None, None);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -964,7 +953,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=6 {
|
||||
seq!(N in 0..=7 {
|
||||
|
||||
#(#[test_case(LARGE_TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -982,7 +971,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false, None, Some(5));
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, false, None, Some(5));
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1458,12 +1447,10 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
decomp_base: Option<usize>,
|
||||
decomp_legs: Option<usize>,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
example_name.clone(),
|
||||
@@ -1474,7 +1461,6 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
&mut tolerance,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
@@ -1482,128 +1468,17 @@ mod native_tests {
|
||||
decomp_legs,
|
||||
);
|
||||
|
||||
if tolerance > 0.0 {
|
||||
// load witness and shift the output by a small amount that is less than tolerance percent
|
||||
let witness = GraphWitness::from_path(
|
||||
format!("{}/{}/witness.json", test_dir, example_name).into(),
|
||||
)
|
||||
.unwrap();
|
||||
let witness = witness.clone();
|
||||
let outputs = witness.outputs.clone();
|
||||
|
||||
// get values as i64
|
||||
let output_perturbed_safe: Vec<Vec<halo2curves::bn256::Fr>> = outputs
|
||||
.iter()
|
||||
.map(|sv| {
|
||||
sv.iter()
|
||||
.map(|v| {
|
||||
// randomly perturb by a small amount less than tolerance
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::zero()
|
||||
} else {
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
|
||||
*v + perturbation
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// get values as i64
|
||||
let output_perturbed_bad: Vec<Vec<halo2curves::bn256::Fr>> = outputs
|
||||
.iter()
|
||||
.map(|sv| {
|
||||
sv.iter()
|
||||
.map(|v| {
|
||||
// randomly perturb by a small amount less than tolerance
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::from(2)
|
||||
} else {
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
*v + perturbation
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let good_witness = GraphWitness {
|
||||
outputs: output_perturbed_safe,
|
||||
..witness.clone()
|
||||
};
|
||||
|
||||
// save
|
||||
good_witness
|
||||
.save(format!("{}/{}/witness_ok.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let bad_witness = GraphWitness {
|
||||
outputs: output_perturbed_bad,
|
||||
..witness.clone()
|
||||
};
|
||||
|
||||
// save
|
||||
bad_witness
|
||||
.save(format!("{}/{}/witness_bad.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness_ok.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness_bad.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
} else {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -1617,7 +1492,6 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
tolerance: &mut f32,
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
@@ -1632,13 +1506,16 @@ mod native_tests {
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
format!("--variables=batch_size->{}", batch_size),
|
||||
format!(
|
||||
"--variables=batch_size->{},sequence_length->100,<Sym1>->1",
|
||||
batch_size
|
||||
),
|
||||
format!("--input-visibility={}", input_visibility),
|
||||
format!("--param-visibility={}", param_visibility),
|
||||
format!("--output-visibility={}", output_visibility),
|
||||
format!("--num-inner-cols={}", num_inner_columns),
|
||||
format!("--tolerance={}", tolerance),
|
||||
format!("--commitment={}", commitment),
|
||||
format!("--logrows={}", 22),
|
||||
];
|
||||
|
||||
// if output-visibility is fixed set --range-check-inputs-outputs to False
|
||||
@@ -1694,24 +1571,6 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let mut settings =
|
||||
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
|
||||
|
||||
if any_output_scales_smol {
|
||||
// set the tolerance to 0.0
|
||||
settings.run_args.tolerance = Tolerance {
|
||||
val: 0.0,
|
||||
scale: 0.0.into(),
|
||||
};
|
||||
settings
|
||||
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
*tolerance = 0.0;
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"compile-circuit",
|
||||
@@ -1724,7 +1583,6 @@ mod native_tests {
|
||||
test_dir, example_name
|
||||
),
|
||||
])
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
@@ -1767,7 +1625,6 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
@@ -2053,7 +1910,6 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
&mut 0.0,
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
@@ -2488,7 +2344,6 @@ mod native_tests {
|
||||
// we need the accuracy
|
||||
Some(vec![4]),
|
||||
1,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
|
||||
@@ -49,6 +49,23 @@ mod py_tests {
|
||||
std::env::set_var("VOICE_DATA_DIR", format!("{}", voice_data_dir));
|
||||
}
|
||||
|
||||
fn download_catdog_data() {
|
||||
let cat_and_dog_data_dir = shellexpand::tilde("~/data/catdog_data");
|
||||
|
||||
DOWNLOAD_VOICE_DATA.call_once(|| {
|
||||
let status = Command::new("bash")
|
||||
.args([
|
||||
"examples/notebooks/cat_and_dog_data.sh",
|
||||
&cat_and_dog_data_dir,
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
});
|
||||
// set VOICE_DATA_DIR environment variable
|
||||
std::env::set_var("CATDOG_DATA_DIR", format!("{}", cat_and_dog_data_dir));
|
||||
}
|
||||
|
||||
fn setup_py_env() {
|
||||
ENV_SETUP.call_once(|| {
|
||||
// supposes that you have a virtualenv called .env and have run the following
|
||||
@@ -225,6 +242,20 @@ mod py_tests {
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn cat_and_dog_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
let mut anvil_child = crate::py_tests::start_anvil(false);
|
||||
crate::py_tests::download_catdog_data();
|
||||
let test_dir: TempDir = TempDir::new("cat_and_dog").unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::py_tests::mv_test_(path, "cat_and_dog.ipynb");
|
||||
run_notebook(path, "cat_and_dog.ipynb");
|
||||
test_dir.close().unwrap();
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reusable_verifier_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
|
||||
@@ -48,7 +48,6 @@ def test_py_run_args():
|
||||
run_args = ezkl.PyRunArgs()
|
||||
run_args.input_visibility = "hashed"
|
||||
run_args.output_visibility = "hashed"
|
||||
run_args.tolerance = 1.5
|
||||
|
||||
|
||||
def test_poseidon_hash():
|
||||
@@ -59,7 +58,7 @@ def test_poseidon_hash():
|
||||
message = [ezkl.float_to_felt(x, 7) for x in message]
|
||||
res = ezkl.poseidon_hash(message)
|
||||
assert ezkl.felt_to_big_endian(
|
||||
res[0]) == "0x0da7e5e5c8877242fa699f586baf770d731defd54f952d4adeb85047a0e32f45"
|
||||
res[0]) == "0x2369898875588bf49b6539376b09705ea69aee318a58e6fcc1e68fc3e7ad81ab"
|
||||
|
||||
|
||||
|
||||
@@ -873,7 +872,8 @@ def get_examples():
|
||||
'linear_regression',
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
"fr_age"
|
||||
"fr_age",
|
||||
"1d_conv",
|
||||
]
|
||||
examples = []
|
||||
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
|
||||
@@ -900,7 +900,12 @@ async def test_all_examples(model_file, input_file):
|
||||
proof_path = os.path.join(folder_path, 'proof.json')
|
||||
|
||||
print("Testing example: ", model_file)
|
||||
res = ezkl.gen_settings(model_file, settings_path)
|
||||
|
||||
run_args = ezkl.PyRunArgs()
|
||||
run_args.variables = [("batch_size", 1), ("sequence_length", 100), ("<Sym1>", 1)]
|
||||
run_args.logrows = 22
|
||||
|
||||
res = ezkl.gen_settings(model_file, settings_path, py_run_args=run_args)
|
||||
assert res
|
||||
|
||||
res = await ezkl.calibrate_settings(
|
||||
|
||||
@@ -11,7 +11,6 @@ mod wasm32 {
|
||||
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use ezkl::circuit::modules::poseidon::PoseidonChip;
|
||||
use ezkl::circuit::modules::Module;
|
||||
use ezkl::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use ezkl::graph::GraphCircuit;
|
||||
use ezkl::graph::{GraphSettings, GraphWitness};
|
||||
use ezkl::pfsys;
|
||||
@@ -227,11 +226,9 @@ mod wasm32 {
|
||||
let hash: Vec<Vec<Fr>> = serde_json::from_slice(&hash[..]).unwrap();
|
||||
|
||||
let reference_hash =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
|
||||
message.clone(),
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(hash, reference_hash)
|
||||
}
|
||||
@@ -340,7 +337,7 @@ mod wasm32 {
|
||||
// Run compiled circuit validation on onnx network (should fail)
|
||||
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK.to_vec()));
|
||||
assert!(circuit.is_err());
|
||||
// Run compiled circuit validation on comiled network (should pass)
|
||||
// Run compiled circuit validation on compiled network (should pass)
|
||||
let circuit = compiledCircuitValidation(wasm_bindgen::Clamped(NETWORK_COMPILED.to_vec()));
|
||||
assert!(circuit.is_ok());
|
||||
// Run input validation on witness (should fail)
|
||||
|
||||
Reference in New Issue
Block a user