mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
35 Commits
ac/fix-uns
...
ac/einsum
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2fe5a3bd0c | ||
|
|
282c1f8ff1 | ||
|
|
bccf9f7282 | ||
|
|
fca1e2f7c2 | ||
|
|
cde73c0eb4 | ||
|
|
d846e5bee9 | ||
|
|
1ce8376037 | ||
|
|
70af803c5c | ||
|
|
282a96b3f1 | ||
|
|
76e80df09c | ||
|
|
3f9159c4f4 | ||
|
|
d2b4cd1dc7 | ||
|
|
7e89d1d67b | ||
|
|
be8560e2be | ||
|
|
e9a6443f26 | ||
|
|
fe0689389f | ||
|
|
57cf318e42 | ||
|
|
e3355dbf69 | ||
|
|
fa548efb7f | ||
|
|
67b97f9ab8 | ||
|
|
71e86ade32 | ||
|
|
be5fb23ef4 | ||
|
|
d64749fc71 | ||
|
|
d7b04d0d25 | ||
|
|
d50a4b7d59 | ||
|
|
fad31de5b4 | ||
|
|
2f1a3f430e | ||
|
|
edd4d7f5b8 | ||
|
|
1c3ae450e1 | ||
|
|
afb4ca9f06 | ||
|
|
0e79d36238 | ||
|
|
2ba6417913 | ||
|
|
910d096a10 | ||
|
|
24649702ad | ||
|
|
e1c126c735 |
22
.github/workflows/benchmarks.yml
vendored
22
.github/workflows/benchmarks.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -32,7 +32,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -100,7 +100,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -151,7 +151,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -168,7 +168,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
@@ -185,7 +185,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
|
||||
90
.github/workflows/engine.yml
vendored
90
.github/workflows/engine.yml
vendored
@@ -18,29 +18,32 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
id-token: write # Required for provenance
|
||||
name: publish-wasm-bindings
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
cache: false
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
version: "v0.12.1"
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
@@ -49,41 +52,41 @@ jobs:
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
run: |
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
@@ -175,16 +178,17 @@ jobs:
|
||||
run: |
|
||||
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
|
||||
|
||||
# zizmor: ignore cache-poisoning
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
package-manager-cache: false
|
||||
|
||||
- name: Publish to npm with provenance
|
||||
run: |
|
||||
cd pkg
|
||||
npm install
|
||||
npm ci
|
||||
npm publish
|
||||
npm publish --provenance --access public
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
4
.github/workflows/large-tests.yml
vendored
4
.github/workflows/large-tests.yml
vendored
@@ -13,9 +13,9 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
|
||||
17
.github/workflows/pypi-gpu.yml
vendored
17
.github/workflows/pypi-gpu.yml
vendored
@@ -27,6 +27,8 @@ jobs:
|
||||
target: [x86_64]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
@@ -36,6 +38,16 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
|
||||
- name: Set pyproject.toml version to match github tag and rename ezkl to ezkl-gpu
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -43,11 +55,12 @@ jobs:
|
||||
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig > pyproject.toml.tmp
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.tmp > pyproject.toml
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
cache: false
|
||||
|
||||
- name: Set Cargo.toml version to match github tag and rename ezkl to ezkl-gpu
|
||||
shell: bash
|
||||
@@ -70,7 +83,7 @@ jobs:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: auto
|
||||
container: off
|
||||
args: --release --out dist --features python-bindings,icicle
|
||||
args: --release --out dist --features python-bindings,gpu-accelerated
|
||||
|
||||
- name: Install built wheel
|
||||
if: matrix.target == 'x86_64'
|
||||
|
||||
10
.github/workflows/pypi.yml
vendored
10
.github/workflows/pypi.yml
vendored
@@ -48,11 +48,12 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
cache: false
|
||||
|
||||
- name: Build wheels
|
||||
if: matrix.target == 'universal2-apple-darwin'
|
||||
@@ -113,11 +114,12 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
cache: false
|
||||
|
||||
- name: Build wheels
|
||||
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
|
||||
|
||||
30
.github/workflows/release.yml
vendored
30
.github/workflows/release.yml
vendored
@@ -47,17 +47,27 @@ jobs:
|
||||
TARGET_DIR: ./target
|
||||
RUST_BACKTRACE: 1
|
||||
PCRE2_SYS_STATIC: 1
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
cache: false
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
|
||||
- name: Get release version from tag
|
||||
shell: bash
|
||||
@@ -78,7 +88,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
|
||||
- name: Build release binary
|
||||
run: cargo build --release -Z sparse-registry --features icicle
|
||||
run: cargo build --release -Z sparse-registry --features gpu-accelerated
|
||||
|
||||
- name: Build archive
|
||||
shell: bash
|
||||
@@ -117,27 +127,27 @@ jobs:
|
||||
include:
|
||||
- build: windows-msvc
|
||||
os: windows-latest
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2025-05-01
|
||||
target: x86_64-pc-windows-msvc
|
||||
- build: macos
|
||||
os: macos-13
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2025-05-01
|
||||
target: x86_64-apple-darwin
|
||||
- build: macos-aarch64
|
||||
os: macos-13
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2025-05-01
|
||||
target: aarch64-apple-darwin
|
||||
- build: linux-musl
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2025-05-01
|
||||
target: x86_64-unknown-linux-musl
|
||||
- build: linux-gnu
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2025-05-01
|
||||
target: x86_64-unknown-linux-gnu
|
||||
- build: linux-aarch64
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2025-02-17
|
||||
rust: nightly-2025-05-01
|
||||
target: aarch64-unknown-linux-gnu
|
||||
|
||||
steps:
|
||||
@@ -199,7 +209,7 @@ jobs:
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
|
||||
|
||||
- name: Build release binary (metal)
|
||||
if: matrix.build == 'macos-aarch64'
|
||||
|
||||
567
.github/workflows/rust.yml
vendored
567
.github/workflows/rust.yml
vendored
@@ -27,13 +27,13 @@ jobs:
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3
|
||||
@@ -46,57 +46,76 @@ jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- name: Build
|
||||
run: cargo build --verbose
|
||||
|
||||
docs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Docs
|
||||
run: cargo doc --verbose
|
||||
|
||||
library-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -105,100 +124,84 @@ jobs:
|
||||
run: cargo test --doc --verbose
|
||||
- name: Library tests
|
||||
run: cargo nextest run --lib --verbose
|
||||
- name: Library tests (original lookup)
|
||||
run: cargo nextest run --lib --verbose --no-default-features --features ezkl,eth-original-lookup
|
||||
|
||||
# ultra-overflow-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2025-02-17
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
# with:
|
||||
# wasmtime-version: "3.0.1"
|
||||
# # - name: Matmul overflow (wasi)
|
||||
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# # - name: Conv overflow (wasi)
|
||||
# # run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: lookup overflow
|
||||
# run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Matmul overflow
|
||||
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Conv overflow
|
||||
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Conv + relu overflow
|
||||
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
|
||||
ultra-overflow-tests_og-lookup:
|
||||
ultra-overflow-tests-gpu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
runs-on: gpu
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
- name: Setup GPU dependencies
|
||||
run: sudo ./setup-gpu.sh --yes
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
# - name: Matmul overflow (wasi)
|
||||
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl,eth-original-lookup -- --include-ignored
|
||||
run: cargo nextest run lookup_ultra_overflow --no-capture --features gpu-accelerated -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
run: RUST_LOG=debug cargo nextest run --release matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl,eth-original-lookup -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features gpu-accelerated -- --include-ignored
|
||||
- name: Conv overflow
|
||||
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture --no-default-features --features ezkl,eth-original-lookup -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features gpu-accelerated -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl,eth-original-lookup -- --include-ignored
|
||||
run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features gpu-accelerated -- --include-ignored
|
||||
|
||||
ultra-overflow-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
# - name: Matmul overflow (wasi)
|
||||
@@ -217,21 +220,29 @@ jobs:
|
||||
model-serialization:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -241,31 +252,40 @@ jobs:
|
||||
wasm32-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-64-cores
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
# add `atomics` and `bulk-memory` to RUSTFLAGS to enable wasm-bindgen tests
|
||||
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
|
||||
# Pin to version 0.13.1
|
||||
version: "v0.13.1"
|
||||
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- name: Create webdriver.json to disable timeouts
|
||||
run: |
|
||||
echo '{"args": ["--headless", "--disable-gpu", "--disable-dev-shm-usage", "--no-sandbox"]}' > webdriver.json
|
||||
@@ -280,24 +300,29 @@ jobs:
|
||||
mock-proving-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: Large 1D Conv Mock
|
||||
@@ -350,53 +375,61 @@ jobs:
|
||||
prove-and-verify-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
# needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.13.1
|
||||
version: "v0.13.1"
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
with:
|
||||
version: 8
|
||||
- name: Use Node.js 18.12.1
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
- name: Use Node.js 22.17.1
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
node-version: "22.17.1"
|
||||
cache: "pnpm"
|
||||
- name: "Add rust-src"
|
||||
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and package
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 56b806a3ba7866a3b061093bebd0fa2ace97f1fc --locked anvil --force
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: KZG prove and verify tests (EVM)
|
||||
run: cargo nextest run --verbose "tests_evm::kzg_evm_prove_and_verify_::" --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --features reusable-verifier --test-threads 1
|
||||
# - name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
|
||||
# run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --features reusable-verifier --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg inputs)
|
||||
@@ -419,17 +452,17 @@ jobs:
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
# - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2025-02-17
|
||||
# toolchain: nightly-2025-05-01
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
# with:
|
||||
# # Pin to version 0.12.1
|
||||
# version: 'v0.12.1'
|
||||
# # Pin to version 0.13.1
|
||||
# version: 'v0.13.1'
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2025-02-17
|
||||
# run: rustup component add rust-src --toolchain nightly-2025-05-01
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
@@ -447,41 +480,42 @@ jobs:
|
||||
prove-and-verify-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs]
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
# Pin to version 0.13.1
|
||||
version: "v0.13.1"
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
with:
|
||||
version: 8
|
||||
- name: Use Node.js 18.12.1
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
- name: Use Node.js 22.17.1
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
node-version: "22.17.1"
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies for js tests
|
||||
run: |
|
||||
@@ -489,13 +523,13 @@ jobs:
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
# - name: Build wasm package for nodejs target.
|
||||
# run: |
|
||||
# wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
@@ -504,10 +538,6 @@ jobs:
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
- name: IPA prove and verify tests
|
||||
run: cargo nextest run --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
- name: IPA prove and verify tests (ipa outputs)
|
||||
run: cargo nextest run --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
- name: KZG prove and verify tests single inner col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_single_col
|
||||
- name: KZG prove and verify tests triple inner col
|
||||
@@ -527,42 +557,53 @@ jobs:
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed
|
||||
|
||||
# prove-and-verify-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2025-02-17
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (kzg outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public inputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (fixed params)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (hashed outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
prove-and-verify-tests-gpu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: gpu
|
||||
needs: [build, library-tests, docs]
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Setup GPU dependencies
|
||||
run: sudo ./setup-gpu.sh --yes
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features gpu-accelerated --test-threads 1
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
permissions:
|
||||
@@ -571,43 +612,57 @@ jobs:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
|
||||
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Mock aggr tests (KZG)
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
|
||||
# prove-and-verify-aggr-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2025-02-17
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG tests
|
||||
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
prove-and-verify-aggr-tests-gpu:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Setup GPU dependencies
|
||||
run: sudo ./setup-gpu.sh --yes
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- name: KZG tests
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features gpu-accelerated --test-threads 1 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-tests:
|
||||
permissions:
|
||||
@@ -616,18 +671,18 @@ jobs:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -641,47 +696,55 @@ jobs:
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 56b806a3ba7866a3b061093bebd0fa2ace97f1fc --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
examples:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [build, library-tests, docs]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -691,22 +754,22 @@ jobs:
|
||||
python-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install cmake
|
||||
@@ -716,7 +779,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 56b806a3ba7866a3b061093bebd0fa2ace97f1fc --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings,reusable-verifier --profile=test-runs
|
||||
- name: Run pytest
|
||||
@@ -725,25 +788,25 @@ jobs:
|
||||
accuracy-measurement-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -764,30 +827,31 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 56b806a3ba7866a3b061093bebd0fa2ace97f1fc --locked anvil --force
|
||||
- name: Install pip
|
||||
run: python -m ensurepip --upgrade
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
@@ -828,23 +892,27 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2025-02-17-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2025-05-01-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
|
||||
swift-package-tests:
|
||||
permissions:
|
||||
@@ -854,16 +922,21 @@ jobs:
|
||||
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
@@ -895,7 +968,7 @@ jobs:
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 16 Pro,OS=18.4' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
@@ -904,7 +977,7 @@ jobs:
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 16 Pro,OS=18.4' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
|
||||
4
.github/workflows/static-analysis.yml
vendored
4
.github/workflows/static-analysis.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-02-17
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
|
||||
300
Cargo.lock
generated
300
Cargo.lock
generated
@@ -881,17 +881,6 @@ dependencies = [
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi 0.1.19",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aurora-engine-modexp"
|
||||
version = "1.2.0"
|
||||
@@ -970,29 +959,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.69.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash 1.1.0",
|
||||
"shlex",
|
||||
"syn 2.0.101",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
@@ -1211,15 +1177,6 @@ dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||
dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
@@ -1270,29 +1227,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"half 2.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"libc",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "2.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"textwrap 0.11.0",
|
||||
"unicode-width 0.1.14",
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1323,7 +1258,7 @@ version = "4.5.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c06f5378ea264ad4f82bbc826628b5aad714a75abf6ece087e923010eb937fb6"
|
||||
dependencies = [
|
||||
"clap 4.5.37",
|
||||
"clap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1510,32 +1445,6 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"cast",
|
||||
"clap 2.34.0",
|
||||
"criterion-plot 0.4.5",
|
||||
"csv",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
"num-traits",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_cbor",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.5.1"
|
||||
@@ -1545,8 +1454,8 @@ dependencies = [
|
||||
"anes",
|
||||
"cast",
|
||||
"ciborium",
|
||||
"clap 4.5.37",
|
||||
"criterion-plot 0.5.0",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"is-terminal",
|
||||
"itertools 0.10.5",
|
||||
"num-traits",
|
||||
@@ -1562,16 +1471,6 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools 0.10.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
@@ -1644,27 +1543,6 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf"
|
||||
dependencies = [
|
||||
"csv-core",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv-core"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
@@ -2065,12 +1943,12 @@ dependencies = [
|
||||
"bincode",
|
||||
"camino",
|
||||
"chrono",
|
||||
"clap 4.5.37",
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"colored",
|
||||
"colored_json",
|
||||
"console_error_panic_hook",
|
||||
"criterion 0.5.1",
|
||||
"criterion",
|
||||
"ecc",
|
||||
"env_logger 0.10.2",
|
||||
"ethabi",
|
||||
@@ -2082,6 +1960,7 @@ dependencies = [
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"hex",
|
||||
"icicle-runtime",
|
||||
"indicatif",
|
||||
"instant",
|
||||
"itertools 0.10.5",
|
||||
@@ -2328,7 +2207,7 @@ version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4"
|
||||
dependencies = [
|
||||
"rustix 1.0.5",
|
||||
"rustix",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -2521,12 +2400,6 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "1.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.6.0"
|
||||
@@ -2541,7 +2414,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac%2Fconditional-compilation-icicle2#1dd2090741f006fd031a07da7f3c9dfce5e0015e"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.6",
|
||||
"bitvec",
|
||||
@@ -2558,7 +2431,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#01c88842679b4308e43ae5ed91c4183e861669bd"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
@@ -2568,7 +2441,7 @@ dependencies = [
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"icicle-bn254",
|
||||
"icicle-core",
|
||||
"icicle-cuda-runtime",
|
||||
"icicle-runtime",
|
||||
"instant",
|
||||
"lazy_static",
|
||||
"log",
|
||||
@@ -2576,7 +2449,7 @@ dependencies = [
|
||||
"mopro-msm",
|
||||
"rand_chacha 0.3.1",
|
||||
"rand_core 0.6.4",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustc-hash",
|
||||
"serde",
|
||||
"sha3 0.9.1",
|
||||
"tracing",
|
||||
@@ -2791,15 +2664,6 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.9"
|
||||
@@ -2991,33 +2855,45 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "icicle-bn254"
|
||||
version = "2.8.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
|
||||
version = "3.7.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
|
||||
dependencies = [
|
||||
"cmake",
|
||||
"criterion 0.3.6",
|
||||
"icicle-core",
|
||||
"icicle-cuda-runtime",
|
||||
"icicle-hash",
|
||||
"icicle-runtime",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icicle-core"
|
||||
version = "2.8.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
|
||||
version = "3.7.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
|
||||
dependencies = [
|
||||
"criterion 0.3.6",
|
||||
"hex",
|
||||
"icicle-cuda-runtime",
|
||||
"icicle-runtime",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icicle-cuda-runtime"
|
||||
version = "2.8.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
|
||||
name = "icicle-hash"
|
||||
version = "3.7.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
|
||||
dependencies = [
|
||||
"bindgen",
|
||||
"bitflags 1.3.2",
|
||||
"cmake",
|
||||
"icicle-core",
|
||||
"icicle-runtime",
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icicle-runtime"
|
||||
version = "3.7.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
|
||||
dependencies = [
|
||||
"cmake",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3470,28 +3346,12 @@ dependencies = [
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazycell"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.172"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.13"
|
||||
@@ -3519,12 +3379,6 @@ dependencies = [
|
||||
"redox_syscall",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.9.4"
|
||||
@@ -3963,7 +3817,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"pyo3",
|
||||
"pyo3-build-config",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustc-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4394,16 +4248,6 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "primal-check"
|
||||
version = "0.3.4"
|
||||
@@ -4656,7 +4500,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustc-hash",
|
||||
"rustls",
|
||||
"socket2",
|
||||
"thiserror 2.0.12",
|
||||
@@ -4675,7 +4519,7 @@ dependencies = [
|
||||
"getrandom 0.3.2",
|
||||
"rand 0.9.1",
|
||||
"ring",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustc-hash",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
@@ -5173,12 +5017,6 @@ version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.1.1"
|
||||
@@ -5224,19 +5062,6 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.0.5"
|
||||
@@ -5246,7 +5071,7 @@ dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.9.4",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -5494,16 +5319,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_cbor"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
|
||||
dependencies = [
|
||||
"half 1.8.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.219"
|
||||
@@ -5950,7 +5765,7 @@ dependencies = [
|
||||
"fastrand",
|
||||
"getrandom 0.3.2",
|
||||
"once_cell",
|
||||
"rustix 1.0.5",
|
||||
"rustix",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -5996,15 +5811,6 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||
dependencies = [
|
||||
"unicode-width 0.1.14",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.16.2"
|
||||
@@ -6398,7 +6204,7 @@ dependencies = [
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"dyn-hash",
|
||||
"half 2.6.0",
|
||||
"half",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"maplit",
|
||||
@@ -6433,7 +6239,7 @@ dependencies = [
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"dyn-hash",
|
||||
"half 2.6.0",
|
||||
"half",
|
||||
"lazy_static",
|
||||
"liquid",
|
||||
"liquid-core",
|
||||
@@ -6614,7 +6420,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"paste",
|
||||
"serde",
|
||||
"textwrap 0.16.2",
|
||||
"textwrap",
|
||||
"toml 0.5.11",
|
||||
"uniffi_meta",
|
||||
"uniffi_testing",
|
||||
@@ -6706,7 +6512,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cef408229a3a407fafa4c36dc4f6ece78a6fb258ab28d2b64bddd49c8cb680f6"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"textwrap 0.16.2",
|
||||
"textwrap",
|
||||
"uniffi_meta",
|
||||
"uniffi_testing",
|
||||
"weedle2",
|
||||
@@ -7046,18 +6852,6 @@ dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
||||
dependencies = [
|
||||
"either",
|
||||
"home",
|
||||
"once_cell",
|
||||
"rustix 0.38.44",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
@@ -7435,7 +7229,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rustix 1.0.5",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
39
Cargo.toml
39
Cargo.toml
@@ -16,12 +16,12 @@ crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2" }
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "ac/conditional-compilation-icicle2" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", features = [
|
||||
"circuit-params",
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
|
||||
"circuit-params", "mv-lookup"
|
||||
] }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
itertools = { version = "0.10.3", default-features = false }
|
||||
@@ -33,10 +33,10 @@ thiserror = { version = "1.0.38", default-features = false }
|
||||
hex = { version = "0.4.3", default-features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
"derive_serde", "mv-lookup"
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/zkonduit/ezkl-verifier", branch = "main", optional = true, features = [
|
||||
"evm",
|
||||
"evm", "mv-lookup",
|
||||
] }
|
||||
maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
@@ -103,6 +103,10 @@ uniffi_bindgen = { version = "=0.28.0", optional = true }
|
||||
camino = { version = "^1.1", optional = true }
|
||||
uuid = { version = "1.10.0", features = ["v4"], optional = true }
|
||||
|
||||
# GPU / device related things (optional - only enabled with gpu-accelerated feature)
|
||||
icicle-runtime = { git = "https://github.com/ingonyama-zk/icicle", branch="emir/gate_eval_2", package="icicle-runtime", optional = true }
|
||||
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
|
||||
colored = { version = "2.0.0", default-features = false, optional = true }
|
||||
env_logger = { version = "0.10.0", default-features = false, optional = true }
|
||||
@@ -222,11 +226,13 @@ required-features = ["python-bindings"]
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = [
|
||||
"eth-mv-lookup",
|
||||
"eth",
|
||||
"dep:halo2_solidity_verifier",
|
||||
"ezkl",
|
||||
"precompute-coset",
|
||||
"no-banner",
|
||||
"parallel-poly-read",
|
||||
"reusable-verifier",
|
||||
]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
|
||||
@@ -235,7 +241,7 @@ universal-bindings = [
|
||||
"mv-lookup",
|
||||
"precompute-coset",
|
||||
"parallel-poly-read",
|
||||
"solidity-verifier-mv-lookup",
|
||||
"dep:halo2_solidity_verifier"
|
||||
]
|
||||
logging = ["dep:colored", "dep:env_logger", "dep:chrono"]
|
||||
ios-bindings = ["universal-bindings"]
|
||||
@@ -261,10 +267,6 @@ ezkl = [
|
||||
"logging",
|
||||
]
|
||||
eth = ["dep:alloy", "dep:foundry-compilers", "dep:ethabi"]
|
||||
solidity-verifier = ["dep:halo2_solidity_verifier"]
|
||||
solidity-verifier-mv-lookup = ["halo2_solidity_verifier/mv-lookup"]
|
||||
eth-mv-lookup = ["solidity-verifier-mv-lookup", "mv-lookup", "eth"]
|
||||
eth-original-lookup = ["eth", "solidity-verifier"]
|
||||
parallel-poly-read = [
|
||||
"halo2_proofs/circuit-params",
|
||||
"halo2_proofs/parallel-poly-read",
|
||||
@@ -273,7 +275,7 @@ mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup"]
|
||||
asm = ["halo2curves/asm", "halo2_proofs/asm"]
|
||||
precompute-coset = ["halo2_proofs/precompute-coset"]
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
gpu-accelerated = ["halo2_proofs/gpu-accelerated", "dep:icicle-runtime"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
no-update = []
|
||||
@@ -284,6 +286,17 @@ mimalloc = ["dep:mimalloc"]
|
||||
reusable-verifier = []
|
||||
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
|
||||
"circuit-params", "mv-lookup"
|
||||
] }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
|
||||
"circuit-params", "mv-lookup"
|
||||
] }
|
||||
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
@@ -301,3 +314,5 @@ opt-level = 3
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = ["-O4", "--flexible-inline-max-function-size", "4294967295"]
|
||||
|
||||
|
||||
|
||||
@@ -1,53 +1,78 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use criterion::{
|
||||
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
|
||||
Throughput,
|
||||
};
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
use ezkl::pfsys::{create_keys, create_proof_circuit, TranscriptType};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
use std::collections::HashMap;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 16;
|
||||
static mut K: usize = 15;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit {
|
||||
inputs: [ValTensor<Fr>; 2],
|
||||
_marker: PhantomData<Fr>,
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit {
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = ();
|
||||
type FloorPlanner = V1;
|
||||
type Params = SingleEinsumParams<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, 1, len * len);
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
unsafe {
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
let _constant = VarTensor::constant_cols(cs, K, 2, false);
|
||||
}
|
||||
|
||||
let b = VarTensor::new_advice(cs, K, 1, len * len);
|
||||
config
|
||||
}
|
||||
|
||||
let output = VarTensor::new_advice(cs, K, 1, (len + 1) * len);
|
||||
fn params(&self) -> Self::Params {
|
||||
SingleEinsumParams::<Fr>::new(
|
||||
&self.einsum_params.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE)
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -55,16 +80,33 @@ impl Circuit<Fr> for MyCircuit {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.as_ref()
|
||||
.ok_or(Error::Synthesis)?
|
||||
.challenges()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ab,bc->ac".to_string(),
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -77,41 +119,49 @@ impl Circuit<Fr> for MyCircuit {
|
||||
|
||||
fn runmatmul(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("accum_einsum_matmul");
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
|
||||
for &len in [4, 32].iter() {
|
||||
unsafe {
|
||||
LEN = len;
|
||||
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear));
|
||||
group.sampling_mode(criterion::SamplingMode::Flat);
|
||||
group.sample_size(10);
|
||||
let len = 128;
|
||||
unsafe {
|
||||
LEN = len;
|
||||
}
|
||||
for k in 16..17 {
|
||||
let params = unsafe {
|
||||
K = k;
|
||||
gen_srs::<KZGCommitmentScheme<_>>(K as u32)
|
||||
};
|
||||
|
||||
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[len, len]).unwrap();
|
||||
|
||||
let einsum_params = SingleEinsumParams::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
_marker: PhantomData,
|
||||
einsum_params,
|
||||
};
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, ¶ms, false)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
MyCircuit<Fr>,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
|
||||
171
examples/accum_einsum_matmul.rs
Normal file
171
examples/accum_einsum_matmul.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const K: usize = 13;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.as_ref()
|
||||
.ok_or(Error::Synthesis)?
|
||||
.challenges()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runmatmul() {
|
||||
let len = 64;
|
||||
|
||||
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[len, len]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runmatmul()
|
||||
}
|
||||
179
examples/batch_mat_mul.rs
Normal file
179
examples/batch_mat_mul.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 11;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, 1, len);
|
||||
let b = VarTensor::new_advice(cs, K, 1, len);
|
||||
let output = VarTensor::new_advice(cs, K, 1, len);
|
||||
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.as_ref()
|
||||
.ok_or(Error::Synthesis)?
|
||||
.challenges()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runbatchmatmul() {
|
||||
let batch_size = 5;
|
||||
let len = 12;
|
||||
|
||||
let mut a = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[batch_size, len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[batch_size, len, len]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("ijk,ikl->ijl", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runbatchmatmul()
|
||||
}
|
||||
@@ -866,7 +866,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 98,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -879,6 +879,7 @@
|
||||
"run_args.input_visibility = \"private\"\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.disable_freivalds = True\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -1142,4 +1143,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
79
examples/onnx/large_mlp/gen.py
Normal file
79
examples/onnx/large_mlp/gen.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from torch import nn
|
||||
import torch.nn.init as init
|
||||
import torch
|
||||
import json
|
||||
|
||||
N = 100
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.aff1 = nn.Linear(N,N)
|
||||
self.aff2 = nn.Linear(N,N)
|
||||
self.aff3 = nn.Linear(N,N)
|
||||
self.aff4 = nn.Linear(N,N)
|
||||
self.aff5 = nn.Linear(N,N)
|
||||
self.aff6 = nn.Linear(N,N)
|
||||
self.aff7 = nn.Linear(N,N)
|
||||
self.aff8 = nn.Linear(N,N)
|
||||
self.aff9 = nn.Linear(N,N)
|
||||
self.relu = nn.ReLU()
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
# concat 10 x along dim 0
|
||||
x = x.repeat(10, 1)
|
||||
x = self.aff1(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff2(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff3(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff4(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff5(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff6(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff7(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff8(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff9(x)
|
||||
return x
|
||||
|
||||
|
||||
def _initialize_weights(self):
|
||||
init.orthogonal_(self.aff1.weight)
|
||||
|
||||
model = Model()
|
||||
|
||||
# Flips the neural net into inference mode
|
||||
model.eval()
|
||||
model.to('cpu')
|
||||
|
||||
|
||||
x = torch.randn(1, N)
|
||||
# Export the model
|
||||
torch.onnx.export(model, # model being run
|
||||
# model input (or a tuple for multiple inputs)
|
||||
x,
|
||||
# where to save the model (can be a file or file-like object)
|
||||
"network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=12, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data_json = dict(input_data=[data_array])
|
||||
|
||||
print(data_json)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data_json, open("input.json", 'w'))
|
||||
1
examples/onnx/large_mlp/input.json
Normal file
1
examples/onnx/large_mlp/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.33088353276252747, -0.8819183707237244, 1.245591163635254, -1.807046890258789, 1.9922369718551636, -0.3360576629638672, 0.4529011845588684, -0.3590165674686432, 0.08356846123933792, 0.5126393437385559, 0.44627535343170166, 1.4916497468948364, 0.49731069803237915, -0.9748706817626953, -0.4923185408115387, 1.3548223972320557, 0.2306872010231018, 1.125955581665039, -1.7063908576965332, 0.3777385354042053, -2.7988760471343994, -1.1846797466278076, 0.7473157048225403, 1.490412950515747, 0.017497723922133446, 2.113945245742798, -1.2141249179840088, -0.16120357811450958, 0.021127669140696526, 0.7207374572753906, -1.369688868522644, -0.7369781732559204, -0.630584180355072, -0.4520200788974762, 0.29123976826667786, 0.6334688067436218, -0.869332492351532, -1.258501648902893, 0.3012596666812897, -0.5507447123527527, 0.669975757598877, 0.15088629722595215, -0.1050339788198471, 0.5505334138870239, -0.1287376880645752, -1.4297826290130615, -0.01703289896249771, -1.2296998500823975, 0.5122153162956238, -0.16924428939819336, -0.415036678314209, -1.1979341506958008, 0.05831022188067436, -0.4411357045173645, 2.0713791847229004, 1.4611141681671143, -0.9357407093048096, -0.333297461271286, -0.676478385925293, 1.390028476715088, -0.05827632546424866, 1.535687804222107, 0.3060210347175598, -0.03171076253056526, -0.614985466003418, 1.2040390968322754, 0.31318482756614685, -1.2134959697723389, 0.13110508024692535, -1.4880926609039307, 1.7007993459701538, 1.5412729978561401, 0.09260450303554535, 0.7649128437042236, -0.5009126663208008, -0.5356241464614868, -0.069572813808918, -0.011717632412910461, 0.21314217150211334, -0.1985170543193817, -0.0223808903247118, 1.2128918170928955, 0.8334696888923645, 1.9029873609542847, -0.11491120606660843, -0.10303237289190292, -0.2467050403356552, 1.557223916053772, -1.1108328104019165, -0.9065343141555786, -0.2271333783864975, 0.6959827542304993, -0.48698121309280396, 0.5689510703086853, 1.115319013595581, -0.8907430768013, -0.24722427129745483, -0.7437837719917297, 0.6742106676101685, -1.7830933332443237]]}
|
||||
BIN
examples/onnx/large_mlp/network.onnx
Normal file
BIN
examples/onnx/large_mlp/network.onnx
Normal file
Binary file not shown.
182
examples/tensor_contraction.rs
Normal file
182
examples/tensor_contraction.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 11;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, 1, len);
|
||||
let b = VarTensor::new_advice(cs, K, 1, len);
|
||||
let output = VarTensor::new_advice(cs, K, 1, len);
|
||||
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 2;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
let _constant = VarTensor::constant_cols(cs, K, 2, false);
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.as_ref()
|
||||
.ok_or(Error::Synthesis)?
|
||||
.challenges()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runmatmul() {
|
||||
let i = 10;
|
||||
let n = 10;
|
||||
let j = 40;
|
||||
let k = 10;
|
||||
|
||||
let mut a = Tensor::from((0..i * n * j).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[i, n, j]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..j * k).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[j, k]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("inj,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runmatmul()
|
||||
}
|
||||
@@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "nightly-2025-02-17"
|
||||
channel = "nightly-2025-05-01"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
||||
229
setup-gpu.sh
Executable file
229
setup-gpu.sh
Executable file
@@ -0,0 +1,229 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Default installation directory
|
||||
DEFAULT_INSTALL_DIR="/opt/icicle/lib/backend/halo2"
|
||||
|
||||
# Halo2 repository details
|
||||
HALO2_REPO="https://github.com/zkonduit/halo2"
|
||||
HALO2_BRANCH="ac/conditional-compilation-icicle2"
|
||||
|
||||
# Parse command line arguments
|
||||
AUTO_YES=false
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
-y|--yes)
|
||||
AUTO_YES=true
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo "Options:"
|
||||
echo " -y, --yes Automatically answer 'yes' to all prompts"
|
||||
echo " -h, --help Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $arg"
|
||||
echo "Use -h or --help for usage information"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo -e "${GREEN}EZKL GPU Setup Script${NC}"
|
||||
echo -e "${GREEN}=====================${NC}"
|
||||
echo ""
|
||||
|
||||
# Parse commit hash from Cargo.lock
|
||||
echo "Parsing halo2 commit hash from Cargo.lock..."
|
||||
if [ ! -f "Cargo.lock" ]; then
|
||||
echo -e "${RED}Error: Cargo.lock not found. Please run this script from the project root.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HALO2_COMMIT=$(grep "github\.com/zkonduit/halo2?" Cargo.lock | grep -v "halo2wrong" | head -1 | grep -o "#[a-f0-9]\{40\}" | cut -c2-)
|
||||
|
||||
if [ -z "$HALO2_COMMIT" ]; then
|
||||
echo -e "${RED}Error: Could not parse halo2 commit hash from Cargo.lock${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Found halo2 commit: $HALO2_COMMIT${NC}"
|
||||
echo ""
|
||||
echo "This script will:"
|
||||
echo "1. Sparse checkout the halo2 repository at commit $HALO2_COMMIT"
|
||||
echo "2. Extract only the icicle/backend/cuda/ directory"
|
||||
echo "3. Set the ICICLE_BACKEND_INSTALL_DIR environment variable"
|
||||
echo ""
|
||||
|
||||
# Check if user wants to override the default directory
|
||||
if [ "$AUTO_YES" = true ]; then
|
||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
||||
echo -e "${GREEN}Using default installation directory: ${INSTALL_DIR}${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}Default installation directory: ${DEFAULT_INSTALL_DIR}${NC}"
|
||||
read -p "Do you want to use a different directory? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
read -p "Enter the installation directory: " INSTALL_DIR
|
||||
INSTALL_DIR="${INSTALL_DIR/#\~/$HOME}" # Expand ~ to $HOME
|
||||
else
|
||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
||||
fi
|
||||
|
||||
# Confirm the installation directory
|
||||
echo ""
|
||||
echo -e "${YELLOW}Installation directory: ${INSTALL_DIR}${NC}"
|
||||
read -p "Continue with this directory? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo -e "${RED}Setup cancelled by user.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check if ICICLE_BACKEND_INSTALL_DIR is already set
|
||||
if [ ! -z "$ICICLE_BACKEND_INSTALL_DIR" ] && [ "$AUTO_YES" = false ]; then
|
||||
echo ""
|
||||
echo -e "${YELLOW}Warning: ICICLE_BACKEND_INSTALL_DIR is already set to: $ICICLE_BACKEND_INSTALL_DIR${NC}"
|
||||
read -p "Do you want to override it? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo -e "${RED}Setup cancelled by user.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
elif [ ! -z "$ICICLE_BACKEND_INSTALL_DIR" ] && [ "$AUTO_YES" = true ]; then
|
||||
echo -e "${GREEN}Overriding existing ICICLE_BACKEND_INSTALL_DIR (was: $ICICLE_BACKEND_INSTALL_DIR)${NC}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}Starting GPU setup...${NC}"
|
||||
|
||||
# Create installation directory
|
||||
echo "Creating installation directory..."
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
|
||||
# Create temporary directory for sparse checkout
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
echo "Using temporary directory: $TEMP_DIR"
|
||||
|
||||
# Clone with sparse checkout
|
||||
echo "Cloning halo2 repository with sparse checkout..."
|
||||
cd "$TEMP_DIR"
|
||||
git clone --filter=blob:none --sparse "$HALO2_REPO" halo2
|
||||
cd halo2
|
||||
|
||||
# Checkout the specific branch and commit
|
||||
echo "Checking out branch $HALO2_BRANCH at commit $HALO2_COMMIT..."
|
||||
git checkout "$HALO2_BRANCH"
|
||||
git checkout "$HALO2_COMMIT"
|
||||
|
||||
# Configure sparse checkout
|
||||
echo "Configuring sparse checkout for icicle/backend/cuda/..."
|
||||
git sparse-checkout init --cone
|
||||
git sparse-checkout set icicle/backend/cuda/
|
||||
|
||||
# Copy the icicle directory to the installation location
|
||||
if [ -d "icicle/backend/cuda" ]; then
|
||||
echo "Copying icicle/backend/cuda/ to $INSTALL_DIR..."
|
||||
cp -r icicle/backend/cuda/* "$INSTALL_DIR/"
|
||||
echo -e "${GREEN}Files copied successfully!${NC}"
|
||||
else
|
||||
echo -e "${RED}Error: icicle/backend/cuda directory not found in the repository${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Clean up temporary directory
|
||||
echo "Cleaning up temporary files..."
|
||||
rm -rf "$TEMP_DIR"
|
||||
|
||||
# Ask user about setting environment variable permanently
|
||||
SETUP_PERMANENT_ENV=false
|
||||
if [ "$AUTO_YES" = true ]; then
|
||||
SETUP_PERMANENT_ENV=true
|
||||
echo ""
|
||||
echo -e "${GREEN}Setting ICICLE_BACKEND_INSTALL_DIR environment variable permanently...${NC}"
|
||||
else
|
||||
echo ""
|
||||
echo -e "${YELLOW}Do you want to set ICICLE_BACKEND_INSTALL_DIR environment variable permanently?${NC}"
|
||||
echo "This will add 'export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\"' to your shell configuration file."
|
||||
read -p "Set environment variable permanently? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
SETUP_PERMANENT_ENV=true
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$SETUP_PERMANENT_ENV" = true ]; then
|
||||
echo "Setting ICICLE_BACKEND_INSTALL_DIR environment variable..."
|
||||
|
||||
# Detect shell and set environment variable accordingly
|
||||
if [ -n "$ZSH_VERSION" ]; then
|
||||
SHELL_RC="$HOME/.zshrc"
|
||||
elif [ -n "$BASH_VERSION" ]; then
|
||||
SHELL_RC="$HOME/.bashrc"
|
||||
else
|
||||
# Try to detect based on $SHELL
|
||||
case "$SHELL" in
|
||||
*/zsh)
|
||||
SHELL_RC="$HOME/.zshrc"
|
||||
;;
|
||||
*/bash)
|
||||
SHELL_RC="$HOME/.bashrc"
|
||||
;;
|
||||
*)
|
||||
SHELL_RC="$HOME/.profile"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Add environment variable to shell configuration
|
||||
ENV_EXPORT="export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\""
|
||||
|
||||
# Check if the variable is already set in the file
|
||||
if [ -f "$SHELL_RC" ] && grep -q "ICICLE_BACKEND_INSTALL_DIR" "$SHELL_RC"; then
|
||||
# Replace existing line
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
# macOS
|
||||
sed -i '' "s|export ICICLE_BACKEND_INSTALL_DIR=.*|$ENV_EXPORT|" "$SHELL_RC"
|
||||
else
|
||||
# Linux
|
||||
sed -i "s|export ICICLE_BACKEND_INSTALL_DIR=.*|$ENV_EXPORT|" "$SHELL_RC"
|
||||
fi
|
||||
echo "Updated existing ICICLE_BACKEND_INSTALL_DIR in $SHELL_RC"
|
||||
else
|
||||
# Add new line
|
||||
echo "$ENV_EXPORT" >> "$SHELL_RC"
|
||||
echo "Added ICICLE_BACKEND_INSTALL_DIR to $SHELL_RC"
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Environment variable set permanently.${NC}"
|
||||
else
|
||||
echo "Skipping permanent environment variable setup."
|
||||
fi
|
||||
|
||||
# Export for current session regardless
|
||||
export ICICLE_BACKEND_INSTALL_DIR="$INSTALL_DIR"
|
||||
echo "Environment variable set for current session."
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}GPU setup completed successfully!${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Important:${NC}"
|
||||
echo "1. The ICICLE_BACKEND_INSTALL_DIR environment variable has been set to: $INSTALL_DIR"
|
||||
if [ "$SETUP_PERMANENT_ENV" = true ]; then
|
||||
echo "2. Please restart your terminal or run: source $SHELL_RC"
|
||||
else
|
||||
echo "2. To use GPU features, set: export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\""
|
||||
fi
|
||||
echo "3. You can now build with GPU support using: cargo build --features gpu-accelerated"
|
||||
echo ""
|
||||
echo -e "${GREEN}Setup complete!${NC}"
|
||||
@@ -15,9 +15,6 @@ use log::{error, info};
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
use rand::prelude::SliceRandom;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(feature = "icicle")]
|
||||
use std::env;
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
pub async fn main() {
|
||||
@@ -31,12 +28,7 @@ pub async fn main() {
|
||||
init_logger();
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
banner();
|
||||
#[cfg(feature = "icicle")]
|
||||
if env::var("ENABLE_ICICLE_GPU").is_ok() {
|
||||
info!("Running with ICICLE GPU");
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
}
|
||||
|
||||
debug!(
|
||||
"command: \n {}",
|
||||
&command.as_json().to_colored_json_auto().unwrap()
|
||||
|
||||
@@ -93,17 +93,6 @@ impl From<PyG1> for G1 {
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3::ToPyObject for PyG1 {
|
||||
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
|
||||
let g1_dict = pyo3::types::PyDict::new(py);
|
||||
|
||||
g1_dict.set_item("x", self.x.to_object(py)).unwrap();
|
||||
g1_dict.set_item("y", self.y.to_object(py)).unwrap();
|
||||
g1_dict.set_item("z", self.z.to_object(py)).unwrap();
|
||||
g1_dict.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// pyclass containing the struct used for G1
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -135,16 +124,6 @@ impl From<PyG1Affine> for G1Affine {
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3::ToPyObject for PyG1Affine {
|
||||
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
|
||||
let g1_dict = pyo3::types::PyDict::new(py);
|
||||
|
||||
g1_dict.set_item("x", self.x.to_object(py)).unwrap();
|
||||
g1_dict.set_item("y", self.y.to_object(py)).unwrap();
|
||||
g1_dict.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Python class containing the struct used for run_args
|
||||
///
|
||||
/// Returns
|
||||
@@ -161,6 +140,9 @@ struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing parameters
|
||||
pub param_scale: crate::Scale,
|
||||
/// int: The scale to rebase to (optional). If None, we rebase to the max of input_scale and param_scale
|
||||
/// This is an advanced parameter that should be used with caution
|
||||
pub rebase_scale: Option<crate::Scale>,
|
||||
#[pyo3(get, set)]
|
||||
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
pub scale_rebase_multiplier: u32,
|
||||
@@ -209,6 +191,9 @@ struct PyRunArgs {
|
||||
/// float: epsilon used for arguments that use division
|
||||
#[pyo3(get, set)]
|
||||
pub epsilon: f64,
|
||||
/// bool: Whether to disable using Freivalds' argument in einsum operations
|
||||
#[pyo3(get, set)]
|
||||
pub disable_freivalds: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -227,6 +212,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
rebase_scale: py_run_args.rebase_scale,
|
||||
num_inner_cols: py_run_args.num_inner_cols,
|
||||
scale_rebase_multiplier: py_run_args.scale_rebase_multiplier,
|
||||
lookup_range: py_run_args.lookup_range,
|
||||
@@ -242,6 +228,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
|
||||
epsilon: Some(py_run_args.epsilon),
|
||||
disable_freivalds: py_run_args.disable_freivalds,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -253,6 +240,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
rebase_scale: self.rebase_scale,
|
||||
num_inner_cols: self.num_inner_cols,
|
||||
scale_rebase_multiplier: self.scale_rebase_multiplier,
|
||||
lookup_range: self.lookup_range,
|
||||
@@ -268,6 +256,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
decomp_legs: self.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
|
||||
epsilon: eps,
|
||||
disable_freivalds: self.disable_freivalds,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -694,7 +683,7 @@ fn ipa_commit(
|
||||
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
|
||||
|
||||
let srs_path =
|
||||
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
|
||||
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::IPA);
|
||||
|
||||
let srs = load_srs_prover::<IPACommitmentScheme<G1Affine>>(srs_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
|
||||
@@ -884,7 +873,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn get_srs(
|
||||
py: Python,
|
||||
py: Python<'_>,
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -1101,7 +1090,7 @@ fn gen_witness(
|
||||
let err_str = format!("Failed to generate witness: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
Python::with_gil(|py| Ok(output.into_pyobject(py).unwrap().into()))
|
||||
}
|
||||
|
||||
/// Mocks the prover
|
||||
@@ -1283,7 +1272,7 @@ fn prove(
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Python::with_gil(|py| Ok(snark.to_object(py)))
|
||||
Python::with_gil(|py| Ok(snark.into_pyobject(py).unwrap().into()))
|
||||
}
|
||||
|
||||
/// Verifies a given proof
|
||||
@@ -1649,7 +1638,7 @@ fn encode_evm_calldata<'a>(
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_verifier(
|
||||
py: Python,
|
||||
py: Python<'_>,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
@@ -1710,7 +1699,7 @@ fn create_evm_verifier(
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_vka(
|
||||
py: Python,
|
||||
py: Python<'_>,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
vka_path: PathBuf,
|
||||
@@ -1740,7 +1729,7 @@ fn create_evm_vka(
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn deploy_evm(
|
||||
py: Python,
|
||||
py: Python<'_>,
|
||||
addr_path: PathBuf,
|
||||
rpc_url: String,
|
||||
sol_code_path: PathBuf,
|
||||
@@ -1924,7 +1913,7 @@ fn verify_evm<'a>(
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_verifier_aggr(
|
||||
py: Python,
|
||||
py: Python<'_>,
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
|
||||
@@ -7,17 +7,14 @@ use halo2_proofs::{
|
||||
};
|
||||
use log::debug;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
conversion::{FromPyObject, IntoPy},
|
||||
exceptions::PyValueError,
|
||||
prelude::*,
|
||||
};
|
||||
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*, IntoPyObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use crate::{
|
||||
circuit::{
|
||||
chip::einsum::analysis::EinsumAnalysis,
|
||||
ops::base::BaseOp,
|
||||
table::{Range, RangeCheck, Table},
|
||||
},
|
||||
@@ -28,6 +25,9 @@ use std::{collections::BTreeMap, marker::PhantomData};
|
||||
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
|
||||
///
|
||||
pub mod einsum;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
|
||||
#[derive(
|
||||
@@ -86,12 +86,17 @@ impl CheckMode {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CheckMode {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
CheckMode::SAFE => "safe".to_object(py),
|
||||
CheckMode::UNSAFE => "unsafe".to_object(py),
|
||||
}
|
||||
impl<'py> IntoPyObject<'py> for CheckMode {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
CheckMode::SAFE => "safe",
|
||||
CheckMode::UNSAFE => "unsafe",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,6 +270,8 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub range_checks: RangeChecks<F>,
|
||||
/// [Selector]s for the shuffles
|
||||
pub shuffles: Shuffles,
|
||||
/// Einsum-specific configuration
|
||||
pub einsums: Option<einsum::Einsums<F>>,
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
@@ -279,6 +286,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
|
||||
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
|
||||
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
|
||||
einsums: Some(einsum::Einsums::<F>::dummy(col_size, num_inner_cols)),
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
shared_table_inputs: vec![],
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, no tables, and no Freivalds' argument.
|
||||
pub fn dummy_without_freivalds(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
Self {
|
||||
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
|
||||
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
|
||||
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
|
||||
einsums: None,
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
@@ -413,6 +436,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
},
|
||||
static_lookups: StaticLookups::default(),
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
einsums: None,
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
shared_table_inputs: vec![],
|
||||
@@ -687,6 +711,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates einsums
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_einsums(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
analysis: &EinsumAnalysis,
|
||||
num_inner_cols: usize,
|
||||
logrows: usize,
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
self.einsums = Some(einsum::Einsums::configure_universal(
|
||||
cs,
|
||||
analysis,
|
||||
num_inner_cols,
|
||||
logrows,
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_shuffles(
|
||||
|
||||
210
src/circuit/ops/chip/einsum/analysis.rs
Normal file
210
src/circuit/ops/chip/einsum/analysis.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::circuit::{
|
||||
einsum::reduction_planner::{self, Reduction},
|
||||
CircuitError,
|
||||
};
|
||||
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EinsumAnalysis {
|
||||
/// max size of input tensors
|
||||
pub max_input_size: usize,
|
||||
/// max size of output tensors
|
||||
pub max_output_size: usize,
|
||||
/// max number of input tensors
|
||||
pub max_num_inputs: usize,
|
||||
/// max number of output axes
|
||||
pub max_num_output_axes: usize,
|
||||
/// the sum of the lengths of dot product to compute all the reductions
|
||||
pub reduction_length: usize,
|
||||
}
|
||||
|
||||
/// The strategy to use for einsum
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EinsumStrategy {
|
||||
/// Use only base ops
|
||||
BaseOps,
|
||||
/// Use Freivalds' argument
|
||||
Freivalds,
|
||||
}
|
||||
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SingleEquationAnalysis {
|
||||
///
|
||||
pub equation: String,
|
||||
///
|
||||
pub num_inputs: usize,
|
||||
///
|
||||
pub max_input_size: usize,
|
||||
///
|
||||
pub output_size: usize,
|
||||
///
|
||||
pub num_output_axes: usize,
|
||||
///
|
||||
pub output_indices: Vec<char>,
|
||||
/// the length of dot product to compute all the reductions
|
||||
pub reduction_length: usize,
|
||||
/// the strategy to use for einsum
|
||||
pub strategy: EinsumStrategy,
|
||||
}
|
||||
|
||||
///
|
||||
pub fn analyze_einsum_usage(
|
||||
equations: &HashMap<(usize, String), HashMap<char, usize>>,
|
||||
) -> Result<EinsumAnalysis, CircuitError> {
|
||||
let mut max_num_inputs = 0;
|
||||
let mut max_input_size = 0;
|
||||
let mut max_output_size = 0;
|
||||
let mut max_num_output_axes = 0;
|
||||
let mut reduction_length = 0;
|
||||
|
||||
for ((_, equation), input_axes_to_dim) in equations.iter() {
|
||||
let analysis = analyze_single_equation(equation, input_axes_to_dim)?;
|
||||
max_input_size = max_input_size.max(analysis.max_input_size);
|
||||
max_output_size = max_output_size.max(analysis.output_size);
|
||||
max_num_inputs = max_num_inputs.max(analysis.num_inputs);
|
||||
max_num_output_axes = max_num_output_axes.max(analysis.num_output_axes);
|
||||
reduction_length += analysis.reduction_length;
|
||||
}
|
||||
|
||||
Ok(EinsumAnalysis {
|
||||
max_input_size,
|
||||
max_output_size,
|
||||
max_num_inputs,
|
||||
max_num_output_axes,
|
||||
reduction_length,
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
pub fn analyze_single_equation(
|
||||
equation: &str,
|
||||
input_axes_to_dim: &HashMap<char, usize>,
|
||||
) -> Result<SingleEquationAnalysis, CircuitError> {
|
||||
// Sanitise equation to remove trivial axes
|
||||
let equation = {
|
||||
let (inputs_str, output_str) = equation.split_once("->").unwrap();
|
||||
let input_equations: Vec<&str> = inputs_str.split(',').collect();
|
||||
|
||||
let inputs: Vec<String> = input_equations
|
||||
.iter()
|
||||
.map(|input| {
|
||||
input
|
||||
.chars()
|
||||
.filter(|char| *input_axes_to_dim.get(char).unwrap() > 1)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output = output_str
|
||||
.chars()
|
||||
.filter(|c| {
|
||||
input_axes_to_dim.get(c).is_some() && *input_axes_to_dim.get(c).unwrap() > 1
|
||||
})
|
||||
.collect();
|
||||
|
||||
[inputs.join(","), output].join("->")
|
||||
};
|
||||
|
||||
let (inputs_eq, output_eq) = equation.split_once("->").unwrap();
|
||||
let input_equations: Vec<&str> = inputs_eq.split(',').collect();
|
||||
|
||||
let max_input_size = input_equations
|
||||
.iter()
|
||||
.map(|eqn| {
|
||||
eqn.chars()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap())
|
||||
.product()
|
||||
})
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
let output_indices: Vec<char> = output_eq.chars().collect();
|
||||
let output_dims = output_indices
|
||||
.iter()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap());
|
||||
let output_size = output_dims.clone().product();
|
||||
|
||||
let output_reduction_length = {
|
||||
let mut output_dims = output_dims.rev().cloned().collect_vec();
|
||||
let mut total_length = 0;
|
||||
for _ in 0..output_dims.len() {
|
||||
let dot_product_len = output_dims.remove(0);
|
||||
let num_dot_products: usize = output_dims.iter().product();
|
||||
total_length += dot_product_len * num_dot_products;
|
||||
}
|
||||
total_length
|
||||
};
|
||||
|
||||
let input_reductions_length = {
|
||||
let input_reductions = reduction_planner::input_reductions(&equation)?;
|
||||
input_reductions
|
||||
.into_iter()
|
||||
.map(|reduction| {
|
||||
let (_, output_expr) = reduction.expression().split_once("->").unwrap();
|
||||
let num_inputs = reduction.input_indices().len();
|
||||
let dot_product_len = match reduction {
|
||||
Reduction::RLC { axis, .. } => *input_axes_to_dim.get(&axis).unwrap(),
|
||||
Reduction::Contraction { axis, .. } => *axis
|
||||
.and_then(|axis| input_axes_to_dim.get(&axis))
|
||||
.unwrap_or(&1),
|
||||
};
|
||||
let num_dot_products: usize = output_expr
|
||||
.chars()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap())
|
||||
.product();
|
||||
// since `multi_dot` does pairwise mult between input pairs and final summation
|
||||
if num_inputs <= 2 {
|
||||
num_dot_products * dot_product_len
|
||||
} else {
|
||||
num_dot_products * (dot_product_len * num_inputs)
|
||||
}
|
||||
})
|
||||
.sum::<usize>()
|
||||
};
|
||||
|
||||
let dispatch_to_einsum_with_base_ops = {
|
||||
let mut seen = HashSet::new();
|
||||
let mut common_indices_to_inputs = vec![];
|
||||
for input in input_equations.iter() {
|
||||
for c in input.chars() {
|
||||
if !seen.contains(&c) {
|
||||
seen.insert(c);
|
||||
} else {
|
||||
common_indices_to_inputs.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
let non_common_indices = input_axes_to_dim
|
||||
.keys()
|
||||
.filter(|&x| {
|
||||
!common_indices_to_inputs.contains(x)
|
||||
&& input_axes_to_dim.get(x).cloned().unwrap() > 1
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
!(output_indices.len() > 0
|
||||
&& common_indices_to_inputs.len() > 0
|
||||
&& non_common_indices.len() > 1)
|
||||
};
|
||||
|
||||
let strategy = if dispatch_to_einsum_with_base_ops {
|
||||
EinsumStrategy::BaseOps
|
||||
} else {
|
||||
EinsumStrategy::Freivalds
|
||||
};
|
||||
|
||||
Ok(SingleEquationAnalysis {
|
||||
output_size,
|
||||
max_input_size,
|
||||
equation: equation.to_string(),
|
||||
num_inputs: input_equations.len(),
|
||||
num_output_axes: output_indices.len(),
|
||||
output_indices,
|
||||
reduction_length: output_reduction_length + input_reductions_length,
|
||||
strategy,
|
||||
})
|
||||
}
|
||||
54
src/circuit/ops/chip/einsum/circuit_params.rs
Normal file
54
src/circuit/ops/chip/einsum/circuit_params.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use std::{collections::HashMap, marker::PhantomData};
|
||||
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
tensor::{Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
/// Circuit parameter for a single einsum equation
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct SingleEinsumParams<F: PrimeField + TensorType + PartialOrd> {
|
||||
///
|
||||
pub equation: String,
|
||||
/// Map from input axes to dimensions
|
||||
pub input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> SingleEinsumParams<F> {
|
||||
///
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
359
src/circuit/ops/chip/einsum/layouts.rs
Normal file
359
src/circuit/ops/chip/einsum/layouts.rs
Normal file
@@ -0,0 +1,359 @@
|
||||
use halo2curves::ff::PrimeField;
|
||||
use log::{error, trace};
|
||||
|
||||
use crate::{
|
||||
circuit::{base::BaseOp, einsum::BaseOpInfo, region::RegionCtx, CheckMode, CircuitError},
|
||||
tensor::{
|
||||
get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
TensorError, TensorType, ValTensor, ValType,
|
||||
},
|
||||
};
|
||||
|
||||
use super::ContractionConfig;
|
||||
|
||||
/// Pairwise (elementwise) op layout
|
||||
pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 2],
|
||||
op: BaseOp,
|
||||
phases: &[usize; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let (mut lhs, mut rhs) = if phases[0] <= phases[1] {
|
||||
(values[0].clone(), values[1].clone())
|
||||
} else {
|
||||
(values[1].clone(), values[0].clone())
|
||||
};
|
||||
|
||||
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
|
||||
|
||||
lhs.expand(&broadcasted_shape)?;
|
||||
rhs.expand(&broadcasted_shape)?;
|
||||
|
||||
if lhs.len() != rhs.len() {
|
||||
return Err(CircuitError::DimMismatch(format!(
|
||||
"pairwise {} layout",
|
||||
op.as_str()
|
||||
)));
|
||||
}
|
||||
|
||||
region.flush_einsum()?;
|
||||
|
||||
let input_vars = config.get_input_vars(phases.as_slice().into());
|
||||
let output_var = config.get_output_var(phases.as_slice().into());
|
||||
|
||||
let inputs = [lhs, rhs]
|
||||
.iter()
|
||||
.zip(input_vars)
|
||||
.map(|(val, var)| {
|
||||
let res = region.assign_einsum(var, val)?;
|
||||
Ok(res.get_inner()?)
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
|
||||
// Now we can assign the dot product
|
||||
// time the calc
|
||||
let op_result = match op {
|
||||
BaseOp::Add => add(&inputs),
|
||||
BaseOp::Sub => sub(&inputs),
|
||||
BaseOp::Mult => mult(&inputs),
|
||||
_ => return Err(CircuitError::UnsupportedOp),
|
||||
}
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
|
||||
let assigned_len = op_result.len();
|
||||
let mut output = region.assign_einsum(output_var, &op_result.into())?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = output_var.cartesian_coord(region.einsum_col_coord() + i);
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: op.clone(),
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
let selector = config.selectors.get(&(op_info, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
output.reshape(&broadcasted_shape)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 1],
|
||||
phase: usize,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() == 1 {
|
||||
return Ok(values[0].clone());
|
||||
}
|
||||
assert!(phase == 0 || phase == 1);
|
||||
|
||||
region.flush_einsum()?;
|
||||
let mut input = values[0].clone();
|
||||
|
||||
let block_width = config.block_width();
|
||||
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let var = config.get_input_vars([phase].as_slice().into())[0];
|
||||
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
|
||||
// Now we can assign the dot product
|
||||
let accumulated_sum = accumulated::sum(&input, block_width)?;
|
||||
|
||||
let output_var = config.get_output_var([phase].as_slice().into());
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
output_var,
|
||||
&accumulated_sum.into(),
|
||||
check_mode,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
for i in 0..output_assigned_len {
|
||||
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// skip over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
continue;
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::SumInit,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::Sum,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
|
||||
region.enable(selector, z)?;
|
||||
}
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
// last element is the result
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 1],
|
||||
phase: usize,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phase == 0 || phase == 1);
|
||||
region.flush_einsum()?;
|
||||
let block_width = config.block_width();
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
|
||||
let var = config.get_input_vars([phase].as_slice().into())[0];
|
||||
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
|
||||
// Now we can assign the dot product
|
||||
let accumulated_prod = accumulated::prod(&input, block_width)?;
|
||||
|
||||
let output_var = config.get_output_var([phase].as_slice().into());
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
output_var,
|
||||
&accumulated_prod.into(),
|
||||
check_mode,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output_assigned_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) =
|
||||
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// skip over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::CumProdInit,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::CumProd,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
// last element is the result
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 2],
|
||||
phases: &[usize; 2],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() != values[1].len() {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
}
|
||||
|
||||
region.flush_einsum()?;
|
||||
// time this entire function run
|
||||
let global_start = instant::Instant::now();
|
||||
|
||||
let mut values = if phases[0] <= phases[1] {
|
||||
[values[0].clone(), values[1].clone()]
|
||||
} else {
|
||||
[values[1].clone(), values[0].clone()]
|
||||
};
|
||||
let vars = config.get_input_vars(phases.as_slice().into());
|
||||
|
||||
let mut inputs = vec![];
|
||||
let block_width = config.block_width();
|
||||
|
||||
let mut assigned_len = 0;
|
||||
for (val, var) in values.iter_mut().zip(vars) {
|
||||
val.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let inp = {
|
||||
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
inputs.push(inp);
|
||||
}
|
||||
|
||||
// Now we can assign the dot product
|
||||
// time this step
|
||||
let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?;
|
||||
let output_var = config.get_output_var(phases.as_slice().into());
|
||||
let (output, output_assigned_len) = region
|
||||
.assign_einsum_with_duplication_constrained(output_var, &accumulated_dot.into(), check_mode)
|
||||
.expect("failed to assign einsum with duplication constrained");
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output_assigned_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) =
|
||||
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// hop over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::DotInit,
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::Dot,
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
region.enable(selector, z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("dot layout took: {:?}, row {}", elapsed, region.row());
|
||||
trace!("----------------------------");
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
/// Dot product of more than two tensors
|
||||
pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>],
|
||||
phases: &[usize],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phases.iter().all(|phase| *phase == 0 || *phase == 1));
|
||||
if !values.iter().all(|value| value.len() == values[0].len()) {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
}
|
||||
// time this entire function run
|
||||
let global_start = instant::Instant::now();
|
||||
|
||||
let values: Vec<ValTensor<F>> = values.iter().copied().cloned().collect();
|
||||
// do pairwise dot product between intermediate tensor and the next tensor
|
||||
let (intermediate, output_phase) = values
|
||||
.into_iter()
|
||||
.zip(phases.iter().cloned())
|
||||
.reduce(|(intermediate, intermediate_phase), (input, phase)| {
|
||||
(
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[&intermediate, &input],
|
||||
BaseOp::Mult,
|
||||
&[intermediate_phase, phase],
|
||||
)
|
||||
.unwrap(),
|
||||
std::cmp::max(intermediate_phase, phase),
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let accumulated_dot = sum(config, region, &[&intermediate], output_phase, check_mode)?;
|
||||
let last_elem = accumulated_dot.last()?;
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("multi_dot layout took: {:?}, row {}", elapsed, region.row());
|
||||
trace!("----------------------------");
|
||||
Ok(last_elem)
|
||||
}
|
||||
867
src/circuit/ops/chip/einsum/mod.rs
Normal file
867
src/circuit/ops/chip/einsum/mod.rs
Normal file
@@ -0,0 +1,867 @@
|
||||
use crate::circuit::base::BaseOp;
|
||||
use crate::circuit::chip::einsum::analysis::{analyze_single_equation, EinsumAnalysis};
|
||||
use crate::circuit::einsum::layouts::{pairwise, sum};
|
||||
use crate::circuit::einsum::reduction_planner::Reduction;
|
||||
use crate::circuit::layouts::einsum_with_base_ops;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::{BaseConfig, CheckMode, CircuitError};
|
||||
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
Challenge, ConstraintSystem, Constraints, Expression, FirstPhase, Selector,
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use layouts::{dot, multi_dot, prod};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
///
|
||||
pub mod analysis;
|
||||
///
|
||||
pub mod circuit_params;
|
||||
mod layouts;
|
||||
mod reduction_planner;
|
||||
|
||||
/// The maximum number of challenges
|
||||
pub const NUM_MAX_EINSUM_CHALLENGES: usize = 10;
|
||||
|
||||
/// A struct representing reductions for the einsums
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// custom gate to constrain tensor contractions
|
||||
contraction_gate: ContractionConfig<F>,
|
||||
/// custom gate to constrain random linear combinations used by Freivalds' argument
|
||||
rlc_gates: Vec<RLCConfig<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
///
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
let dummy_contraction_gate = ContractionConfig {
|
||||
inputs: [
|
||||
[dummy_var.clone(), dummy_var.clone()],
|
||||
[dummy_var.clone(), dummy_var.clone()],
|
||||
],
|
||||
outputs: [dummy_var.clone(), dummy_var.clone()],
|
||||
selectors: BTreeMap::default(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
Self {
|
||||
contraction_gate: dummy_contraction_gate,
|
||||
rlc_gates: (0..NUM_MAX_EINSUM_CHALLENGES)
|
||||
.map(|_| RLCConfig::dummy(&dummy_var))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure the columns based on universal Einsum analysis
|
||||
pub fn configure_universal(
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
analysis: &EinsumAnalysis,
|
||||
num_inner_cols: usize,
|
||||
logrows: usize,
|
||||
) -> Self {
|
||||
let capacity = analysis.reduction_length;
|
||||
let inputs: [VarTensor; 4] = [
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
];
|
||||
let outputs = [
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
];
|
||||
let contraction_gate = ContractionConfig::new(
|
||||
meta,
|
||||
&[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]],
|
||||
&[&outputs[0], &outputs[1]],
|
||||
);
|
||||
|
||||
let mut rlc_gates = vec![];
|
||||
for _ in 0..analysis.max_num_output_axes {
|
||||
let rlc_gate =
|
||||
RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &outputs[1]);
|
||||
rlc_gates.push(rlc_gate);
|
||||
}
|
||||
|
||||
Self {
|
||||
contraction_gate,
|
||||
rlc_gates,
|
||||
}
|
||||
}
|
||||
|
||||
/// In dummy layout phase, calling this function will return error
|
||||
pub fn challenges(&self) -> Result<Vec<Challenge>, CircuitError> {
|
||||
self.rlc_gates
|
||||
.iter()
|
||||
.map(|gate| gate.challenge.ok_or(CircuitError::ChallengeNotSet))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
///
|
||||
pub fn assign_einsum(
|
||||
&self,
|
||||
base_config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
input_tensors: &[&ValTensor<F>],
|
||||
output_tensor: &ValTensor<F>,
|
||||
equation: &str,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<(), CircuitError> {
|
||||
region.set_num_einsum_inner_cols(self.contraction_gate.block_width());
|
||||
|
||||
let (input_exprs, _) = equation.split_once("->").unwrap();
|
||||
let input_exprs = input_exprs.split(",").collect_vec();
|
||||
assert_eq!(input_exprs.len(), input_tensors.len());
|
||||
|
||||
let mut input_tensors = input_tensors.iter().copied().cloned().collect_vec();
|
||||
let mut output_tensor = output_tensor.clone();
|
||||
|
||||
let mut input_axes_to_dim: HashMap<char, usize> = HashMap::new();
|
||||
input_exprs
|
||||
.iter()
|
||||
.zip(input_tensors.iter())
|
||||
.for_each(|(indices, tensor)| {
|
||||
indices.chars().zip(tensor.dims()).for_each(|(index, dim)| {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) =
|
||||
input_axes_to_dim.entry(index)
|
||||
{
|
||||
e.insert(*dim);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let equation_analysis = analyze_single_equation(&equation, &input_axes_to_dim)?;
|
||||
let equation = equation_analysis.equation;
|
||||
|
||||
// Remove trivial axes from tensors
|
||||
input_tensors
|
||||
.iter_mut()
|
||||
.map(|tensor| tensor.remove_trivial_axes())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
output_tensor.remove_trivial_axes()?;
|
||||
|
||||
if matches!(
|
||||
equation_analysis.strategy,
|
||||
analysis::EinsumStrategy::BaseOps
|
||||
) {
|
||||
let _ = einsum_with_base_ops(
|
||||
base_config,
|
||||
region,
|
||||
&input_tensors.iter().collect_vec(),
|
||||
&equation,
|
||||
)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let output_shape = equation_analysis
|
||||
.output_indices
|
||||
.iter()
|
||||
.map(|c| input_axes_to_dim.get(c).copied().unwrap())
|
||||
.collect_vec();
|
||||
let squashed_output =
|
||||
self.assign_output(region, &output_tensor, output_shape, check_mode)?;
|
||||
|
||||
// reorder the reduction of input tensors and reduce
|
||||
let reordered_input_reductions = reduction_planner::input_reductions(&equation).unwrap();
|
||||
let mut tensors = input_tensors;
|
||||
let mut reduced_input_phase = 0;
|
||||
|
||||
for reduction in reordered_input_reductions.iter() {
|
||||
let (input_expr, output_expr) = reduction.expression().split_once("->").unwrap();
|
||||
let input_exprs = input_expr.split(",").collect_vec();
|
||||
|
||||
let remaining_axes = output_expr.chars().collect_vec();
|
||||
let mut remaining_axes_indices = remaining_axes
|
||||
.iter()
|
||||
.map(|c| 0..input_axes_to_dim[c])
|
||||
.multi_cartesian_product()
|
||||
.collect_vec();
|
||||
|
||||
// Dummy value to ensure the for loop runs at least once
|
||||
if remaining_axes.is_empty() {
|
||||
remaining_axes_indices.push(vec![]);
|
||||
}
|
||||
|
||||
let input_tensors = reduction
|
||||
.input_indices()
|
||||
.iter()
|
||||
.map(|idx| tensors[*idx].clone())
|
||||
.collect_vec();
|
||||
|
||||
let mut flattened_input_tensors: Vec<Vec<ValTensor<F>>> =
|
||||
vec![vec![]; input_tensors.len()];
|
||||
for remaining_axes_indices in remaining_axes_indices {
|
||||
// corresponds to 1 running sum of input tensors
|
||||
for (i, (input_tensor, input_expr)) in
|
||||
input_tensors.iter().zip(input_exprs.iter()).enumerate()
|
||||
{
|
||||
let mut sliced_dim = vec![];
|
||||
input_expr.chars().for_each(|axis| {
|
||||
if let Some(pos) = remaining_axes.iter().position(|c| *c == axis) {
|
||||
sliced_dim
|
||||
.push(remaining_axes_indices[pos]..remaining_axes_indices[pos] + 1);
|
||||
} else {
|
||||
// common axis
|
||||
sliced_dim.push(0..input_axes_to_dim[&axis]);
|
||||
}
|
||||
});
|
||||
let mut sliced_input_tensor = input_tensor.get_slice(&sliced_dim)?;
|
||||
sliced_input_tensor.flatten();
|
||||
flattened_input_tensors[i].push(sliced_input_tensor);
|
||||
}
|
||||
}
|
||||
let flattened_input_tensors = flattened_input_tensors
|
||||
.into_iter()
|
||||
.map(|tensors| {
|
||||
ValTensor::from(
|
||||
tensors
|
||||
.into_iter()
|
||||
.flat_map(|t| t.get_inner_tensor().unwrap().clone().into_iter())
|
||||
.collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let output_dims = output_expr
|
||||
.chars()
|
||||
.map(|c| input_axes_to_dim[&c])
|
||||
.collect_vec();
|
||||
|
||||
let contracted_output = match reduction {
|
||||
Reduction::RLC {
|
||||
axis,
|
||||
input_phase,
|
||||
challenge_index,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(flattened_input_tensors.len(), 1);
|
||||
let rlc_len = input_axes_to_dim[axis];
|
||||
let mut result = self.rlc_gates[*challenge_index].assign_rlc(
|
||||
region,
|
||||
&flattened_input_tensors[0],
|
||||
region.challenges()[*challenge_index],
|
||||
rlc_len,
|
||||
*input_phase,
|
||||
check_mode,
|
||||
)?;
|
||||
result.reshape(&output_dims)?;
|
||||
result
|
||||
}
|
||||
Reduction::Contraction {
|
||||
axis, input_phases, ..
|
||||
} => match axis {
|
||||
Some(axis) => {
|
||||
let dot_product_len = input_axes_to_dim[axis];
|
||||
assign_input_contraction(
|
||||
&self.contraction_gate,
|
||||
region,
|
||||
flattened_input_tensors,
|
||||
dot_product_len,
|
||||
&output_dims,
|
||||
input_phases,
|
||||
check_mode,
|
||||
)?
|
||||
}
|
||||
None => {
|
||||
let mut result = assign_pairwise_mult(
|
||||
&self.contraction_gate,
|
||||
region,
|
||||
flattened_input_tensors,
|
||||
input_phases,
|
||||
)?;
|
||||
result.reshape(&output_dims)?;
|
||||
result
|
||||
}
|
||||
},
|
||||
};
|
||||
tensors.push(contracted_output);
|
||||
reduced_input_phase = reduction.output_phase();
|
||||
}
|
||||
tensors.retain(|tensor| tensor.is_singleton());
|
||||
|
||||
let scalars: ValTensor<F> = tensors
|
||||
.into_iter()
|
||||
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
|
||||
.collect_vec()
|
||||
.into();
|
||||
let squashed_input = prod(
|
||||
&self.contraction_gate,
|
||||
region,
|
||||
&[&scalars],
|
||||
reduced_input_phase,
|
||||
check_mode,
|
||||
)?;
|
||||
|
||||
region.constrain_equal(&squashed_input, &squashed_output)
|
||||
}
|
||||
|
||||
fn assign_output(
|
||||
&self,
|
||||
region: &mut RegionCtx<F>,
|
||||
output: &ValTensor<F>,
|
||||
output_shape: Vec<usize>,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut intermediate_values = output.clone();
|
||||
|
||||
let challenges = region
|
||||
.challenges()
|
||||
.iter()
|
||||
.take(output_shape.len())
|
||||
.copied()
|
||||
.collect_vec();
|
||||
|
||||
// Loop over the output axes
|
||||
for (idx, (rlc_config, challenge)) in self
|
||||
.rlc_gates
|
||||
.iter()
|
||||
.take(output_shape.len())
|
||||
.zip(challenges.iter())
|
||||
.rev()
|
||||
.enumerate()
|
||||
{
|
||||
let rlc_len = output_shape[output_shape.len() - idx - 1];
|
||||
intermediate_values.flatten();
|
||||
let phase = if idx > 0 { 1 } else { 0 };
|
||||
intermediate_values = rlc_config.assign_rlc(
|
||||
region,
|
||||
&intermediate_values,
|
||||
*challenge,
|
||||
rlc_len,
|
||||
phase,
|
||||
check_mode,
|
||||
)?;
|
||||
}
|
||||
|
||||
let phase = if challenges.len() > 0 { 1 } else { 0 };
|
||||
let output_var = self
|
||||
.contraction_gate
|
||||
.get_output_var([phase].as_slice().into());
|
||||
let res = region.assign_einsum(output_var, &intermediate_values)?;
|
||||
region.increment_einsum_col_coord(1);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_tensors: Vec<ValTensor<F>>,
|
||||
input_phases: &[usize],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert_eq!(flattened_tensors.len(), input_phases.len());
|
||||
let (result, _) = flattened_tensors
|
||||
.into_iter()
|
||||
.zip(input_phases.iter().cloned())
|
||||
.reduce(|(acc, acc_phase), (input, phase)| {
|
||||
(
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[&acc, &input],
|
||||
BaseOp::Mult,
|
||||
&[acc_phase, phase],
|
||||
)
|
||||
.unwrap(),
|
||||
std::cmp::max(acc_phase, phase),
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_tensors: Vec<ValTensor<F>>,
|
||||
dot_product_len: usize,
|
||||
output_shape: &[usize],
|
||||
input_phases: &[usize],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert_eq!(flattened_tensors.len(), input_phases.len());
|
||||
let num_dot_products = output_shape.iter().product();
|
||||
let mut dot_product_results = vec![];
|
||||
for chunk_idx in 0..num_dot_products {
|
||||
let start = chunk_idx * dot_product_len;
|
||||
let tensors: Vec<_> = flattened_tensors
|
||||
.iter()
|
||||
.map(|tensor| tensor.get_slice(&[start..(start + dot_product_len)]))
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
let result = if tensors.len() == 1 {
|
||||
sum(config, region, &[&tensors[0]], input_phases[0], check_mode)?
|
||||
} else if tensors.len() == 2 {
|
||||
dot(
|
||||
config,
|
||||
region,
|
||||
&[&tensors[0], &tensors[1]],
|
||||
&[input_phases[0], input_phases[1]],
|
||||
check_mode,
|
||||
)?
|
||||
} else {
|
||||
multi_dot(
|
||||
config,
|
||||
region,
|
||||
tensors.iter().collect_vec().as_slice(),
|
||||
input_phases,
|
||||
check_mode,
|
||||
)?
|
||||
};
|
||||
dot_product_results.push(result.get_inner_tensor()?.get_scalar());
|
||||
}
|
||||
let mut tensor = ValTensor::from(dot_product_results);
|
||||
tensor.reshape(output_shape)?;
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)]
|
||||
enum InputPhases {
|
||||
FirstPhase,
|
||||
SecondPhase,
|
||||
BothFirstPhase, // [0, 0]
|
||||
Mixed, // [0, 1] or [1, 0]
|
||||
BothSecondPhase, // [1, 1]
|
||||
}
|
||||
|
||||
impl From<&[usize]> for InputPhases {
|
||||
fn from(phases: &[usize]) -> Self {
|
||||
match phases {
|
||||
[0] => Self::FirstPhase,
|
||||
[1] => Self::SecondPhase,
|
||||
[0, 0] => Self::BothFirstPhase,
|
||||
[0, 1] | [1, 0] => Self::Mixed,
|
||||
[1, 1] => Self::BothSecondPhase,
|
||||
_ => panic!("Invalid phase combination"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||
struct BaseOpInfo {
|
||||
pub op_kind: BaseOp,
|
||||
pub input_phases: InputPhases,
|
||||
}
|
||||
|
||||
/// `ContractionConfig` is the custom gate to constrain tensor contractions
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct ContractionConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
// [[first phase, first phase], [second phase, second phase]]
|
||||
inputs: [[VarTensor; 2]; 2],
|
||||
// [first phase, second phase]
|
||||
outputs: [VarTensor; 2],
|
||||
// (BaseOpInfo, block index, inner column index) -> selector
|
||||
selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
fn get_input_vars(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
|
||||
match input_phases {
|
||||
InputPhases::FirstPhase => vec![&self.inputs[0][0]],
|
||||
InputPhases::SecondPhase => vec![&self.inputs[1][0]],
|
||||
InputPhases::BothFirstPhase => vec![&self.inputs[0][0], &self.inputs[0][1]],
|
||||
InputPhases::Mixed => vec![&self.inputs[0][0], &self.inputs[1][0]],
|
||||
InputPhases::BothSecondPhase => vec![&self.inputs[1][0], &self.inputs[1][1]],
|
||||
}
|
||||
}
|
||||
|
||||
fn get_output_var(&self, input_phases: InputPhases) -> &VarTensor {
|
||||
match input_phases {
|
||||
InputPhases::FirstPhase => &self.outputs[0],
|
||||
InputPhases::SecondPhase => &self.outputs[1],
|
||||
InputPhases::BothFirstPhase => &self.outputs[0],
|
||||
InputPhases::Mixed => &self.outputs[1],
|
||||
InputPhases::BothSecondPhase => &self.outputs[1],
|
||||
}
|
||||
}
|
||||
|
||||
fn block_width(&self) -> usize {
|
||||
self.outputs[0].num_inner_cols()
|
||||
}
|
||||
|
||||
fn new(
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
inputs: &[&[&VarTensor; 2]; 2],
|
||||
outputs: &[&VarTensor; 2],
|
||||
) -> Self {
|
||||
let mut selectors = BTreeMap::new();
|
||||
let num_blocks = outputs[0].num_blocks();
|
||||
let block_width = outputs[0].num_inner_cols();
|
||||
for input_phases in [
|
||||
InputPhases::BothFirstPhase,
|
||||
InputPhases::Mixed,
|
||||
InputPhases::BothSecondPhase,
|
||||
] {
|
||||
for i in 0..num_blocks {
|
||||
for j in 0..block_width {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Mult,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
j,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
for i in 0..num_blocks {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::DotInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Dot,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for input_phases in [InputPhases::FirstPhase, InputPhases::SecondPhase] {
|
||||
for i in 0..num_blocks {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::SumInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Sum,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::CumProdInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::CumProd,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
}
|
||||
for ((base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
|
||||
let inputs = match base_op.input_phases {
|
||||
InputPhases::FirstPhase => vec![inputs[0][0]],
|
||||
InputPhases::SecondPhase => vec![inputs[1][0]],
|
||||
InputPhases::BothFirstPhase => vec![inputs[0][0], inputs[0][1]],
|
||||
InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]],
|
||||
InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]],
|
||||
};
|
||||
let output = match base_op.input_phases {
|
||||
InputPhases::FirstPhase => outputs[0],
|
||||
InputPhases::SecondPhase => outputs[1],
|
||||
InputPhases::BothFirstPhase => outputs[0],
|
||||
InputPhases::Mixed => outputs[1],
|
||||
InputPhases::BothSecondPhase => outputs[1],
|
||||
};
|
||||
assert_eq!(inputs.len(), base_op.op_kind.num_inputs());
|
||||
match base_op.op_kind {
|
||||
BaseOp::Mult => {
|
||||
meta.create_gate(base_op.op_kind.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
|
||||
let zero = Expression::<F>::Constant(F::ZERO);
|
||||
let mut qis = vec![zero; 2];
|
||||
for (q_i, input) in qis.iter_mut().zip(inputs) {
|
||||
*q_i = input
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("contraction config: input query failed")[0]
|
||||
.clone()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("contraction config: output query failed");
|
||||
|
||||
let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
meta.create_gate(base_op.op_kind.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
let mut qis = vec![vec![]; 2];
|
||||
for (q_i, input) in qis.iter_mut().zip(inputs) {
|
||||
*q_i = input
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("contraction config: input query failed")
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
|
||||
.expect("contraction config: output query failed");
|
||||
|
||||
let res = base_op.op_kind.accum_f(
|
||||
expected_output[0].clone(),
|
||||
qis[1].clone(),
|
||||
qis[0].clone(),
|
||||
);
|
||||
let constraints =
|
||||
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res];
|
||||
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let first_phase_inputs: [VarTensor; 2] = inputs[0]
|
||||
.iter()
|
||||
.copied()
|
||||
.cloned()
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let second_phase_inputs: [VarTensor; 2] = inputs[1]
|
||||
.iter()
|
||||
.copied()
|
||||
.cloned()
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
inputs: [first_phase_inputs, second_phase_inputs],
|
||||
outputs: [outputs[0].clone(), outputs[1].clone()],
|
||||
selectors,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `RLCConfig` is the custom gate used for random linear combination with the specific challenge
|
||||
#[derive(Clone, Debug)]
|
||||
struct RLCConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// The challenge used for the random linear combination
|
||||
/// Challenge is `None` in the dummy configuration
|
||||
pub challenge: Option<Challenge>,
|
||||
/// [first phase, second phase]
|
||||
pub inputs: [VarTensor; 2],
|
||||
pub output: VarTensor,
|
||||
/// (phase of input, block index) -> (init selector, acc selector)
|
||||
pub selectors: BTreeMap<(usize, usize), (Selector, Selector)>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
|
||||
fn dummy(dummy_var: &VarTensor) -> Self {
|
||||
let challenge = None;
|
||||
let inputs = [dummy_var.clone(), dummy_var.clone()];
|
||||
let output = dummy_var.clone();
|
||||
let selectors = BTreeMap::new();
|
||||
Self {
|
||||
challenge,
|
||||
inputs,
|
||||
output,
|
||||
selectors,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn new(meta: &mut ConstraintSystem<F>, inputs: &[VarTensor; 2], output: &VarTensor) -> Self {
|
||||
let challenge = meta.challenge_usable_after(FirstPhase);
|
||||
|
||||
let mut selectors = BTreeMap::new();
|
||||
for (phase, input) in inputs.iter().enumerate() {
|
||||
for block_idx in 0..input.num_blocks() {
|
||||
let selector = (meta.selector(), meta.selector());
|
||||
selectors.insert((phase, block_idx), selector);
|
||||
}
|
||||
}
|
||||
let block_width = output.num_inner_cols();
|
||||
let powers_of_challenge = (0..block_width)
|
||||
.scan(Expression::Constant(F::ONE), |r_power, _| {
|
||||
*r_power = r_power.clone() * challenge.expr();
|
||||
Some(r_power.clone())
|
||||
})
|
||||
.collect_vec();
|
||||
for ((phase, block_idx), (init_selector, acc_selector)) in selectors.iter() {
|
||||
meta.create_gate("init", |meta| {
|
||||
let selector = meta.query_selector(*init_selector);
|
||||
let input_exprs = inputs[*phase]
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("rlc config: input query failed")
|
||||
.into_iter()
|
||||
.collect();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, 0, 1)
|
||||
.expect("rlc config: output query failed");
|
||||
|
||||
let res = BaseOp::Dot.accum_f(
|
||||
Expression::Constant(F::ZERO),
|
||||
powers_of_challenge.iter().cloned().rev().collect_vec(),
|
||||
input_exprs,
|
||||
);
|
||||
vec![expected_output[0].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
meta.create_gate("acc", |meta| {
|
||||
let selector = meta.query_selector(*acc_selector);
|
||||
let input_exprs = inputs[*phase]
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("rlc config: input query failed")
|
||||
.into_iter()
|
||||
.collect();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, -1, 2)
|
||||
.expect("rlc config: output query failed");
|
||||
|
||||
let res = BaseOp::Dot.accum_f(
|
||||
expected_output[0].clone() * powers_of_challenge.last().cloned().unwrap(),
|
||||
powers_of_challenge.iter().cloned().rev().collect_vec(),
|
||||
input_exprs,
|
||||
);
|
||||
vec![expected_output[1].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
Self {
|
||||
inputs: inputs.clone(),
|
||||
output: output.clone(),
|
||||
selectors,
|
||||
challenge: Some(challenge),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_rlc(
|
||||
&self,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_input: &ValTensor<F>,
|
||||
challenge: Value<F>,
|
||||
rlc_len: usize,
|
||||
phase: usize,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
region.flush_einsum()?;
|
||||
let block_width = self.output.num_inner_cols();
|
||||
let powers_of_challenge = (0..block_width)
|
||||
.scan(Value::known(F::ONE), |challenge_power, _| {
|
||||
*challenge_power = challenge_power.clone() * challenge;
|
||||
Some(challenge_power.clone())
|
||||
})
|
||||
.collect_vec();
|
||||
let mut rlc_results: Vec<ValType<F>> = vec![];
|
||||
for tensor in flattened_input.get_inner_tensor()?.chunks_exact(rlc_len) {
|
||||
let running_sums = tensor
|
||||
.iter()
|
||||
.chunks(block_width)
|
||||
.into_iter()
|
||||
.scan(Value::known(F::ZERO), |state, val| {
|
||||
let curr_sum: Value<F> = val
|
||||
.into_iter()
|
||||
.zip(powers_of_challenge.iter().rev())
|
||||
.map(|(v, c_power)| {
|
||||
c_power.and_then(|c_power| {
|
||||
v.get_felt_eval()
|
||||
.and_then(|v| Some(Value::known(c_power * v)))
|
||||
.unwrap_or(Value::unknown())
|
||||
})
|
||||
})
|
||||
.reduce(|acc, v| acc + v)
|
||||
.unwrap();
|
||||
*state = *state * powers_of_challenge.last().unwrap() + curr_sum;
|
||||
Some(*state)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let assigned_len = {
|
||||
let mut input: ValTensor<F> = tensor.iter().collect_vec().into();
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let (_, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(&self.inputs[phase], &input)?;
|
||||
len
|
||||
};
|
||||
let (assigned_output, assigned_output_len) = {
|
||||
let running_sums = running_sums.into_iter().map(ValType::from).collect_vec();
|
||||
region.assign_einsum_with_duplication_constrained(
|
||||
&self.output,
|
||||
&running_sums.into(),
|
||||
check_mode,
|
||||
)?
|
||||
};
|
||||
|
||||
(0..assigned_output_len)
|
||||
.map(|i| {
|
||||
let (block_idx, _, z) = self
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
self.selectors
|
||||
.get(&(phase, block_idx))
|
||||
.map(|(init, _)| init)
|
||||
} else {
|
||||
self.selectors.get(&(phase, block_idx)).map(|(_, acc)| acc)
|
||||
};
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
rlc_results.push(assigned_output.last()?.get_inner_tensor()?.get_scalar());
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
}
|
||||
Ok(rlc_results.into())
|
||||
}
|
||||
}
|
||||
205
src/circuit/ops/chip/einsum/reduction_planner.rs
Normal file
205
src/circuit/ops/chip/einsum/reduction_planner.rs
Normal file
@@ -0,0 +1,205 @@
|
||||
use std::{collections::BTreeSet, ops::Index};
|
||||
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
tensor::{TensorType, ValTensor},
|
||||
};
|
||||
|
||||
/// inj,jk->ik [inj,jk]
|
||||
/// inj,i->nj => RLC [jk,nj]
|
||||
/// jk,k->j => RLC [nj,j]
|
||||
/// nj,j->n => Contraction [n]
|
||||
/// n-> => Contraction []
|
||||
///
|
||||
/// bn,anm,bm->ba [bn,anm,bm]
|
||||
/// bn,bm->bnm => Contraction [anm,bnm]
|
||||
/// bnm,b->nm => RLC [anm,nm]
|
||||
/// anm,a->nm => RLC [nm,nm]
|
||||
/// nm,nm->m => Contraction [m]
|
||||
/// m-> => Contraction []
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Reduction {
|
||||
/// Random linear combination with powers of challenge along the axis
|
||||
RLC {
|
||||
expression: String,
|
||||
axis: char,
|
||||
/// Uniquely identifying index of input tensor to be reduced
|
||||
input_index: TensorIndex,
|
||||
/// phase of input tensor
|
||||
input_phase: usize,
|
||||
/// phase of output tensor
|
||||
output_phase: usize,
|
||||
challenge_index: usize,
|
||||
},
|
||||
Contraction {
|
||||
expression: String,
|
||||
/// when axis is `None`, the contraction is pairwise multiplication
|
||||
axis: Option<char>,
|
||||
/// Uniquely identifying indices of input tensors to be contracted
|
||||
input_indices: Vec<TensorIndex>,
|
||||
/// phases of input tensors
|
||||
input_phases: Vec<usize>,
|
||||
/// phase of output tensor
|
||||
output_phase: usize,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TensorIndex(usize);
|
||||
|
||||
impl<T: PrimeField + TensorType + PartialOrd> Index<TensorIndex> for Vec<ValTensor<T>> {
|
||||
type Output = ValTensor<T>;
|
||||
|
||||
fn index(&self, index: TensorIndex) -> &Self::Output {
|
||||
&self[index.0]
|
||||
}
|
||||
}
|
||||
|
||||
impl Reduction {
|
||||
pub fn expression(&self) -> &str {
|
||||
match self {
|
||||
Reduction::Contraction { expression, .. } => expression,
|
||||
Reduction::RLC { expression, .. } => &expression,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn input_indices(&self) -> Vec<TensorIndex> {
|
||||
match self {
|
||||
Reduction::Contraction { input_indices, .. } => input_indices.clone(),
|
||||
Reduction::RLC { input_index, .. } => vec![*input_index],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn output_phase(&self) -> usize {
|
||||
match self {
|
||||
Reduction::Contraction { output_phase, .. } => *output_phase,
|
||||
Reduction::RLC { output_phase, .. } => *output_phase,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn input_reductions(expression: &str) -> Result<Vec<Reduction>, CircuitError> {
|
||||
let (input_exprs, output_expr) = expression.split_once("->").unwrap();
|
||||
let input_exprs: Vec<_> = input_exprs.split(",").map(|eq| eq.to_string()).collect();
|
||||
// (phase, expression)
|
||||
let input_exprs: Vec<(usize, String)> =
|
||||
input_exprs.into_iter().map(|expr| (0, expr)).collect_vec();
|
||||
|
||||
let mut input_tensor_counter = input_exprs.len();
|
||||
let mut input_exprs: Vec<((usize, String), TensorIndex)> = input_exprs
|
||||
.into_iter()
|
||||
.zip((0..input_tensor_counter).map(TensorIndex))
|
||||
.collect();
|
||||
let mut reductions: Vec<Reduction> = vec![];
|
||||
|
||||
// Reduce input_exprs along given axis
|
||||
let mut reduce = |input_exprs: Vec<((usize, String), TensorIndex)>,
|
||||
axis: char|
|
||||
-> (Reduction, Vec<((usize, String), TensorIndex)>) {
|
||||
let inputs = input_exprs
|
||||
.iter()
|
||||
.filter(|((_, eq), _)| eq.chars().contains(&axis))
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
let (inputs_axes, input_indices): (Vec<(usize, String)>, Vec<TensorIndex>) =
|
||||
inputs.iter().cloned().unzip();
|
||||
let (input_phases, inputs_axes): (Vec<usize>, Vec<String>) =
|
||||
inputs_axes.into_iter().unzip();
|
||||
|
||||
let is_output_axis = output_expr.chars().contains(&axis);
|
||||
let output: String = if is_output_axis == true && inputs.len() > 1 {
|
||||
let output: BTreeSet<char> =
|
||||
inputs_axes.iter().flat_map(|input| input.chars()).collect();
|
||||
output.iter().collect()
|
||||
} else {
|
||||
let output: BTreeSet<char> = inputs_axes
|
||||
.iter()
|
||||
.flat_map(|input| input.chars().filter(|&c| c != axis))
|
||||
.collect();
|
||||
output.iter().collect()
|
||||
};
|
||||
|
||||
let reduction = if is_output_axis == true && inputs.len() == 1 {
|
||||
let mut expression = inputs_axes.join(",");
|
||||
expression.push_str(format!(",{axis}").as_str());
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::RLC {
|
||||
expression,
|
||||
axis,
|
||||
input_index: input_indices[0],
|
||||
input_phase: input_phases[0],
|
||||
output_phase: 1,
|
||||
challenge_index: output_expr.chars().position(|c| c == axis).unwrap(),
|
||||
}
|
||||
} else if is_output_axis == true {
|
||||
let mut expression = inputs_axes.join(",");
|
||||
let output_phase = input_phases.iter().copied().max().unwrap();
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::Contraction {
|
||||
expression,
|
||||
axis: None,
|
||||
input_indices: input_indices,
|
||||
input_phases,
|
||||
output_phase,
|
||||
}
|
||||
} else {
|
||||
let mut expression = inputs_axes.join(",");
|
||||
let output_phase = input_phases.iter().copied().max().unwrap();
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::Contraction {
|
||||
expression,
|
||||
axis: Some(axis),
|
||||
input_indices: input_indices,
|
||||
input_phases,
|
||||
output_phase,
|
||||
}
|
||||
};
|
||||
|
||||
// Mutate input_exprs
|
||||
let mut input_exprs = input_exprs.clone();
|
||||
input_exprs.retain(|((_, input_eq), _)| !inputs_axes.contains(input_eq));
|
||||
input_exprs.push((
|
||||
(reduction.output_phase(), output.clone()),
|
||||
TensorIndex(input_tensor_counter),
|
||||
));
|
||||
input_tensor_counter += 1;
|
||||
|
||||
(reduction, input_exprs)
|
||||
};
|
||||
|
||||
let mut output_axes = output_expr.chars().collect_vec();
|
||||
while let Some(axis) = output_axes.first().cloned() {
|
||||
let num_inputs = input_exprs
|
||||
.iter()
|
||||
.filter(|((_, eq), _)| eq.chars().contains(&axis))
|
||||
.count();
|
||||
if num_inputs == 0 {
|
||||
output_axes.remove(0);
|
||||
} else {
|
||||
let (reduction, new_input_exprs) = reduce(input_exprs, axis);
|
||||
reductions.push(reduction);
|
||||
input_exprs = new_input_exprs;
|
||||
}
|
||||
}
|
||||
|
||||
// These are not output axes and were not contracted with random vectors
|
||||
let remaining_axes: BTreeSet<_> = input_exprs
|
||||
.iter()
|
||||
.flat_map(|((_, eq), _)| eq.chars())
|
||||
.collect();
|
||||
|
||||
for axis in remaining_axes.iter() {
|
||||
let (reduction, new_input_exprs) = reduce(input_exprs, *axis);
|
||||
reductions.push(reduction);
|
||||
input_exprs = new_input_exprs;
|
||||
}
|
||||
|
||||
Ok(reductions)
|
||||
}
|
||||
@@ -46,6 +46,9 @@ pub enum CircuitError {
|
||||
/// Failed to get shuffle
|
||||
#[error("failed to get shuffle for op: {0}")]
|
||||
GetShuffleError(String),
|
||||
/// Failed to get einsum
|
||||
#[error("failed to get einsum for op: {0}")]
|
||||
GetEinsumError(String),
|
||||
/// Failed to get constants
|
||||
#[error("failed to get constants for op: {0}")]
|
||||
GetConstantsError(String),
|
||||
@@ -61,6 +64,9 @@ pub enum CircuitError {
|
||||
/// Missing product in einsum
|
||||
#[error("missing product in einsum")]
|
||||
MissingEinsumProduct,
|
||||
/// Missing config in einsum
|
||||
#[error("missing config in einsum")]
|
||||
MissingEinsumConfig,
|
||||
/// Mismatched lookup length
|
||||
#[error("mismatched lookup lengths: {0} and {1}")]
|
||||
MismatchedLookupLength(usize, usize),
|
||||
@@ -109,4 +115,7 @@ pub enum CircuitError {
|
||||
/// A decomposition base overflowed
|
||||
#[error("decomposition base overflowed")]
|
||||
DecompositionBaseOverflow,
|
||||
/// Challenge not set
|
||||
#[error("challenge not set")]
|
||||
ChallengeNotSet,
|
||||
}
|
||||
|
||||
@@ -22,9 +22,8 @@ use crate::{
|
||||
tensor::{
|
||||
create_unit_tensor, get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
Tensor, TensorError, ValType,
|
||||
DataFormat, KernelFormat, Tensor, TensorError, ValType,
|
||||
},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -823,14 +822,73 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// let result = einsum::<Fp>(&dummy_config, &mut dummy_region, &[&x, &k], "mk,n->ma").unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1773]), &[1, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
inputs: &[&ValTensor<F>],
|
||||
equation: &str,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut indices_to_size = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if indices_to_size[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track the einsum equation
|
||||
region.add_used_einsum_equation(equation.to_string(), &indices_to_size)?;
|
||||
|
||||
if config.einsums.is_none() {
|
||||
return einsum_with_base_ops(config, region, inputs, equation);
|
||||
}
|
||||
|
||||
let input_values = inputs
|
||||
.iter()
|
||||
.map(|t| t.get_inner())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
let (output_tensor, _) =
|
||||
crate::tensor::ops::accumulated::einsum(equation, &input_values.iter().collect_vec())?;
|
||||
|
||||
config.einsums.as_ref().unwrap().assign_einsum(
|
||||
config,
|
||||
region,
|
||||
inputs,
|
||||
&output_tensor.clone().into(),
|
||||
equation,
|
||||
&config.check_mode,
|
||||
)?;
|
||||
|
||||
let output: ValTensor<F> = output_tensor.into();
|
||||
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// einsum with base ops
|
||||
pub fn einsum_with_base_ops<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
inputs: &[&ValTensor<F>],
|
||||
equation: &str,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut equation = equation.split("->");
|
||||
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
@@ -1036,6 +1094,8 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let output: ValTensor<F> = output.into();
|
||||
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
@@ -4626,7 +4686,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
let mut rescaled_inputs = vec![];
|
||||
for (i, ri) in values.iter().enumerate() {
|
||||
if scales[i].1 == 1 {
|
||||
rescaled_inputs.push(ri.clone().clone());
|
||||
rescaled_inputs.push((*ri).clone());
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -5709,13 +5769,13 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let abs_distance_to_prior_pow2 = l1_distance(config, region, &[&input, &prior_pow2])?;
|
||||
|
||||
// because we round up this can be equal
|
||||
let is_closest_to_0: ValTensor<F> = less(
|
||||
let is_closest_to_0: ValTensor<F> = less_equal(
|
||||
config,
|
||||
region,
|
||||
&[&abs_distance_to_claimed, &abs_distance_to_next_pow2],
|
||||
)?;
|
||||
|
||||
let is_closest_to_1 = less(
|
||||
let is_closest_to_1 = less_equal(
|
||||
config,
|
||||
region,
|
||||
&[&abs_distance_to_claimed, &abs_distance_to_prior_pow2],
|
||||
|
||||
@@ -364,7 +364,15 @@ impl<
|
||||
};
|
||||
Ok(Some(if self.decomp {
|
||||
log::debug!("constraining constant to be decomp");
|
||||
super::layouts::decompose(config, region, &[&value], ®ion.base(), ®ion.legs(), false)?.1
|
||||
super::layouts::decompose(
|
||||
config,
|
||||
region,
|
||||
&[&value],
|
||||
®ion.base(),
|
||||
®ion.legs(),
|
||||
false,
|
||||
)?
|
||||
.1
|
||||
} else {
|
||||
log::debug!("constraining constant to be identity");
|
||||
super::layouts::identity(config, region, &[&value])?
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use crate::{
|
||||
circuit::table::Range,
|
||||
circuit::{einsum::NUM_MAX_EINSUM_CHALLENGES, table::Range},
|
||||
fieldutils::IntegerRep,
|
||||
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use colored::Colorize;
|
||||
use halo2_proofs::{
|
||||
circuit::Region,
|
||||
circuit::{Region, Value},
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -85,6 +85,45 @@ impl ShuffleIndex {
|
||||
}
|
||||
}
|
||||
|
||||
/// Einsum index
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct EinsumIndex {
|
||||
index: usize,
|
||||
col_coord: usize,
|
||||
// (einsum equation, input axes to dimensions map)
|
||||
equations: Vec<(String, HashMap<char, usize>)>,
|
||||
num_inner_cols: usize,
|
||||
}
|
||||
|
||||
impl EinsumIndex {
|
||||
/// Create a new einsum index
|
||||
pub fn new(index: usize, col_coord: usize, num_inner_cols: usize) -> EinsumIndex {
|
||||
EinsumIndex {
|
||||
index,
|
||||
col_coord,
|
||||
equations: Vec::new(),
|
||||
num_inner_cols,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the einsum index
|
||||
pub fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
/// Get the column coord
|
||||
pub fn col_coord(&self) -> usize {
|
||||
self.col_coord
|
||||
}
|
||||
|
||||
/// update with another einsum index
|
||||
pub fn update(&mut self, other: &EinsumIndex) {
|
||||
self.index += other.index;
|
||||
self.col_coord += other.col_coord;
|
||||
self.equations.extend(other.equations.clone());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Some settings for a region to differentiate it across the different phases of proof generation
|
||||
pub struct RegionSettings {
|
||||
@@ -176,9 +215,11 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
einsum_index: EinsumIndex,
|
||||
statistics: RegionStatistics,
|
||||
settings: RegionSettings,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
challenges: Vec<Value<F>>,
|
||||
max_dynamic_input_len: usize,
|
||||
}
|
||||
|
||||
@@ -250,6 +291,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.shuffle_index.col_coord += n;
|
||||
}
|
||||
|
||||
/// increment the einsum index
|
||||
pub fn increment_einsum_index(&mut self, n: usize) {
|
||||
self.einsum_index.index += n;
|
||||
}
|
||||
|
||||
/// increment the einsum column coordinate
|
||||
pub fn increment_einsum_col_coord(&mut self, n: usize) {
|
||||
self.einsum_index.col_coord += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn witness_gen(&self) -> bool {
|
||||
self.settings.witness_gen
|
||||
@@ -265,6 +316,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&self.statistics
|
||||
}
|
||||
|
||||
///
|
||||
pub fn challenges(&self) -> &[Value<F>] {
|
||||
&self.challenges
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new(
|
||||
region: Region<'a, F>,
|
||||
@@ -283,9 +339,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
linear_coord,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
einsum_index: EinsumIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(decomp_base, decomp_legs),
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges: vec![],
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -304,6 +362,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
new_self
|
||||
}
|
||||
|
||||
/// Create a new region context with challenges
|
||||
pub fn new_with_challenges(
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
decomp_base: usize,
|
||||
decomp_legs: usize,
|
||||
challenges: Vec<Value<F>>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
|
||||
new_self.challenges = challenges;
|
||||
new_self
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
@@ -320,9 +392,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
einsum_index: EinsumIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges: vec![Value::unknown(); NUM_MAX_EINSUM_CHALLENGES],
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -333,6 +407,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
settings: RegionSettings,
|
||||
challenges: Vec<Value<F>>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -342,9 +417,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
einsum_index: EinsumIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges,
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -398,6 +475,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let einsum_index = Arc::new(Mutex::new(self.einsum_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output.par_enum_map(|idx, _| {
|
||||
@@ -412,6 +490,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.settings.clone(),
|
||||
self.challenges.clone(),
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -430,6 +509,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the einsum index
|
||||
let mut einsum_index = einsum_index.lock().unwrap();
|
||||
einsum_index.update(&local_reg.einsum_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
@@ -450,6 +532,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
|
||||
self.einsum_index = Arc::try_unwrap(einsum_index)
|
||||
.map_err(|e| CircuitError::GetEinsumError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetEinsumError(format!("{:?}", e)))?;
|
||||
self.assigned_constants = Arc::try_unwrap(constants)
|
||||
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
@@ -516,6 +602,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
|
||||
/// add used einsum equation
|
||||
pub fn add_used_einsum_equation(
|
||||
&mut self,
|
||||
equation: String,
|
||||
input_axes_to_dims: &HashMap<char, usize>,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.einsum_index
|
||||
.equations
|
||||
.push((equation, input_axes_to_dims.clone()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the offset
|
||||
pub fn row(&self) -> usize {
|
||||
self.row
|
||||
@@ -551,6 +649,31 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.shuffle_index.col_coord
|
||||
}
|
||||
|
||||
/// einsum index
|
||||
pub fn einsum_index(&self) -> usize {
|
||||
self.einsum_index.index
|
||||
}
|
||||
|
||||
/// einsum column coordinate
|
||||
pub fn einsum_col_coord(&self) -> usize {
|
||||
self.einsum_index.col_coord
|
||||
}
|
||||
|
||||
/// get used einsum equations
|
||||
pub fn used_einsum_equations(&self) -> Vec<(String, HashMap<char, usize>)> {
|
||||
self.einsum_index.equations.clone()
|
||||
}
|
||||
|
||||
/// set the number of inner columns used in einsum custom gate
|
||||
pub fn set_num_einsum_inner_cols(&mut self, num_inner_cols: usize) {
|
||||
self.einsum_index.num_inner_cols = num_inner_cols;
|
||||
}
|
||||
|
||||
/// number of inner columns used in einsum custom gate
|
||||
pub fn num_einsum_inner_cols(&self) -> usize {
|
||||
self.einsum_index.num_inner_cols
|
||||
}
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.statistics.used_lookups.clone()
|
||||
@@ -640,6 +763,28 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor in einsum area
|
||||
pub fn assign_einsum(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
Ok(var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)?)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
@@ -697,6 +842,63 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_einsum_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication_unconstrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
false,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_einsum_with_duplication_constrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &crate::circuit::CheckMode,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication_constrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
check_mode,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
true,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable a selector
|
||||
pub fn enable(&mut self, selector: Option<&Selector>, offset: usize) -> Result<(), Error> {
|
||||
match &self.region {
|
||||
@@ -763,4 +965,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// flush row to the next row in einsum area
|
||||
pub fn flush_einsum(&mut self) -> Result<(), CircuitError> {
|
||||
// increment by the difference between the current linear coord and the next row
|
||||
let num_einsum_inner_cols = self.num_einsum_inner_cols();
|
||||
let remainder = self.einsum_col_coord() % num_einsum_inner_cols;
|
||||
if remainder != 0 {
|
||||
let diff = num_einsum_inner_cols - remainder;
|
||||
self.increment_einsum_col_coord(diff);
|
||||
}
|
||||
if self.einsum_col_coord() % num_einsum_inner_cols != 0 {
|
||||
return Err(CircuitError::FlushError);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,12 +93,17 @@ pub const DEFAULT_VKA_DIGEST: &str = "vka.digest";
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
impl IntoPy<PyObject> for TranscriptType {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
TranscriptType::Poseidon => "poseidon".to_object(py),
|
||||
TranscriptType::EVM => "evm".to_object(py),
|
||||
}
|
||||
impl<'py> IntoPyObject<'py> for TranscriptType {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
TranscriptType::Poseidon => "poseidon",
|
||||
TranscriptType::EVM => "evm",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -257,17 +262,20 @@ impl From<&str> for H160Flag {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts CalibrationTarget into a PyObject (Required for CalibrationTarget to be compatible with Python)
|
||||
impl IntoPy<PyObject> for CalibrationTarget {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow".to_object(py)
|
||||
}
|
||||
impl<'py> IntoPyObject<'py> for CalibrationTarget {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => "resources/col-overflow",
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources".to_object(py),
|
||||
CalibrationTarget::Accuracy => "accuracy".to_object(py),
|
||||
}
|
||||
} => "resources",
|
||||
CalibrationTarget::Accuracy => "accuracy",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,12 +297,17 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts ContractType into a PyObject (Required for ContractType to be compatible with Python)
|
||||
impl IntoPy<PyObject> for ContractType {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
ContractType::Verifier { reusable: true } => "verifier/reusable".to_object(py),
|
||||
ContractType::Verifier { reusable: false } => "verifier".to_object(py),
|
||||
}
|
||||
impl<'py> IntoPyObject<'py> for ContractType {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
ContractType::Verifier { reusable: true } => "verifier/reusable",
|
||||
ContractType::Verifier { reusable: false } => "verifier",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,8 @@ use colored::Colorize;
|
||||
#[cfg(unix)]
|
||||
use gag::Gag;
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
#[cfg(feature = "gpu-accelerated")]
|
||||
use halo2_proofs::icicle::try_load_and_set_backend_device;
|
||||
use halo2_proofs::plonk::{self, Circuit};
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
|
||||
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
|
||||
@@ -46,6 +48,8 @@ use halo2_solidity_verifier;
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
use halo2curves::ff::{FromUniformBytes, WithSmallOrderMulGroup};
|
||||
use halo2curves::serde::SerdeObject;
|
||||
#[cfg(feature = "gpu-accelerated")]
|
||||
use icicle_runtime::{stream::IcicleStream, warmup};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use instant::Instant;
|
||||
use itertools::Itertools;
|
||||
@@ -87,6 +91,22 @@ lazy_static! {
|
||||
|
||||
}
|
||||
|
||||
/// Set the device used for computation.
|
||||
#[cfg(feature = "gpu-accelerated")]
|
||||
pub fn set_device() {
|
||||
if std::env::var("ICICLE_BACKEND_INSTALL_DIR").is_ok() {
|
||||
info!("Running with ICICLE GPU");
|
||||
try_load_and_set_backend_device("CUDA");
|
||||
match warmup(&IcicleStream::default()) {
|
||||
Ok(_) => info!("GPU warmed :)"),
|
||||
Err(e) => log::error!("GPU warmup failed: {:?}", e),
|
||||
}
|
||||
} else {
|
||||
info!("Running with CPU: 'ICICLE_BACKEND_INSTALL_DIR' not set");
|
||||
try_load_and_set_backend_device("CPU");
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for execution errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ExecutionError {
|
||||
@@ -108,6 +128,8 @@ lazy_static::lazy_static! {
|
||||
|
||||
/// Run an ezkl command with given args
|
||||
pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
#[cfg(feature = "gpu-accelerated")]
|
||||
set_device();
|
||||
// set working dir
|
||||
std::env::set_current_dir(WORKING_DIR.as_path())?;
|
||||
|
||||
@@ -1268,6 +1290,7 @@ pub(crate) fn calibrate(
|
||||
total_const_size: new_settings.total_const_size,
|
||||
dynamic_lookup_params: new_settings.dynamic_lookup_params,
|
||||
shuffle_params: new_settings.shuffle_params,
|
||||
einsum_params: new_settings.einsum_params,
|
||||
..settings.clone()
|
||||
};
|
||||
|
||||
|
||||
@@ -7,9 +7,7 @@ use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use pyo3::IntoPyObject;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
@@ -429,34 +427,23 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallToAccount {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("account", &self.address).unwrap();
|
||||
dict.set_item("call_data", &self.call_data).unwrap();
|
||||
dict.set_item("decimals", &self.decimals).unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for DataSource {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
self.0.to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use crate::pfsys::field_to_string;
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for FileSourceInner {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
impl<'py> IntoPyObject<'py> for FileSourceInner {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
match self {
|
||||
FileSourceInner::Field(data) => field_to_string(data).to_object(py),
|
||||
FileSourceInner::Bool(data) => data.to_object(py),
|
||||
FileSourceInner::Float(data) => data.to_object(py),
|
||||
FileSourceInner::Field(data) => {
|
||||
let s = field_to_string(&data);
|
||||
Ok(pyo3::types::PyString::new(py, &s).into_any())
|
||||
}
|
||||
FileSourceInner::Bool(data) => Ok(pyo3::types::PyBool::new(py, data).as_any().clone()),
|
||||
FileSourceInner::Float(data) => Ok(pyo3::types::PyFloat::new(py, data).into_any()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
690
src/graph/mod.rs
690
src/graph/mod.rs
@@ -61,9 +61,10 @@ use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDictMethods;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use pyo3::IntoPyObject;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Deref;
|
||||
pub use utilities::*;
|
||||
pub use vars::*;
|
||||
@@ -319,8 +320,12 @@ impl GraphWitness {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for GraphWitness {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
impl<'py> IntoPyObject<'py> for GraphWitness {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
// Create a Python dictionary
|
||||
let dict = PyDict::new(py);
|
||||
let dict_inputs = PyDict::new(py);
|
||||
@@ -383,7 +388,7 @@ impl ToPyObject for GraphWitness {
|
||||
dict.set_item("processed_outputs", dict_outputs).unwrap();
|
||||
}
|
||||
|
||||
dict.to_object(py)
|
||||
Ok(dict.into_any())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,8 +439,17 @@ pub struct ShuffleParams {
|
||||
pub total_shuffle_col_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||
/// Parameters for einsum operations
|
||||
pub struct EinsumParams {
|
||||
/// einsum equations
|
||||
pub equations: Vec<(String, HashMap<char, usize>)>,
|
||||
/// total einsum column size
|
||||
pub total_einsum_col_size: usize,
|
||||
}
|
||||
|
||||
/// model parameters
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct GraphSettings {
|
||||
/// run args
|
||||
pub run_args: RunArgs,
|
||||
@@ -445,12 +459,12 @@ pub struct GraphSettings {
|
||||
pub total_assignments: usize,
|
||||
/// total const size
|
||||
pub total_const_size: usize,
|
||||
/// dynamic lookup parameters, flattened for backwards compatibility
|
||||
#[serde(flatten)]
|
||||
/// dynamic lookup parameters, flattened for backwards compatibility, serialize and deserialize flattened for backwards compatibility
|
||||
pub dynamic_lookup_params: DynamicLookupParams,
|
||||
/// shuffle parameters, flattened for backwards compatibility
|
||||
#[serde(flatten)]
|
||||
pub shuffle_params: ShuffleParams,
|
||||
/// einsum parameters
|
||||
pub einsum_params: EinsumParams,
|
||||
/// the shape of public inputs to the model (in order of appearance)
|
||||
pub model_instance_shapes: Vec<Vec<usize>>,
|
||||
/// model output scales
|
||||
@@ -477,6 +491,504 @@ pub struct GraphSettings {
|
||||
pub output_types: Option<Vec<InputType>>,
|
||||
}
|
||||
|
||||
impl Serialize for GraphSettings {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
if serializer.is_human_readable() {
|
||||
// JSON format - use flattened fields for backwards compatibility
|
||||
use serde::ser::SerializeStruct;
|
||||
let mut state = serializer.serialize_struct("GraphSettings", 22)?;
|
||||
state.serialize_field("run_args", &self.run_args)?;
|
||||
state.serialize_field("num_rows", &self.num_rows)?;
|
||||
state.serialize_field("total_assignments", &self.total_assignments)?;
|
||||
state.serialize_field("total_const_size", &self.total_const_size)?;
|
||||
|
||||
// Flatten DynamicLookupParams fields
|
||||
state.serialize_field(
|
||||
"total_dynamic_col_size",
|
||||
&self.dynamic_lookup_params.total_dynamic_col_size,
|
||||
)?;
|
||||
state.serialize_field(
|
||||
"max_dynamic_input_len",
|
||||
&self.dynamic_lookup_params.max_dynamic_input_len,
|
||||
)?;
|
||||
state.serialize_field(
|
||||
"num_dynamic_lookups",
|
||||
&self.dynamic_lookup_params.num_dynamic_lookups,
|
||||
)?;
|
||||
|
||||
// Flatten ShuffleParams fields
|
||||
state.serialize_field("num_shuffles", &self.shuffle_params.num_shuffles)?;
|
||||
state.serialize_field(
|
||||
"total_shuffle_col_size",
|
||||
&self.shuffle_params.total_shuffle_col_size,
|
||||
)?;
|
||||
|
||||
// Serialize EinsumParams
|
||||
state.serialize_field("einsum_params", &self.einsum_params)?;
|
||||
|
||||
state.serialize_field("model_instance_shapes", &self.model_instance_shapes)?;
|
||||
state.serialize_field("model_output_scales", &self.model_output_scales)?;
|
||||
state.serialize_field("model_input_scales", &self.model_input_scales)?;
|
||||
state.serialize_field("module_sizes", &self.module_sizes)?;
|
||||
state.serialize_field("required_lookups", &self.required_lookups)?;
|
||||
state.serialize_field("required_range_checks", &self.required_range_checks)?;
|
||||
state.serialize_field("check_mode", &self.check_mode)?;
|
||||
state.serialize_field("version", &self.version)?;
|
||||
state.serialize_field("num_blinding_factors", &self.num_blinding_factors)?;
|
||||
state.serialize_field("timestamp", &self.timestamp)?;
|
||||
state.serialize_field("input_types", &self.input_types)?;
|
||||
state.serialize_field("output_types", &self.output_types)?;
|
||||
state.end()
|
||||
} else {
|
||||
// Binary format (bincode) - use nested struct format
|
||||
use serde::ser::SerializeTuple;
|
||||
let mut state = serializer.serialize_tuple(19)?;
|
||||
state.serialize_element(&self.run_args)?;
|
||||
state.serialize_element(&self.num_rows)?;
|
||||
state.serialize_element(&self.total_assignments)?;
|
||||
state.serialize_element(&self.total_const_size)?;
|
||||
state.serialize_element(&self.dynamic_lookup_params)?;
|
||||
state.serialize_element(&self.shuffle_params)?;
|
||||
state.serialize_element(&self.einsum_params)?;
|
||||
state.serialize_element(&self.model_instance_shapes)?;
|
||||
state.serialize_element(&self.model_output_scales)?;
|
||||
state.serialize_element(&self.model_input_scales)?;
|
||||
state.serialize_element(&self.module_sizes)?;
|
||||
state.serialize_element(&self.required_lookups)?;
|
||||
state.serialize_element(&self.required_range_checks)?;
|
||||
state.serialize_element(&self.check_mode)?;
|
||||
state.serialize_element(&self.version)?;
|
||||
state.serialize_element(&self.num_blinding_factors)?;
|
||||
state.serialize_element(&self.timestamp)?;
|
||||
state.serialize_element(&self.input_types)?;
|
||||
state.serialize_element(&self.output_types)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for GraphSettings {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::{self, MapAccess, Visitor};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(field_identifier, rename_all = "snake_case")]
|
||||
enum Field {
|
||||
RunArgs,
|
||||
NumRows,
|
||||
TotalAssignments,
|
||||
TotalConstSize,
|
||||
// Flattened DynamicLookupParams fields
|
||||
TotalDynamicColSize,
|
||||
MaxDynamicInputLen,
|
||||
NumDynamicLookups,
|
||||
// Flattened ShuffleParams fields
|
||||
NumShuffles,
|
||||
TotalShuffleColSize,
|
||||
// EinsumParams field
|
||||
EinsumParams,
|
||||
ModelInstanceShapes,
|
||||
ModelOutputScales,
|
||||
ModelInputScales,
|
||||
ModuleSizes,
|
||||
RequiredLookups,
|
||||
RequiredRangeChecks,
|
||||
CheckMode,
|
||||
Version,
|
||||
NumBlindingFactors,
|
||||
Timestamp,
|
||||
InputTypes,
|
||||
OutputTypes,
|
||||
// Legacy nested struct fields for backwards compatibility
|
||||
DynamicLookupParams,
|
||||
ShuffleParams,
|
||||
}
|
||||
|
||||
struct GraphSettingsVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for GraphSettingsVisitor {
|
||||
type Value = GraphSettings;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("struct GraphSettings")
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> Result<GraphSettings, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
let mut run_args = None;
|
||||
let mut num_rows = None;
|
||||
let mut total_assignments = None;
|
||||
let mut total_const_size = None;
|
||||
let mut total_dynamic_col_size = None;
|
||||
let mut max_dynamic_input_len = None;
|
||||
let mut num_dynamic_lookups = None;
|
||||
let mut num_shuffles = None;
|
||||
let mut total_shuffle_col_size = None;
|
||||
let mut einsum_params = None;
|
||||
let mut model_instance_shapes = None;
|
||||
let mut model_output_scales = None;
|
||||
let mut model_input_scales = None;
|
||||
let mut module_sizes = None;
|
||||
let mut required_lookups = None;
|
||||
let mut required_range_checks = None;
|
||||
let mut check_mode = None;
|
||||
let mut version = None;
|
||||
let mut num_blinding_factors = None;
|
||||
let mut timestamp = None;
|
||||
let mut input_types = None;
|
||||
let mut output_types = None;
|
||||
|
||||
while let Some(key) = map.next_key()? {
|
||||
match key {
|
||||
Field::RunArgs => {
|
||||
if run_args.is_some() {
|
||||
return Err(de::Error::duplicate_field("run_args"));
|
||||
}
|
||||
run_args = Some(map.next_value()?);
|
||||
}
|
||||
Field::NumRows => {
|
||||
if num_rows.is_some() {
|
||||
return Err(de::Error::duplicate_field("num_rows"));
|
||||
}
|
||||
num_rows = Some(map.next_value()?);
|
||||
}
|
||||
Field::TotalAssignments => {
|
||||
if total_assignments.is_some() {
|
||||
return Err(de::Error::duplicate_field("total_assignments"));
|
||||
}
|
||||
total_assignments = Some(map.next_value()?);
|
||||
}
|
||||
Field::TotalConstSize => {
|
||||
if total_const_size.is_some() {
|
||||
return Err(de::Error::duplicate_field("total_const_size"));
|
||||
}
|
||||
total_const_size = Some(map.next_value()?);
|
||||
}
|
||||
Field::TotalDynamicColSize => {
|
||||
if total_dynamic_col_size.is_some() {
|
||||
return Err(de::Error::duplicate_field("total_dynamic_col_size"));
|
||||
}
|
||||
total_dynamic_col_size = Some(map.next_value()?);
|
||||
}
|
||||
Field::MaxDynamicInputLen => {
|
||||
if max_dynamic_input_len.is_some() {
|
||||
return Err(de::Error::duplicate_field("max_dynamic_input_len"));
|
||||
}
|
||||
max_dynamic_input_len = Some(map.next_value()?);
|
||||
}
|
||||
Field::NumDynamicLookups => {
|
||||
if num_dynamic_lookups.is_some() {
|
||||
return Err(de::Error::duplicate_field("num_dynamic_lookups"));
|
||||
}
|
||||
num_dynamic_lookups = Some(map.next_value()?);
|
||||
}
|
||||
Field::NumShuffles => {
|
||||
if num_shuffles.is_some() {
|
||||
return Err(de::Error::duplicate_field("num_shuffles"));
|
||||
}
|
||||
num_shuffles = Some(map.next_value()?);
|
||||
}
|
||||
Field::TotalShuffleColSize => {
|
||||
if total_shuffle_col_size.is_some() {
|
||||
return Err(de::Error::duplicate_field("total_shuffle_col_size"));
|
||||
}
|
||||
total_shuffle_col_size = Some(map.next_value()?);
|
||||
}
|
||||
Field::EinsumParams => {
|
||||
if einsum_params.is_some() {
|
||||
return Err(de::Error::duplicate_field("einsum_params"));
|
||||
}
|
||||
einsum_params = Some(map.next_value()?);
|
||||
}
|
||||
Field::ModelInstanceShapes => {
|
||||
if model_instance_shapes.is_some() {
|
||||
return Err(de::Error::duplicate_field("model_instance_shapes"));
|
||||
}
|
||||
model_instance_shapes = Some(map.next_value()?);
|
||||
}
|
||||
Field::ModelOutputScales => {
|
||||
if model_output_scales.is_some() {
|
||||
return Err(de::Error::duplicate_field("model_output_scales"));
|
||||
}
|
||||
model_output_scales = Some(map.next_value()?);
|
||||
}
|
||||
Field::ModelInputScales => {
|
||||
if model_input_scales.is_some() {
|
||||
return Err(de::Error::duplicate_field("model_input_scales"));
|
||||
}
|
||||
model_input_scales = Some(map.next_value()?);
|
||||
}
|
||||
Field::ModuleSizes => {
|
||||
if module_sizes.is_some() {
|
||||
return Err(de::Error::duplicate_field("module_sizes"));
|
||||
}
|
||||
module_sizes = Some(map.next_value()?);
|
||||
}
|
||||
Field::RequiredLookups => {
|
||||
if required_lookups.is_some() {
|
||||
return Err(de::Error::duplicate_field("required_lookups"));
|
||||
}
|
||||
required_lookups = Some(map.next_value()?);
|
||||
}
|
||||
Field::RequiredRangeChecks => {
|
||||
if required_range_checks.is_some() {
|
||||
return Err(de::Error::duplicate_field("required_range_checks"));
|
||||
}
|
||||
required_range_checks = Some(map.next_value()?);
|
||||
}
|
||||
Field::CheckMode => {
|
||||
if check_mode.is_some() {
|
||||
return Err(de::Error::duplicate_field("check_mode"));
|
||||
}
|
||||
check_mode = Some(map.next_value()?);
|
||||
}
|
||||
Field::Version => {
|
||||
if version.is_some() {
|
||||
return Err(de::Error::duplicate_field("version"));
|
||||
}
|
||||
version = Some(map.next_value()?);
|
||||
}
|
||||
Field::NumBlindingFactors => {
|
||||
if num_blinding_factors.is_some() {
|
||||
return Err(de::Error::duplicate_field("num_blinding_factors"));
|
||||
}
|
||||
num_blinding_factors = map.next_value()?;
|
||||
}
|
||||
Field::Timestamp => {
|
||||
if timestamp.is_some() {
|
||||
return Err(de::Error::duplicate_field("timestamp"));
|
||||
}
|
||||
timestamp = Some(map.next_value()?);
|
||||
}
|
||||
Field::InputTypes => {
|
||||
if input_types.is_some() {
|
||||
return Err(de::Error::duplicate_field("input_types"));
|
||||
}
|
||||
input_types = map.next_value()?;
|
||||
}
|
||||
Field::OutputTypes => {
|
||||
if output_types.is_some() {
|
||||
return Err(de::Error::duplicate_field("output_types"));
|
||||
}
|
||||
output_types = map.next_value()?;
|
||||
}
|
||||
// Handle legacy nested struct fields for backwards compatibility
|
||||
Field::DynamicLookupParams => {
|
||||
let legacy_params: DynamicLookupParams = map.next_value()?;
|
||||
if total_dynamic_col_size.is_none() {
|
||||
total_dynamic_col_size = Some(legacy_params.total_dynamic_col_size);
|
||||
}
|
||||
if max_dynamic_input_len.is_none() {
|
||||
max_dynamic_input_len = Some(legacy_params.max_dynamic_input_len);
|
||||
}
|
||||
if num_dynamic_lookups.is_none() {
|
||||
num_dynamic_lookups = Some(legacy_params.num_dynamic_lookups);
|
||||
}
|
||||
}
|
||||
Field::ShuffleParams => {
|
||||
let legacy_params: ShuffleParams = map.next_value()?;
|
||||
if num_shuffles.is_none() {
|
||||
num_shuffles = Some(legacy_params.num_shuffles);
|
||||
}
|
||||
if total_shuffle_col_size.is_none() {
|
||||
total_shuffle_col_size = Some(legacy_params.total_shuffle_col_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let run_args = run_args.ok_or_else(|| de::Error::missing_field("run_args"))?;
|
||||
let num_rows = num_rows.ok_or_else(|| de::Error::missing_field("num_rows"))?;
|
||||
let total_assignments = total_assignments
|
||||
.ok_or_else(|| de::Error::missing_field("total_assignments"))?;
|
||||
let total_const_size =
|
||||
total_const_size.ok_or_else(|| de::Error::missing_field("total_const_size"))?;
|
||||
let model_instance_shapes = model_instance_shapes
|
||||
.ok_or_else(|| de::Error::missing_field("model_instance_shapes"))?;
|
||||
let model_output_scales = model_output_scales
|
||||
.ok_or_else(|| de::Error::missing_field("model_output_scales"))?;
|
||||
let model_input_scales = model_input_scales
|
||||
.ok_or_else(|| de::Error::missing_field("model_input_scales"))?;
|
||||
let module_sizes =
|
||||
module_sizes.ok_or_else(|| de::Error::missing_field("module_sizes"))?;
|
||||
let required_lookups =
|
||||
required_lookups.ok_or_else(|| de::Error::missing_field("required_lookups"))?;
|
||||
let required_range_checks = required_range_checks
|
||||
.ok_or_else(|| de::Error::missing_field("required_range_checks"))?;
|
||||
let check_mode =
|
||||
check_mode.ok_or_else(|| de::Error::missing_field("check_mode"))?;
|
||||
let version = version.ok_or_else(|| de::Error::missing_field("version"))?;
|
||||
|
||||
// Build the nested structs from flattened fields, with defaults if missing
|
||||
let dynamic_lookup_params = DynamicLookupParams {
|
||||
total_dynamic_col_size: total_dynamic_col_size.unwrap_or_default(),
|
||||
max_dynamic_input_len: max_dynamic_input_len.unwrap_or_default(),
|
||||
num_dynamic_lookups: num_dynamic_lookups.unwrap_or_default(),
|
||||
};
|
||||
|
||||
let shuffle_params = ShuffleParams {
|
||||
num_shuffles: num_shuffles.unwrap_or_default(),
|
||||
total_shuffle_col_size: total_shuffle_col_size.unwrap_or_default(),
|
||||
};
|
||||
|
||||
Ok(GraphSettings {
|
||||
run_args,
|
||||
num_rows,
|
||||
total_assignments,
|
||||
total_const_size,
|
||||
dynamic_lookup_params,
|
||||
shuffle_params,
|
||||
einsum_params: einsum_params.unwrap_or_default(),
|
||||
model_instance_shapes,
|
||||
model_output_scales,
|
||||
model_input_scales,
|
||||
module_sizes,
|
||||
required_lookups,
|
||||
required_range_checks,
|
||||
check_mode,
|
||||
version,
|
||||
num_blinding_factors,
|
||||
timestamp,
|
||||
input_types,
|
||||
output_types,
|
||||
})
|
||||
}
|
||||
|
||||
fn visit_seq<V>(self, mut seq: V) -> Result<GraphSettings, V::Error>
|
||||
where
|
||||
V: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
// For bincode compatibility, deserialize in the same order as tuple serialization
|
||||
let run_args = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(0, &self))?;
|
||||
let num_rows = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(1, &self))?;
|
||||
let total_assignments = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(2, &self))?;
|
||||
let total_const_size = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(3, &self))?;
|
||||
let dynamic_lookup_params = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(4, &self))?;
|
||||
let shuffle_params = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(5, &self))?;
|
||||
let einsum_params = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(6, &self))?;
|
||||
let model_instance_shapes = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(7, &self))?;
|
||||
let model_output_scales = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(8, &self))?;
|
||||
let model_input_scales = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(9, &self))?;
|
||||
let module_sizes = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(10, &self))?;
|
||||
let required_lookups = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(11, &self))?;
|
||||
let required_range_checks = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(12, &self))?;
|
||||
let check_mode = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(13, &self))?;
|
||||
let version = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(14, &self))?;
|
||||
let num_blinding_factors = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(15, &self))?;
|
||||
let timestamp = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(16, &self))?;
|
||||
let input_types = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(17, &self))?;
|
||||
let output_types = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(18, &self))?;
|
||||
|
||||
Ok(GraphSettings {
|
||||
run_args,
|
||||
num_rows,
|
||||
total_assignments,
|
||||
total_const_size,
|
||||
dynamic_lookup_params,
|
||||
shuffle_params,
|
||||
einsum_params,
|
||||
model_instance_shapes,
|
||||
model_output_scales,
|
||||
model_input_scales,
|
||||
module_sizes,
|
||||
required_lookups,
|
||||
required_range_checks,
|
||||
check_mode,
|
||||
version,
|
||||
num_blinding_factors,
|
||||
timestamp,
|
||||
input_types,
|
||||
output_types,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Universal deserializer that works with both JSON (map) and bincode (tuple)
|
||||
if deserializer.is_human_readable() {
|
||||
// JSON format - use struct/map deserialization with flattened fields
|
||||
const FIELDS: &'static [&'static str] = &[
|
||||
"run_args",
|
||||
"num_rows",
|
||||
"total_assignments",
|
||||
"total_const_size",
|
||||
"total_dynamic_col_size",
|
||||
"max_dynamic_input_len",
|
||||
"num_dynamic_lookups",
|
||||
"num_shuffles",
|
||||
"total_shuffle_col_size",
|
||||
"einsum_params",
|
||||
"model_instance_shapes",
|
||||
"model_output_scales",
|
||||
"model_input_scales",
|
||||
"module_sizes",
|
||||
"required_lookups",
|
||||
"required_range_checks",
|
||||
"check_mode",
|
||||
"version",
|
||||
"num_blinding_factors",
|
||||
"timestamp",
|
||||
"input_types",
|
||||
"output_types",
|
||||
"dynamic_lookup_params",
|
||||
"shuffle_params",
|
||||
];
|
||||
deserializer.deserialize_struct("GraphSettings", FIELDS, GraphSettingsVisitor)
|
||||
} else {
|
||||
// Binary format (bincode) - use tuple deserialization
|
||||
deserializer.deserialize_tuple(19, GraphSettingsVisitor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphSettings {
|
||||
/// Calc the number of rows required for lookup tables
|
||||
pub fn lookup_log_rows(&self) -> u32 {
|
||||
@@ -557,6 +1069,13 @@ impl GraphSettings {
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// Calculates the logrows for einsum computation area in which there is no column overflow
|
||||
pub fn einsum_logrows(&self) -> u32 {
|
||||
(self.einsum_params.total_einsum_col_size as f64 / self.run_args.num_inner_cols as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the total number of instances
|
||||
pub fn total_instances(&self) -> Vec<usize> {
|
||||
let mut instances: Vec<usize> = self.module_sizes.num_instances();
|
||||
@@ -1112,10 +1631,11 @@ impl GraphCircuit {
|
||||
let model_constraint_logrows = self.settings().model_constraint_logrows_with_blinding();
|
||||
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
|
||||
let constants_logrows = self.settings().constants_logrows();
|
||||
let einsum_logrows = self.settings().einsum_logrows();
|
||||
max_logrows = std::cmp::min(
|
||||
max_logrows,
|
||||
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
|
||||
*[model_constraint_logrows, min_bits, constants_logrows]
|
||||
*[model_constraint_logrows, min_bits, constants_logrows, einsum_logrows]
|
||||
.iter()
|
||||
.max()
|
||||
.unwrap(),
|
||||
@@ -1674,3 +2194,155 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
/// Tests for the graph module
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_graph_settings_serialization_roundtrip() {
|
||||
use crate::{CheckMode, RunArgs};
|
||||
|
||||
// Create a test GraphSettings with nested structs
|
||||
let original = GraphSettings {
|
||||
run_args: RunArgs::default(),
|
||||
num_rows: 1000,
|
||||
total_assignments: 500,
|
||||
total_const_size: 100,
|
||||
dynamic_lookup_params: DynamicLookupParams {
|
||||
total_dynamic_col_size: 42,
|
||||
max_dynamic_input_len: 128,
|
||||
num_dynamic_lookups: 5,
|
||||
},
|
||||
shuffle_params: ShuffleParams {
|
||||
num_shuffles: 3,
|
||||
total_shuffle_col_size: 256,
|
||||
},
|
||||
einsum_params: EinsumParams::default(),
|
||||
model_instance_shapes: vec![vec![1, 2, 3]],
|
||||
model_output_scales: vec![],
|
||||
model_input_scales: vec![],
|
||||
module_sizes: ModuleSizes::default(),
|
||||
required_lookups: vec![],
|
||||
required_range_checks: vec![],
|
||||
check_mode: CheckMode::SAFE,
|
||||
version: "1.0.0".to_string(),
|
||||
num_blinding_factors: Some(5),
|
||||
timestamp: Some(123456789),
|
||||
input_types: None,
|
||||
output_types: None,
|
||||
};
|
||||
|
||||
// Test 1: JSON serialization roundtrip with flattened format
|
||||
let json_str = serde_json::to_string_pretty(&original).unwrap();
|
||||
println!("JSON serialized (flattened):\n{}", json_str);
|
||||
|
||||
// Verify the JSON contains flattened fields
|
||||
assert!(json_str.contains("\"total_dynamic_col_size\": 42"));
|
||||
assert!(json_str.contains("\"max_dynamic_input_len\": 128"));
|
||||
assert!(json_str.contains("\"num_dynamic_lookups\": 5"));
|
||||
assert!(json_str.contains("\"num_shuffles\": 3"));
|
||||
assert!(json_str.contains("\"total_shuffle_col_size\": 256"));
|
||||
|
||||
// Verify the JSON does NOT contain nested structs
|
||||
assert!(!json_str.contains("\"dynamic_lookup_params\""));
|
||||
assert!(!json_str.contains("\"shuffle_params\""));
|
||||
|
||||
// Deserialize from JSON
|
||||
let deserialized: GraphSettings = serde_json::from_str(&json_str).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
|
||||
// now do JSON bytes
|
||||
let json_bytes = serde_json::to_vec(&original).unwrap();
|
||||
let deserialized_from_bytes: GraphSettings = serde_json::from_slice(&json_bytes).unwrap();
|
||||
assert_eq!(original, deserialized_from_bytes);
|
||||
|
||||
// Test 2: Bincode serialization roundtrip
|
||||
let bincode_data = bincode::serialize(&original).unwrap();
|
||||
let bincode_deserialized: GraphSettings = bincode::deserialize(&bincode_data).unwrap();
|
||||
assert_eq!(original, bincode_deserialized);
|
||||
|
||||
// Test 3: Backwards compatibility - deserialize old nested format
|
||||
let old_format_json = r#"{
|
||||
"run_args": {
|
||||
"tolerance": {
|
||||
"val": 0.0,
|
||||
"scale": 1.0
|
||||
},
|
||||
"input_scale": 0,
|
||||
"param_scale": 0,
|
||||
"scale_rebase_multiplier": 10,
|
||||
"lookup_range": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"logrows": 6,
|
||||
"num_inner_cols": 2,
|
||||
"variables": [
|
||||
[
|
||||
"batch_size",
|
||||
1
|
||||
]
|
||||
],
|
||||
"input_visibility": "Private",
|
||||
"output_visibility": "Public",
|
||||
"param_visibility": "Private",
|
||||
"rebase_frac_zero_constants": false,
|
||||
"check_mode": "UNSAFE",
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false,
|
||||
"ignore_range_check_inputs_outputs": false,
|
||||
"disable_freivalds": false
|
||||
},
|
||||
"num_rows": 236,
|
||||
"total_assignments": 472,
|
||||
"total_const_size": 4,
|
||||
"total_dynamic_col_size": 0,
|
||||
"max_dynamic_input_len": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
"num_shuffles": 0,
|
||||
"total_shuffle_col_size": 0,
|
||||
"model_instance_shapes": [
|
||||
[
|
||||
1,
|
||||
4
|
||||
]
|
||||
],
|
||||
"model_output_scales": [
|
||||
0
|
||||
],
|
||||
"model_input_scales": [
|
||||
0
|
||||
],
|
||||
"module_sizes": {
|
||||
"polycommit": [],
|
||||
"poseidon": [
|
||||
0,
|
||||
[
|
||||
0
|
||||
]
|
||||
]
|
||||
},
|
||||
"required_lookups": [],
|
||||
"required_range_checks": [
|
||||
[
|
||||
-1,
|
||||
1
|
||||
],
|
||||
[
|
||||
0,
|
||||
127
|
||||
]
|
||||
],
|
||||
"check_mode": "UNSAFE",
|
||||
"version": "0.0.0",
|
||||
"num_blinding_factors": null,
|
||||
"timestamp": 1741214578354
|
||||
}"#;
|
||||
|
||||
let _backwards_compatible: GraphSettings = serde_json::from_str(old_format_json).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::vars::*;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
@@ -37,7 +38,6 @@ use log::{debug, info, trace};
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::BTreeMap;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fs;
|
||||
@@ -106,6 +106,8 @@ pub struct DummyPassRes {
|
||||
pub dynamic_lookup_params: DynamicLookupParams,
|
||||
/// shuffle parameters
|
||||
pub shuffle_params: ShuffleParams,
|
||||
/// einsum parameters
|
||||
pub einsum_params: crate::graph::EinsumParams,
|
||||
/// num shuffles
|
||||
pub num_shuffles: usize,
|
||||
/// shuffle
|
||||
@@ -592,6 +594,7 @@ impl Model {
|
||||
output_types: Some(self.get_output_types()),
|
||||
dynamic_lookup_params: res.dynamic_lookup_params,
|
||||
shuffle_params: res.shuffle_params,
|
||||
einsum_params: res.einsum_params,
|
||||
total_const_size: res.total_const_size,
|
||||
check_mode,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
@@ -1047,6 +1050,7 @@ impl Model {
|
||||
|
||||
let lookup_range = settings.run_args.lookup_range;
|
||||
let logrows = settings.run_args.logrows as usize;
|
||||
let num_inner_cols = settings.run_args.num_inner_cols;
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
@@ -1095,6 +1099,24 @@ impl Model {
|
||||
)?;
|
||||
}
|
||||
|
||||
// Configures the circuit to use Freivalds' argument
|
||||
// In the dummy phase, Freivalds' is configured as a default (unless `disable-freivalds` is not enabled),
|
||||
// but if einsum coordinate is 0, it means that all the einsum layouts are dispatched to use only base operations.
|
||||
if settings.einsum_params.total_einsum_col_size > 0 {
|
||||
debug!("configuring einsums...");
|
||||
let used_einsums: HashMap<(usize, String), HashMap<char, usize>> = settings
|
||||
.einsum_params
|
||||
.equations
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, (equation, indices_to_dims))| {
|
||||
((idx, equation.clone()), indices_to_dims.clone())
|
||||
})
|
||||
.collect();
|
||||
let analysis = analyze_einsum_usage(&used_einsums)?;
|
||||
base_gate.configure_einsums(meta, &analysis, num_inner_cols, logrows)?;
|
||||
}
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
@@ -1147,17 +1169,30 @@ impl Model {
|
||||
|
||||
let original_constants = constants.clone();
|
||||
|
||||
let challenges = {
|
||||
if let Some(einsum_config) = &config.base.einsums {
|
||||
einsum_config
|
||||
.challenges()?
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
};
|
||||
|
||||
let outputs = layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(
|
||||
let mut thread_safe_region = RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
run_args.num_inner_cols,
|
||||
run_args.decomp_base,
|
||||
run_args.decomp_legs,
|
||||
original_constants.clone(),
|
||||
challenges.clone(),
|
||||
);
|
||||
thread_safe_region.update_constants(original_constants.clone());
|
||||
// we need to do this as this loop is called multiple times
|
||||
vars.set_instance_idx(instance_idx);
|
||||
|
||||
@@ -1459,8 +1494,16 @@ impl Model {
|
||||
results.insert(*input_idx, vec![inputs[i].clone()]);
|
||||
}
|
||||
|
||||
let mut dummy_config =
|
||||
PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols);
|
||||
let mut dummy_config = {
|
||||
if run_args.disable_freivalds {
|
||||
PolyConfig::dummy_without_freivalds(
|
||||
run_args.logrows as usize,
|
||||
run_args.num_inner_cols,
|
||||
)
|
||||
} else {
|
||||
PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols)
|
||||
}
|
||||
};
|
||||
let mut model_config = ModelConfig {
|
||||
base: dummy_config.clone(),
|
||||
vars: ModelVars::new_dummy(),
|
||||
@@ -1529,6 +1572,10 @@ impl Model {
|
||||
num_shuffles: region.shuffle_index(),
|
||||
total_shuffle_col_size: region.shuffle_col_coord(),
|
||||
},
|
||||
einsum_params: crate::graph::EinsumParams {
|
||||
equations: region.used_einsum_equations(),
|
||||
total_einsum_col_size: region.einsum_col_coord(),
|
||||
},
|
||||
total_const_size: region.total_constants(),
|
||||
lookup_ops: region.used_lookups(),
|
||||
range_checks: region.used_range_checks(),
|
||||
|
||||
@@ -695,8 +695,8 @@ impl Node {
|
||||
opkind = opkind.homogenous_rescale(in_scales.clone())?.into();
|
||||
let mut out_scale = opkind.out_scale(in_scales.clone())?;
|
||||
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
|
||||
let global_scale = scales.get_max();
|
||||
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
|
||||
let rebase_scale = scales.get_rebase_scale();
|
||||
opkind = RebaseScale::rebase(opkind, rebase_scale, out_scale, scales.rebase_multiplier);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
|
||||
@@ -8,9 +8,7 @@ use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use log::debug;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
use pyo3::{exceptions::PyValueError, FromPyObject, IntoPyObject, PyResult, Python};
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
@@ -107,27 +105,33 @@ impl<'a> From<&'a str> for Visibility {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl IntoPy<PyObject> for Visibility {
|
||||
impl<'py> IntoPyObject<'py> for Visibility {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
/// Converts Visibility to Python object
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
match self {
|
||||
Visibility::Private => "private".to_object(py),
|
||||
Visibility::Public => "public".to_object(py),
|
||||
Visibility::Fixed => "fixed".to_object(py),
|
||||
Visibility::KZGCommit => "polycommit".to_object(py),
|
||||
Visibility::Private => Ok("private".into_pyobject(py)?.into_any()),
|
||||
Visibility::Public => Ok("public".into_pyobject(py)?.into_any()),
|
||||
Visibility::Fixed => Ok("fixed".into_pyobject(py)?.into_any()),
|
||||
Visibility::KZGCommit => Ok("polycommit".into_pyobject(py)?.into_any()),
|
||||
Visibility::Hashed {
|
||||
hash_is_public,
|
||||
outlets,
|
||||
} => {
|
||||
if hash_is_public {
|
||||
"hashed/public".to_object(py)
|
||||
Ok("hashed/public".into_pyobject(py)?.into_any())
|
||||
} else {
|
||||
let outlets = outlets
|
||||
.iter()
|
||||
.map(|o| o.to_string())
|
||||
.collect_vec()
|
||||
.join(",");
|
||||
format!("hashed/private/{}", outlets).to_object(py)
|
||||
Ok(format!("hashed/private/{}", outlets)
|
||||
.into_pyobject(py)?
|
||||
.into_any())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -246,6 +250,8 @@ pub struct VarScales {
|
||||
pub params: crate::Scale,
|
||||
/// Multiplier for scale rebasing
|
||||
pub rebase_multiplier: u32,
|
||||
/// rebase scale factor (optional). if None, we rebase to the max of input_scale and param_scale
|
||||
pub rebase_scale: Option<crate::Scale>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VarScales {
|
||||
@@ -265,11 +271,21 @@ impl VarScales {
|
||||
std::cmp::min(self.input, self.params)
|
||||
}
|
||||
|
||||
/// Returns the scale to rebase to, if specified
|
||||
pub fn get_rebase_scale(&self) -> crate::Scale {
|
||||
if let Some(rebase_scale) = self.rebase_scale {
|
||||
rebase_scale
|
||||
} else {
|
||||
self.get_max()
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates VarScales from runtime arguments
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
params: args.param_scale,
|
||||
rebase_scale: args.rebase_scale,
|
||||
rebase_multiplier: args.scale_rebase_multiplier,
|
||||
}
|
||||
}
|
||||
|
||||
17
src/lib.rs
17
src/lib.rs
@@ -288,6 +288,10 @@ pub struct RunArgs {
|
||||
/// Higher values provide more precision but increase circuit complexity
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
|
||||
pub param_scale: Scale,
|
||||
/// Scale to rebase to when the input scale exceeds rebase_scale * multiplier. If None we rebase to the max of input_scale and param_scale
|
||||
/// This is an advanced parameter that should be used with caution
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, required = false, value_hint = clap::ValueHint::Other))]
|
||||
pub rebase_scale: Option<Scale>,
|
||||
/// Scale rebase threshold multiplier
|
||||
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
|
||||
/// Advanced parameter that should be used with caution
|
||||
@@ -359,6 +363,13 @@ pub struct RunArgs {
|
||||
/// Optional override for epsilon value
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
|
||||
pub epsilon: Option<f64>,
|
||||
/// Forcefully disable using Freivalds' argument in einsum operations
|
||||
/// Freivalds' argument can make verifier bigger, so this option is useful when
|
||||
/// the verifier size is a concern
|
||||
/// Without this option the circuit layouter will always try to use Freivalds' argument
|
||||
/// when it is good to do so
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
|
||||
pub disable_freivalds: bool,
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
@@ -378,14 +389,15 @@ impl Default for RunArgs {
|
||||
bounded_log_lookup: false,
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
rebase_scale: None,
|
||||
scale_rebase_multiplier: 1,
|
||||
lookup_range: (-32768, 32768),
|
||||
logrows: 17,
|
||||
num_inner_cols: 2,
|
||||
variables: vec![("batch_size".to_string(), 1)],
|
||||
input_visibility: Visibility::Private,
|
||||
input_visibility: Visibility::Public,
|
||||
output_visibility: Visibility::Public,
|
||||
param_visibility: Visibility::Private,
|
||||
param_visibility: Visibility::Fixed,
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
commitment: None,
|
||||
@@ -393,6 +405,7 @@ impl Default for RunArgs {
|
||||
decomp_legs: 2,
|
||||
ignore_range_check_inputs_outputs: false,
|
||||
epsilon: None,
|
||||
disable_freivalds: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,12 +187,17 @@ impl From<ProofType> for StrategyType {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for ProofType {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
match self {
|
||||
ProofType::Single => "Single".to_object(py),
|
||||
ProofType::ForAggr => "ForAggr".to_object(py),
|
||||
}
|
||||
impl<'py> pyo3::IntoPyObject<'py> for ProofType {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
ProofType::Single => "Single",
|
||||
ProofType::ForAggr => "ForAggr",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,12 +250,17 @@ impl std::fmt::Display for StrategyType {
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts StrategyType into a PyObject (Required for StrategyType to be compatible with Python)
|
||||
impl pyo3::IntoPy<PyObject> for StrategyType {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
StrategyType::Single => "single".to_object(py),
|
||||
StrategyType::Accum => "accum".to_object(py),
|
||||
}
|
||||
impl<'py> pyo3::IntoPyObject<'py> for StrategyType {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
StrategyType::Single => "single",
|
||||
StrategyType::Accum => "accum",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -304,16 +314,6 @@ impl ToFlags for TranscriptType {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for TranscriptType {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
match self {
|
||||
TranscriptType::Poseidon => "Poseidon".to_object(py),
|
||||
TranscriptType::EVM => "EVM".to_object(py),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
///
|
||||
pub fn g1affine_to_pydict(g1affine_dict: &pyo3::Bound<'_, PyDict>, g1affine: &G1Affine) {
|
||||
@@ -401,14 +401,19 @@ where
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
|
||||
use pyo3::{types::PyDict, IntoPyObject, Python};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
|
||||
impl<'py, F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> IntoPyObject<'py>
|
||||
for Snark<F, C>
|
||||
where
|
||||
C::Scalar: Serialize + DeserializeOwned,
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
{
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let dict = PyDict::new(py);
|
||||
let field_elems: Vec<Vec<String>> = self
|
||||
.instances
|
||||
@@ -418,9 +423,9 @@ where
|
||||
dict.set_item("instances", field_elems).unwrap();
|
||||
let hex_proof = hex::encode(&self.proof);
|
||||
dict.set_item("proof", format!("0x{}", hex_proof)).unwrap();
|
||||
dict.set_item("transcript_type", self.transcript_type.to_object(py))
|
||||
dict.set_item("transcript_type", self.transcript_type.into_pyobject(py)?)
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
Ok(dict.into_any())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -480,6 +480,13 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
self[index].clone()
|
||||
}
|
||||
|
||||
/// Extracts a single value from the tensor
|
||||
pub fn get_scalar(&self) -> T {
|
||||
assert!(self.inner.len() == 1);
|
||||
assert!(self.dims.iter().all(|dim| *dim == 1));
|
||||
self.inner[0].clone()
|
||||
}
|
||||
|
||||
/// Get a mutable array index from rows / columns indices.
|
||||
///
|
||||
/// ```
|
||||
@@ -901,6 +908,22 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// remove axes that have dimensions 1
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap();
|
||||
/// let b = a.remove_trivial_axes().unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
pub fn remove_trivial_axes(&self) -> Result<Self, TensorError> {
|
||||
let mut result = self.clone();
|
||||
let new_dims: Vec<_> = self.dims.iter().copied().filter(|dim| *dim > 1).collect();
|
||||
result.reshape(&new_dims)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Move axis of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::{
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::{iter::ParallelIterator, prelude::IntoParallelRefIterator};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
pub use std::ops::{Add, Mul, Neg, Sub};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
|
||||
@@ -329,7 +330,7 @@ pub fn resize<T: TensorType + Send + Sync>(
|
||||
|
||||
let cartesian_coord: Vec<Vec<usize>> = new_shape
|
||||
.iter()
|
||||
.map(|d| (0..*d))
|
||||
.map(|d| 0..*d)
|
||||
.multi_cartesian_product()
|
||||
.collect();
|
||||
|
||||
@@ -1218,7 +1219,7 @@ pub fn intercalate_values<T: TensorType>(
|
||||
let cartesian_coord = output
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|d| (0..*d))
|
||||
.map(|d| 0..*d)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -1263,7 +1264,7 @@ pub fn one_hot(
|
||||
let cartesian_coord = output
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|d| (0..*d))
|
||||
.map(|d| 0..*d)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -1341,7 +1342,7 @@ pub fn pad<T: TensorType>(
|
||||
let cartesian_coord = image
|
||||
.dims()
|
||||
.iter()
|
||||
.map(|d| (0..*d))
|
||||
.map(|d| 0..*d)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -2396,6 +2397,8 @@ pub mod nonlinearities {
|
||||
|
||||
/// Ops that return the transcript i.e intermediate calcs of an op
|
||||
pub mod accumulated {
|
||||
use maybe_rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator};
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Dot product of two tensors.
|
||||
@@ -2523,4 +2526,327 @@ pub mod accumulated {
|
||||
|
||||
Ok(transcript)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn row_major_strides(dims: &[usize]) -> Vec<usize> {
|
||||
let mut s = vec![0; dims.len()];
|
||||
let mut acc = 1;
|
||||
for (i, &d) in dims.iter().enumerate().rev() {
|
||||
s[i] = acc;
|
||||
acc *= d;
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::accumulated::einsum;
|
||||
///
|
||||
/// // matmul case
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("ij,jk->ik", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[8, 9, 5, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // element wise multiplication
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("ij,ij->ij", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 9, 2, 6, 12, 3, 8, 15]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // dot product of A with the transpose of B.
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("ik,jk->ij", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 14, 14, 20, 20, 20, 26, 26, 26]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // dot product
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("ik,ik->i", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 20, 26]), &[3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // dot product
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3]),
|
||||
/// &[3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3]),
|
||||
/// &[3],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("i,i->", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[14]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // wut ?
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("anm,bm->ba", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[68, 80, 95, 113, 134, 158]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // wutttttt ?
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let z = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8, 9, 9]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let (result, _) = einsum::<IntegerRep>("bn,anm,bm->ba", &[&z, &x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[390, 414, 534, 994, 1153, 1384]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // contraction with a single common axis
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("abc,cd->", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[648]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // contraction with no common axes (outer product)
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let (result, _) = einsum::<IntegerRep>("abc,ed->", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1296]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // trivial axes mapping
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 7, 8]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let (result, _) = einsum::<IntegerRep>("mk,k->m", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[41, 68]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let (result, _) = einsum::<IntegerRep>("mk,k->mn", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[41, 68]), &[2, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[0, 0, 0, 3]),
|
||||
/// &[1, 4],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[213, 227, 74, 77]),
|
||||
/// &[4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let (result, _) = einsum::<IntegerRep>("mk,k->ma", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[231]), &[1, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// // subtle difference
|
||||
/// let (result, _) = einsum::<IntegerRep>("mk,n->ma", &[&x, &k]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1773]), &[1, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn einsum<T>(
|
||||
equation: &str,
|
||||
input_tensors: &[&Tensor<T>],
|
||||
) -> Result<(Tensor<T>, HashMap<char, usize>), TensorError>
|
||||
where
|
||||
T: Clone + TensorType + Mul<Output = T> + Add<Output = T> + Send + Sync,
|
||||
{
|
||||
let (input_exprs, output_expr) = equation.split_once("->").unwrap();
|
||||
let input_exprs: Vec<&str> = input_exprs.split(',').collect();
|
||||
assert_eq!(input_exprs.len(), input_tensors.len());
|
||||
|
||||
let mut dim_of: HashMap<char, usize> = HashMap::new();
|
||||
for (input_expr, t) in input_exprs.iter().zip(input_tensors.iter()) {
|
||||
for (c, &d) in input_expr.chars().zip(t.dims().iter()) {
|
||||
let e = dim_of.entry(c).or_insert(d);
|
||||
debug_assert!((*e == d) || (*e == 1) || (d == 1));
|
||||
*e = (*e).max(d);
|
||||
}
|
||||
}
|
||||
|
||||
// Output dims
|
||||
let out_idx: Vec<char> = output_expr.chars().collect();
|
||||
let out_dims: Vec<usize> = out_idx
|
||||
.iter()
|
||||
.map(|c| *dim_of.get(c).unwrap_or(&1))
|
||||
.collect();
|
||||
|
||||
// Reduction indices
|
||||
let all_idx: HashSet<char> = dim_of.keys().copied().collect();
|
||||
let out_set: HashSet<char> = out_idx.iter().copied().collect();
|
||||
let red_idx: Vec<char> = all_idx.difference(&out_set).copied().collect();
|
||||
let red_dims: Vec<usize> = red_idx.iter().map(|c| dim_of[c]).collect();
|
||||
|
||||
// Fast index->pos
|
||||
let out_pos: HashMap<char, usize> =
|
||||
out_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
let red_pos: HashMap<char, usize> =
|
||||
red_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
|
||||
// Precompute strides per input and contributions
|
||||
struct Contrib {
|
||||
out_stride: Vec<usize>,
|
||||
red_stride: Vec<usize>,
|
||||
}
|
||||
let contribs: Vec<Contrib> = input_exprs
|
||||
.iter()
|
||||
.zip(input_tensors.iter())
|
||||
.map(|(expr, t)| {
|
||||
let dims = t.dims().to_vec();
|
||||
let strides = row_major_strides(&dims);
|
||||
let mut out_stride = vec![0; out_idx.len()];
|
||||
let mut red_stride = vec![0; red_idx.len()];
|
||||
for (ax, (c, &d)) in expr.chars().zip(dims.iter()).enumerate() {
|
||||
let s = if d == 1 { 0 } else { strides[ax] };
|
||||
if let Some(&p) = out_pos.get(&c) {
|
||||
out_stride[p] = s;
|
||||
} else if let Some(&q) = red_pos.get(&c) {
|
||||
red_stride[q] = s;
|
||||
}
|
||||
}
|
||||
Contrib {
|
||||
out_stride,
|
||||
red_stride,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Prepare output buffer
|
||||
let mut out = if out_dims.is_empty() {
|
||||
Tensor::<T>::new(None, &[1])?
|
||||
} else {
|
||||
Tensor::<T>::new(None, &out_dims)?
|
||||
};
|
||||
|
||||
let out_rank = out_dims.len();
|
||||
let red_rank = red_dims.len();
|
||||
|
||||
// Materialize output elements one by one
|
||||
out.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(out_linear_coord, out)| {
|
||||
let mut out_index = vec![0usize; out_rank];
|
||||
{
|
||||
let mut x = out_linear_coord;
|
||||
for i in (0..out_rank).rev() {
|
||||
let d = out_dims[i];
|
||||
out_index[i] = x % d;
|
||||
x /= d;
|
||||
}
|
||||
}
|
||||
|
||||
// Base offset per input from output coordinates
|
||||
let mut base_off = vec![0usize; input_tensors.len()];
|
||||
for (i, c) in contribs.iter().enumerate() {
|
||||
let mut off = 0usize;
|
||||
for p in 0..out_rank {
|
||||
off += out_index[p] * c.out_stride[p];
|
||||
}
|
||||
base_off[i] = off;
|
||||
}
|
||||
|
||||
let mut acc = T::zero().unwrap();
|
||||
|
||||
if red_rank == 0 {
|
||||
// No reduction -> just multiply corresponding elements
|
||||
let mut prod = T::one().unwrap();
|
||||
for (i, t) in input_tensors.iter().enumerate() {
|
||||
let val = t.get_flat_index(base_off[i]);
|
||||
prod = prod * val;
|
||||
}
|
||||
acc = acc + prod;
|
||||
} else {
|
||||
// Iterate over all reduction coords
|
||||
let red_size = red_dims.iter().product::<usize>();
|
||||
let mut red_index = vec![0usize; red_rank];
|
||||
for red_linear_coord in 0..red_size {
|
||||
{
|
||||
let mut x = red_linear_coord;
|
||||
for q in (0..red_rank).rev() {
|
||||
let d = red_dims[q];
|
||||
red_index[q] = x % d;
|
||||
x /= d;
|
||||
}
|
||||
}
|
||||
let mut prod = T::one().unwrap();
|
||||
for (i, (t, c)) in input_tensors.iter().zip(contribs.iter()).enumerate() {
|
||||
let mut off = base_off[i];
|
||||
for q in 0..red_rank {
|
||||
off += red_index[q] * c.red_stride[q];
|
||||
}
|
||||
let val = t.get_flat_index(off);
|
||||
prod = prod * val;
|
||||
}
|
||||
acc = acc + prod;
|
||||
}
|
||||
}
|
||||
|
||||
// write result
|
||||
*out = acc;
|
||||
});
|
||||
Ok((out, dim_of))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -940,6 +940,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// remove axes that have dimensions 1
|
||||
pub fn remove_trivial_axes(&mut self) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.remove_trivial_axes()?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Takes a slice of the tensor along a given axis
|
||||
///
|
||||
/// # Arguments
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use halo2_proofs::plonk::SecondPhase;
|
||||
use log::{debug, error, warn};
|
||||
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
@@ -152,6 +153,52 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
|
||||
/// the values need to be hidden in the proof.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `cs` - The constraint system to create columns in
|
||||
/// * `logrows` - Log base 2 of the total number of rows
|
||||
/// * `num_inner_cols` - Number of columns in each inner block
|
||||
/// * `capacity` - Total number of advice cells to allocate
|
||||
///
|
||||
/// # Returns
|
||||
/// A new VarTensor::Advice in SecondPhase with blinded columns enabled for equality constraints
|
||||
pub fn new_advice_in_second_phase<F: PrimeField>(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
num_inner_cols: usize,
|
||||
capacity: usize,
|
||||
) -> Self {
|
||||
let max_rows = Self::max_rows(cs, logrows);
|
||||
let max_assignments = Self::max_rows(cs, logrows) * num_inner_cols;
|
||||
|
||||
let mut modulo = (capacity / max_assignments) + 1;
|
||||
// we add a buffer for duplicated rows (we get at most 1 duplicated row per column)
|
||||
modulo = ((capacity + modulo) / max_assignments) + 1;
|
||||
let mut advices = vec![];
|
||||
|
||||
if modulo > 1 {
|
||||
debug!("using column duplication for {} advice blocks", modulo - 1);
|
||||
}
|
||||
|
||||
for _ in 0..modulo {
|
||||
let mut inner = vec![];
|
||||
for _ in 0..num_inner_cols {
|
||||
let col = cs.advice_column_in(SecondPhase);
|
||||
cs.enable_equality(col);
|
||||
inner.push(col);
|
||||
}
|
||||
advices.push(inner);
|
||||
}
|
||||
|
||||
VarTensor::Advice {
|
||||
inner: advices,
|
||||
num_inner_cols,
|
||||
col_size: max_rows,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
|
||||
/// Fixed columns are used for constant values that are known at circuit creation time.
|
||||
///
|
||||
@@ -270,7 +317,7 @@ impl VarTensor {
|
||||
/// # Returns
|
||||
/// A tuple of (block_index, column_index, row_index)
|
||||
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
|
||||
// x indexes over blocks of size num_inner_cols
|
||||
// x (block idx) indexes over blocks of size num_inner_cols
|
||||
let x = linear_coord / self.block_size();
|
||||
// y indexes over the cols inside a block
|
||||
let y = linear_coord % self.num_inner_cols();
|
||||
@@ -519,7 +566,7 @@ impl VarTensor {
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
&self,
|
||||
row: usize,
|
||||
_row: usize,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
single_inner_col: bool,
|
||||
@@ -545,7 +592,7 @@ impl VarTensor {
|
||||
self.num_inner_cols()
|
||||
};
|
||||
|
||||
let duplication_offset = if single_inner_col { row } else { offset };
|
||||
let (_, _, duplication_offset) = self.cartesian_coord(offset);
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let mut res: ValTensor<F> = v
|
||||
@@ -651,7 +698,7 @@ impl VarTensor {
|
||||
>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
row: usize,
|
||||
_row: usize,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &CheckMode,
|
||||
@@ -669,7 +716,7 @@ impl VarTensor {
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
let duplication_freq = self.col_size();
|
||||
let num_repeats = 1;
|
||||
let duplication_offset = row;
|
||||
let (_, _, duplication_offset) = self.cartesian_coord(offset);
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -28,7 +28,8 @@
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false,
|
||||
"ignore_range_check_inputs_outputs": false
|
||||
"ignore_range_check_inputs_outputs": false,
|
||||
"disable_freivalds": false
|
||||
},
|
||||
"num_rows": 236,
|
||||
"total_assignments": 472,
|
||||
@@ -36,6 +37,10 @@
|
||||
"total_dynamic_col_size": 0,
|
||||
"max_dynamic_input_len": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
"einsum_params": {
|
||||
"equations": [],
|
||||
"total_einsum_col_size": 0
|
||||
},
|
||||
"num_shuffles": 0,
|
||||
"total_shuffle_col_size": 0,
|
||||
"model_instance_shapes": [
|
||||
|
||||
Binary file not shown.
@@ -163,9 +163,18 @@ mod native_tests {
|
||||
let data = GraphData::from_path(format!("{}/{}/input.json", test_dir, test).into())
|
||||
.expect("failed to load input data");
|
||||
|
||||
let duplicated_input_data = data.input_data;
|
||||
let duplicated_input_data = data
|
||||
.input_data
|
||||
.into_iter()
|
||||
.map(|input| {
|
||||
(0..num_batches)
|
||||
.map(move |_| input.clone())
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let duplicated_data = GraphData::new(duplicated_input_data.into());
|
||||
let duplicated_data = GraphData::new(duplicated_input_data);
|
||||
|
||||
let res =
|
||||
duplicated_data.save(format!("{}/{}/input.json", test_dir, output_dir).into());
|
||||
@@ -198,7 +207,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 99] = [
|
||||
const TESTS: [&str; 100] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice", //1
|
||||
"1l_concat", //2
|
||||
@@ -302,65 +311,66 @@ mod native_tests {
|
||||
"exp", // 96
|
||||
"general_exp", // 97
|
||||
"integer_div", // 98
|
||||
"large_mlp", // 99
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
"1l_mlp",
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
"1l_flatten",
|
||||
const WASM_TESTS: [&str; 44] = [
|
||||
"1l_mlp", // 0
|
||||
"1l_slice", // 1
|
||||
"1l_concat", // 2
|
||||
"1l_flatten", // 3
|
||||
// "1l_average",
|
||||
"1l_div",
|
||||
"1l_pad",
|
||||
"1l_reshape",
|
||||
"1l_eltwise_div",
|
||||
"1l_sigmoid",
|
||||
"1l_sqrt",
|
||||
"1l_softmax",
|
||||
"1l_div", // 4
|
||||
"1l_pad", // 5
|
||||
"1l_reshape", // 6
|
||||
"1l_eltwise_div", // 7
|
||||
"1l_sigmoid", // 8
|
||||
"1l_sqrt", // 9
|
||||
"1l_softmax", // 10
|
||||
// "1l_instance_norm",
|
||||
"1l_batch_norm",
|
||||
"1l_prelu",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
"1l_batch_norm", // 11
|
||||
"1l_prelu", // 12
|
||||
"1l_leakyrelu", // 13
|
||||
"1l_gelu_noappx", // 14
|
||||
// "1l_gelu_tanh_appx",
|
||||
"1l_relu",
|
||||
"1l_downsample",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small",
|
||||
"2l_relu_sigmoid",
|
||||
"1l_conv",
|
||||
"2l_sigmoid_small",
|
||||
"2l_relu_sigmoid_conv",
|
||||
"3l_relu_conv_fc",
|
||||
"4l_relu_conv_fc",
|
||||
"1l_erf",
|
||||
"1l_var",
|
||||
"1l_elu",
|
||||
"min",
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"1l_conv_transpose",
|
||||
"1l_upsample",
|
||||
"1l_identity",
|
||||
"1l_relu", // 15
|
||||
"1l_downsample", // 16
|
||||
"1l_tanh", // 17
|
||||
"2l_relu_sigmoid_small", // 18
|
||||
"2l_relu_fc", // 19
|
||||
"2l_relu_small", // 20
|
||||
"2l_relu_sigmoid", // 21
|
||||
"1l_conv", // 22
|
||||
"2l_sigmoid_small", // 23
|
||||
"2l_relu_sigmoid_conv", // 24
|
||||
// "3l_relu_conv_fc",
|
||||
// "4l_relu_conv_fc",
|
||||
"1l_erf", // 25
|
||||
"1l_var", // 26
|
||||
"1l_elu", // 27
|
||||
"min", // 28
|
||||
"max", // 29
|
||||
"1l_max_pool", // 30
|
||||
"1l_conv_transpose", // 31
|
||||
"1l_upsample", // 32
|
||||
"1l_identity", // 33
|
||||
// "idolmodel",
|
||||
"trig",
|
||||
"prelu_gmm",
|
||||
"lstm",
|
||||
"rnn",
|
||||
"quantize_dequantize",
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
// "xgboost",
|
||||
// "lightgbm",
|
||||
// "hummingbird_decision_tree",
|
||||
"trig", // 34
|
||||
"prelu_gmm", // 35
|
||||
"lstm", // 36
|
||||
"rnn", // 37
|
||||
"quantize_dequantize", // 38
|
||||
"1l_where", // 39
|
||||
"boolean", // 40
|
||||
"boolean_identity", // 41
|
||||
"gradient_boosted_trees", // 42
|
||||
"1l_topk", // 43
|
||||
// "xgboost",
|
||||
// "lightgbm",
|
||||
// "hummingbird_decision_tree",
|
||||
];
|
||||
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
const TESTS_AGGR: [&str; 21] = [
|
||||
"1l_mlp",
|
||||
"1l_flatten",
|
||||
@@ -385,7 +395,7 @@ mod native_tests {
|
||||
"1l_max_pool",
|
||||
];
|
||||
|
||||
#[cfg(feature = "icicle")]
|
||||
#[cfg(feature = "gpu-accelerated")]
|
||||
const TESTS_AGGR: [&str; 3] = ["1l_mlp", "1l_flatten", "1l_average"];
|
||||
|
||||
const TESTS_EVM: [&str; 23] = [
|
||||
@@ -445,11 +455,12 @@ mod native_tests {
|
||||
use crate::native_tests::TESTS_AGGR;
|
||||
use test_case::test_case;
|
||||
use crate::native_tests::aggr_prove_and_verify;
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
use crate::native_tests::kzg_aggr_mock_prove_and_verify;
|
||||
use tempdir::TempDir;
|
||||
use ezkl::Commitments;
|
||||
|
||||
#[cfg(not(feature="icicle"))]
|
||||
#[cfg(not(feature="gpu-accelerated"))]
|
||||
seq!(N in 0..=20 {
|
||||
|
||||
#(#[test_case(TESTS_AGGR[N])])*
|
||||
@@ -483,7 +494,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
#[cfg(feature="icicle")]
|
||||
#[cfg(feature="gpu-accelerated")]
|
||||
seq!(N in 0..=2 {
|
||||
#(#[test_case(TESTS_AGGR[N])])*
|
||||
fn kzg_aggr_prove_and_verify_(test: &str) {
|
||||
@@ -511,8 +522,8 @@ mod native_tests {
|
||||
use crate::native_tests::mock;
|
||||
use crate::native_tests::accuracy_measurement;
|
||||
use crate::native_tests::prove_and_verify;
|
||||
use crate::native_tests::run_js_tests;
|
||||
use crate::native_tests::render_circuit;
|
||||
// use crate::native_tests::run_js_tests;
|
||||
// use crate::native_tests::render_circuit;
|
||||
use crate::native_tests::model_serialization_different_binaries;
|
||||
|
||||
use tempdir::TempDir;
|
||||
@@ -540,17 +551,17 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=98 {
|
||||
seq!(N in 0..99 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
fn render_circuit_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
render_circuit(path, test.to_string());
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
// #(#[test_case(TESTS[N])])*
|
||||
// #[ignore]
|
||||
// fn render_circuit_(test: &str) {
|
||||
// crate::native_tests::init_binary();
|
||||
// let test_dir = TempDir::new(test).unwrap();
|
||||
// let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// render_circuit(path, test.to_string());
|
||||
// test_dir.close().unwrap();
|
||||
// }
|
||||
|
||||
|
||||
|
||||
@@ -902,7 +913,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=45 {
|
||||
seq!(N in 0..=43 {
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
fn kzg_prove_and_verify_with_overflow_(test: &str) {
|
||||
@@ -912,8 +923,8 @@ mod native_tests {
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single", Commitments::KZG, 2);
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
// #[cfg(not(feature = "gpu-accelerated"))]
|
||||
// run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -925,8 +936,8 @@ mod native_tests {
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
prove_and_verify(path, test.to_string(), "safe", "hashed", "private", "public", 1, None, true, "single", Commitments::KZG, 2);
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
// #[cfg(not(feature = "gpu-accelerated"))]
|
||||
// run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -938,8 +949,8 @@ mod native_tests {
|
||||
env_logger::init();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single", Commitments::KZG, 2);
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
// #[cfg(not(feature = "gpu-accelerated"))]
|
||||
// run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -986,6 +997,7 @@ mod native_tests {
|
||||
use crate::native_tests::kzg_evm_aggr_prove_and_verify;
|
||||
use tempdir::TempDir;
|
||||
use crate::native_tests::Hardfork;
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
use crate::native_tests::run_js_tests;
|
||||
use ezkl::logger::init_logger;
|
||||
use crate::native_tests::lazy_static;
|
||||
@@ -1012,7 +1024,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=98 {
|
||||
seq!(N in 0..=99 {
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
@@ -1086,7 +1098,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
@@ -1100,7 +1112,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let mut _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "hashed", "private", "private");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
@@ -1117,7 +1129,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let mut _anvil_child = crate::native_tests::start_anvil(false, hardfork);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "polycommit", "private", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
@@ -1130,7 +1142,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "hashed", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
@@ -1143,7 +1155,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "hashed");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
@@ -1156,7 +1168,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "polycommit", "public");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
@@ -1169,7 +1181,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "polycommit");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
@@ -1181,7 +1193,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "polycommit", "polycommit", "polycommit");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
@@ -1548,24 +1560,25 @@ mod native_tests {
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
// Mock prove (fast, but does not cover some potential issues)
|
||||
fn render_circuit(test_dir: &str, example_name: String) {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"render-circuit",
|
||||
"-M",
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
"-O",
|
||||
format!("{}/{}/render.png", test_dir, example_name).as_str(),
|
||||
"--lookup-range=-32768->32768",
|
||||
"-K=17",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
// // Mock prove (fast, but does not cover some potential issues)
|
||||
// fn render_circuit(test_dir: &str, example_name: String) {
|
||||
// let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
// .args([
|
||||
// "render-circuit",
|
||||
// "-M",
|
||||
// format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
// "-O",
|
||||
// format!("{}/{}/render.png", test_dir, example_name).as_str(),
|
||||
// "--lookup-range=-32768->32768",
|
||||
// "-K=17",
|
||||
// ])
|
||||
// .status()
|
||||
// .expect("failed to execute process");
|
||||
// assert!(status.success());
|
||||
// }
|
||||
|
||||
// prove-serialize-verify, the usual full path
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
fn kzg_aggr_mock_prove_and_verify(test_dir: &str, example_name: String) {
|
||||
prove_and_verify(
|
||||
test_dir,
|
||||
@@ -2222,6 +2235,7 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// run js browser evm verify tests for a given example
|
||||
#[cfg(not(feature = "gpu-accelerated"))]
|
||||
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str, vk: bool) {
|
||||
let example = format!("--example={}", example_name);
|
||||
let dir = format!("--dir={}", test_dir);
|
||||
@@ -2238,8 +2252,9 @@ mod native_tests {
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
#[allow(unused_variables)]
|
||||
fn build_ezkl() {
|
||||
#[cfg(feature = "icicle")]
|
||||
#[cfg(feature = "gpu-accelerated")]
|
||||
let args = [
|
||||
"build",
|
||||
"--profile=test-runs",
|
||||
@@ -2258,18 +2273,8 @@ mod native_tests {
|
||||
"macos-metal",
|
||||
];
|
||||
// not macos-metal and not icicle
|
||||
#[cfg(all(not(feature = "icicle"), not(feature = "macos-metal")))]
|
||||
#[cfg(all(not(feature = "gpu-accelerated"), not(feature = "macos-metal")))]
|
||||
let args = ["build", "--profile=test-runs", "--bin", "ezkl"];
|
||||
#[cfg(feature = "eth-original-lookup")]
|
||||
let args = [
|
||||
"build",
|
||||
"--profile=test-runs",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--no-default-features",
|
||||
"--features",
|
||||
"ezkl,solidity-verifier,eth",
|
||||
];
|
||||
#[cfg(feature = "reusable-verifier")]
|
||||
let args = [
|
||||
"build",
|
||||
|
||||
@@ -352,7 +352,7 @@ def test_prove_and_verify():
|
||||
"for-aggr",
|
||||
srs_path=srs_path,
|
||||
)
|
||||
assert res['transcript_type'] == 'Poseidon'
|
||||
assert res['transcript_type'] == 'poseidon'
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
settings_path = os.path.join(folder_path, 'settings.json')
|
||||
@@ -388,7 +388,7 @@ def test_prove_evm():
|
||||
"single",
|
||||
srs_path=srs_path,
|
||||
)
|
||||
assert res['transcript_type'] == 'EVM'
|
||||
assert res['transcript_type'] == 'evm'
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ mod wasm32 {
|
||||
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
|
||||
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
|
||||
kzgCommit, pkValidation, poseidonHash, proofValidation, prove, settingsValidation,
|
||||
srsValidation, u8_array_to_u128_le, verify, verifyAggr, verifyEVM, vkValidation,
|
||||
witnessValidation,
|
||||
srsValidation, u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
|
||||
};
|
||||
use ezkl::circuit::modules::polycommit::PolyCommitChip;
|
||||
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
@@ -42,21 +41,21 @@ mod wasm32 {
|
||||
pub const SRS1: &[u8] = include_bytes!("assets/kzg1.srs");
|
||||
pub const VERIFIER_BYTECODE: &[u8] = include_bytes!("assets/wasm.code");
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn can_verify_aggr() {
|
||||
let value = verifyAggr(
|
||||
wasm_bindgen::Clamped(PROOF_AGGR.to_vec()),
|
||||
wasm_bindgen::Clamped(VK_AGGR.to_vec()),
|
||||
21,
|
||||
wasm_bindgen::Clamped(SRS1.to_vec()),
|
||||
"kzg",
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
// #[wasm_bindgen_test]
|
||||
// async fn can_verify_aggr() {
|
||||
// let value = verifyAggr(
|
||||
// wasm_bindgen::Clamped(PROOF_AGGR.to_vec()),
|
||||
// wasm_bindgen::Clamped(VK_AGGR.to_vec()),
|
||||
// 21,
|
||||
// wasm_bindgen::Clamped(SRS1.to_vec()),
|
||||
// "kzg",
|
||||
// )
|
||||
// .map_err(|_| "failed")
|
||||
// .unwrap();
|
||||
|
||||
// should not fail
|
||||
assert!(value);
|
||||
}
|
||||
// // should not fail
|
||||
// assert!(value);
|
||||
// }
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn verify_encode_verifier_calldata() {
|
||||
@@ -94,21 +93,6 @@ mod wasm32 {
|
||||
assert_eq!(calldata, reference_calldata);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
async fn can_verify_evm() {
|
||||
// verify with single purpose evm verifier contract
|
||||
let value = verifyEVM(
|
||||
wasm_bindgen::Clamped(PROOF.to_vec()),
|
||||
VERIFIER_BYTECODE.to_vec(),
|
||||
None,
|
||||
)
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
|
||||
// should not fail
|
||||
assert!(value);
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn verify_kzg_commit() {
|
||||
// create a vector of field elements Vec<Fr> and assign it to the message variable
|
||||
|
||||
Reference in New Issue
Block a user