feat: implement generalized Freivalds' algorithm for arbitrary einsum (#990)

---------

Co-authored-by: DoHoon Kim <59155248+DoHoonKim8@users.noreply.github.com>
Co-authored-by: therealyingtong <yingtong.lai@gmail.com>
Co-authored-by: DoHoonKim <dohoon1097819@gmail.com>
This commit is contained in:
dante
2025-10-08 07:34:32 -04:00
committed by GitHub
parent d64749fc71
commit 365d92a5f2
37 changed files with 3414 additions and 236 deletions

View File

@@ -18,6 +18,7 @@ jobs:
permissions:
contents: read
packages: write
id-token: write # Required for provenance
name: publish-wasm-bindings
env:
RELEASE_TAG: ${{ github.ref_name }}
@@ -25,19 +26,19 @@ jobs:
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
cache: false
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
version: "v0.12.1"
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
@@ -51,41 +52,41 @@ jobs:
wasm-opt --version
- name: Build wasm files for both web and nodejs compilation targets
run: |
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
- name: Create package.json in pkg folder
shell: bash
run: |
cat > pkg/package.json << EOF
{
"name": "@ezkljs/engine",
"version": "$RELEASE_TAG",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}
EOF
cat > pkg/package.json << EOF
{
"name": "@ezkljs/engine",
"version": "$RELEASE_TAG",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}
EOF
- name: Replace memory definition in nodejs
run: |
@@ -177,16 +178,17 @@ jobs:
run: |
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
# zizmor: ignore cache-poisoning
- name: Set up Node.js
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
package-manager-cache: false
- name: Publish to npm with provenance
run: |
cd pkg
npm install
npm ci
npm publish
npm publish --provenance --access public
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -27,11 +27,11 @@ jobs:
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -52,12 +52,11 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -78,10 +77,10 @@ jobs:
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -102,10 +101,10 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -116,7 +115,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -126,7 +125,6 @@ jobs:
- name: Library tests
run: cargo nextest run --lib --verbose
ultra-overflow-tests-gpu:
permissions:
contents: read
@@ -137,19 +135,19 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
with:
wasmtime-version: "3.0.1"
- name: Setup GPU dependencies
@@ -176,7 +174,7 @@ jobs:
ultra-overflow-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
@@ -184,12 +182,11 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -200,11 +197,11 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
with:
wasmtime-version: "3.0.1"
# - name: Matmul overflow (wasi)
@@ -230,12 +227,11 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -246,7 +242,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -264,12 +260,11 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -280,11 +275,11 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.13.1
version: "v0.13.1"
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
@@ -305,7 +300,7 @@ jobs:
mock-proving-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
@@ -313,16 +308,16 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -380,7 +375,7 @@ jobs:
prove-and-verify-evm-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
@@ -388,35 +383,34 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.13.1
version: "v0.13.1"
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Use Node.js 22.17.1
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "22.17.1"
cache: "pnpm"
@@ -434,8 +428,8 @@ jobs:
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: KZG prove and verify tests (EVM)
run: cargo nextest run --verbose "tests_evm::kzg_evm_prove_and_verify_::" --test-threads 1
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --features reusable-verifier --test-threads 1
# - name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
# run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --features reusable-verifier --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg inputs)
@@ -486,22 +480,21 @@ jobs:
prove-and-verify-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.13.1
version: "v0.13.1"
@@ -512,15 +505,15 @@ jobs:
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Use Node.js 22.17.1
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "22.17.1"
cache: "pnpm"
@@ -530,7 +523,7 @@ jobs:
env:
CI: false
NODE_ENV: development
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -575,17 +568,17 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -621,17 +614,16 @@ jobs:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -649,15 +641,15 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -681,17 +673,16 @@ jobs:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -707,17 +698,16 @@ jobs:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -739,12 +729,11 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -755,7 +744,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -765,21 +754,20 @@ jobs:
python-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.12"
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -800,26 +788,25 @@ jobs:
accuracy-measurement-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.12"
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -845,20 +832,19 @@ jobs:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.11"
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -910,17 +896,16 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -941,11 +926,11 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -983,7 +968,7 @@ jobs:
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 16 Pro,OS=18.2' \
-destination 'platform=iOS Simulator,name=iPhone 16 Pro,OS=18.4' \
-resultBundlePath ../testResults
- name: Run Example App Tests
@@ -992,7 +977,7 @@ jobs:
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 16 Pro,OS=18.2' \
-destination 'platform=iOS Simulator,name=iPhone 16 Pro,OS=18.4' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder

2
Cargo.lock generated
View File

@@ -2431,7 +2431,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#1dd2090741f006fd031a07da7f3c9dfce5e0015e"
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#01c88842679b4308e43ae5ed91c4183e861669bd"
dependencies = [
"bincode",
"blake2b_simd",

View File

@@ -300,7 +300,6 @@ halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]

View File

@@ -1,53 +1,78 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
Throughput,
};
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::{create_keys, create_proof_circuit, TranscriptType};
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, SimpleFloorPlanner, Value},
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
use std::collections::HashMap;
static mut LEN: usize = 4;
const K: usize = 16;
static mut K: usize = 15;
#[derive(Clone)]
struct MyCircuit {
inputs: [ValTensor<Fr>; 2],
_marker: PhantomData<Fr>,
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum_params: SingleEinsumParams<F>,
}
impl Circuit<Fr> for MyCircuit {
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
type FloorPlanner = V1;
type Params = SingleEinsumParams<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let len = unsafe { LEN };
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let mut config = Self::Config::default();
let a = VarTensor::new_advice(cs, K, 1, len * len);
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
unsafe {
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
}
let b = VarTensor::new_advice(cs, K, 1, len * len);
config
}
let output = VarTensor::new_advice(cs, K, 1, (len + 1) * len);
fn params(&self) -> Self::Params {
SingleEinsumParams::<Fr>::new(
&self.einsum_params.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE)
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
@@ -55,16 +80,33 @@ impl Circuit<Fr> for MyCircuit {
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
equation: self.einsum_params.equation.clone(),
}),
)
.unwrap();
@@ -77,41 +119,49 @@ impl Circuit<Fr> for MyCircuit {
fn runmatmul(c: &mut Criterion) {
let mut group = c.benchmark_group("accum_einsum_matmul");
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [4, 32].iter() {
unsafe {
LEN = len;
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear));
group.sampling_mode(criterion::SamplingMode::Flat);
group.sample_size(10);
let len = 128;
unsafe {
LEN = len;
}
for k in 16..17 {
let params = unsafe {
K = k;
gen_srs::<KZGCommitmentScheme<_>>(K as u32)
};
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[len, len]).unwrap();
let einsum_params = SingleEinsumParams::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
_marker: PhantomData,
einsum_params,
};
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, &params, false)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
MyCircuit<Fr>,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,

View File

@@ -0,0 +1,171 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
const K: usize = 13;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runmatmul() {
let len = 64;
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[len, len]).unwrap();
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runmatmul()
}

179
examples/batch_mat_mul.rs Normal file
View File

@@ -0,0 +1,179 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
static mut LEN: usize = 4;
const K: usize = 11;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let len = unsafe { LEN };
let a = VarTensor::new_advice(cs, K, 1, len);
let b = VarTensor::new_advice(cs, K, 1, len);
let output = VarTensor::new_advice(cs, K, 1, len);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runbatchmatmul() {
let batch_size = 5;
let len = 12;
let mut a = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[batch_size, len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[batch_size, len, len]).unwrap();
let einsum = Einsum::<Fr>::new("ijk,ikl->ijl", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runbatchmatmul()
}

View File

@@ -866,7 +866,7 @@
},
{
"cell_type": "code",
"execution_count": 98,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -879,6 +879,7 @@
"run_args.input_visibility = \"private\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.disable_freivalds = True\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"\n",
@@ -1142,4 +1143,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -0,0 +1,79 @@
from torch import nn
import torch.nn.init as init
import torch
import json
N = 100
class Model(nn.Module):
def __init__(self, inplace=False):
super(Model, self).__init__()
self.aff1 = nn.Linear(N,N)
self.aff2 = nn.Linear(N,N)
self.aff3 = nn.Linear(N,N)
self.aff4 = nn.Linear(N,N)
self.aff5 = nn.Linear(N,N)
self.aff6 = nn.Linear(N,N)
self.aff7 = nn.Linear(N,N)
self.aff8 = nn.Linear(N,N)
self.aff9 = nn.Linear(N,N)
self.relu = nn.ReLU()
self._initialize_weights()
def forward(self, x):
# concat 10 x along dim 0
x = x.repeat(10, 1)
x = self.aff1(x)
x = self.relu(x)
x = self.aff2(x)
x = self.relu(x)
x = self.aff3(x)
x = self.relu(x)
x = self.aff4(x)
x = self.relu(x)
x = self.aff5(x)
x = self.relu(x)
x = self.aff6(x)
x = self.relu(x)
x = self.aff7(x)
x = self.relu(x)
x = self.aff8(x)
x = self.relu(x)
x = self.aff9(x)
return x
def _initialize_weights(self):
init.orthogonal_(self.aff1.weight)
model = Model()
# Flips the neural net into inference mode
model.eval()
model.to('cpu')
x = torch.randn(1, N)
# Export the model
torch.onnx.export(model, # model being run
# model input (or a tuple for multiple inputs)
x,
# where to save the model (can be a file or file-like object)
"network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
data_json = dict(input_data=[data_array])
print(data_json)
# Serialize data into file:
json.dump(data_json, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.33088353276252747, -0.8819183707237244, 1.245591163635254, -1.807046890258789, 1.9922369718551636, -0.3360576629638672, 0.4529011845588684, -0.3590165674686432, 0.08356846123933792, 0.5126393437385559, 0.44627535343170166, 1.4916497468948364, 0.49731069803237915, -0.9748706817626953, -0.4923185408115387, 1.3548223972320557, 0.2306872010231018, 1.125955581665039, -1.7063908576965332, 0.3777385354042053, -2.7988760471343994, -1.1846797466278076, 0.7473157048225403, 1.490412950515747, 0.017497723922133446, 2.113945245742798, -1.2141249179840088, -0.16120357811450958, 0.021127669140696526, 0.7207374572753906, -1.369688868522644, -0.7369781732559204, -0.630584180355072, -0.4520200788974762, 0.29123976826667786, 0.6334688067436218, -0.869332492351532, -1.258501648902893, 0.3012596666812897, -0.5507447123527527, 0.669975757598877, 0.15088629722595215, -0.1050339788198471, 0.5505334138870239, -0.1287376880645752, -1.4297826290130615, -0.01703289896249771, -1.2296998500823975, 0.5122153162956238, -0.16924428939819336, -0.415036678314209, -1.1979341506958008, 0.05831022188067436, -0.4411357045173645, 2.0713791847229004, 1.4611141681671143, -0.9357407093048096, -0.333297461271286, -0.676478385925293, 1.390028476715088, -0.05827632546424866, 1.535687804222107, 0.3060210347175598, -0.03171076253056526, -0.614985466003418, 1.2040390968322754, 0.31318482756614685, -1.2134959697723389, 0.13110508024692535, -1.4880926609039307, 1.7007993459701538, 1.5412729978561401, 0.09260450303554535, 0.7649128437042236, -0.5009126663208008, -0.5356241464614868, -0.069572813808918, -0.011717632412910461, 0.21314217150211334, -0.1985170543193817, -0.0223808903247118, 1.2128918170928955, 0.8334696888923645, 1.9029873609542847, -0.11491120606660843, -0.10303237289190292, -0.2467050403356552, 1.557223916053772, -1.1108328104019165, -0.9065343141555786, -0.2271333783864975, 0.6959827542304993, -0.48698121309280396, 0.5689510703086853, 1.115319013595581, -0.8907430768013, -0.24722427129745483, -0.7437837719917297, 0.6742106676101685, -1.7830933332443237]]}

Binary file not shown.

View File

@@ -0,0 +1,182 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
static mut LEN: usize = 4;
const K: usize = 11;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let len = unsafe { LEN };
let a = VarTensor::new_advice(cs, K, 1, len);
let b = VarTensor::new_advice(cs, K, 1, len);
let output = VarTensor::new_advice(cs, K, 1, len);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 2;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runmatmul() {
let i = 10;
let n = 10;
let j = 40;
let k = 10;
let mut a = Tensor::from((0..i * n * j).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[i, n, j]).unwrap();
// parameters
let mut b = Tensor::from((0..j * k).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[j, k]).unwrap();
let einsum = Einsum::<Fr>::new("inj,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runmatmul()
}

View File

@@ -191,6 +191,9 @@ struct PyRunArgs {
/// float: epsilon used for arguments that use division
#[pyo3(get, set)]
pub epsilon: f64,
/// bool: Whether to disable using Freivalds' argument in einsum operations
#[pyo3(get, set)]
pub disable_freivalds: bool,
}
/// default instantiation of PyRunArgs
@@ -225,6 +228,7 @@ impl From<PyRunArgs> for RunArgs {
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
epsilon: Some(py_run_args.epsilon),
disable_freivalds: py_run_args.disable_freivalds,
}
}
}
@@ -252,6 +256,7 @@ impl Into<PyRunArgs> for RunArgs {
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
epsilon: eps,
disable_freivalds: self.disable_freivalds,
}
}
}

View File

@@ -14,6 +14,7 @@ use tosubcommand::ToFlags;
use crate::{
circuit::{
chip::einsum::analysis::EinsumAnalysis,
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
},
@@ -24,6 +25,9 @@ use std::{collections::BTreeMap, marker::PhantomData};
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
use halo2curves::ff::{Field, PrimeField};
///
pub mod einsum;
#[allow(missing_docs)]
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
#[derive(
@@ -266,6 +270,8 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub range_checks: RangeChecks<F>,
/// [Selector]s for the shuffles
pub shuffles: Shuffles,
/// Einsum-specific configuration
pub einsums: Option<einsum::Einsums<F>>,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
@@ -280,6 +286,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
einsums: Some(einsum::Einsums::<F>::dummy(col_size, num_inner_cols)),
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
shared_table_inputs: vec![],
_marker: PhantomData,
}
}
/// Returns a new [BaseConfig] with no inputs, no selectors, no tables, and no Freivalds' argument.
pub fn dummy_without_freivalds(col_size: usize, num_inner_cols: usize) -> Self {
Self {
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
einsums: None,
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
@@ -414,6 +436,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
},
static_lookups: StaticLookups::default(),
dynamic_lookups: DynamicLookups::default(),
einsums: None,
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
shared_table_inputs: vec![],
@@ -688,6 +711,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
Ok(())
}
/// Configures and creates einsums
#[allow(clippy::too_many_arguments)]
pub fn configure_einsums(
&mut self,
cs: &mut ConstraintSystem<F>,
analysis: &EinsumAnalysis,
num_inner_cols: usize,
logrows: usize,
) -> Result<(), CircuitError>
where
F: Field,
{
self.einsums = Some(einsum::Einsums::configure_universal(
cs,
analysis,
num_inner_cols,
logrows,
));
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_shuffles(

View File

@@ -0,0 +1,210 @@
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use crate::circuit::{
einsum::reduction_planner::{self, Reduction},
CircuitError,
};
///
#[derive(Debug, Clone)]
pub struct EinsumAnalysis {
/// max size of input tensors
pub max_input_size: usize,
/// max size of output tensors
pub max_output_size: usize,
/// max number of input tensors
pub max_num_inputs: usize,
/// max number of output axes
pub max_num_output_axes: usize,
/// the sum of the lengths of dot product to compute all the reductions
pub reduction_length: usize,
}
/// The strategy to use for einsum
#[derive(Debug, Clone)]
pub enum EinsumStrategy {
/// Use only base ops
BaseOps,
/// Use Freivalds' argument
Freivalds,
}
///
#[derive(Debug, Clone)]
pub struct SingleEquationAnalysis {
///
pub equation: String,
///
pub num_inputs: usize,
///
pub max_input_size: usize,
///
pub output_size: usize,
///
pub num_output_axes: usize,
///
pub output_indices: Vec<char>,
/// the length of dot product to compute all the reductions
pub reduction_length: usize,
/// the strategy to use for einsum
pub strategy: EinsumStrategy,
}
///
pub fn analyze_einsum_usage(
equations: &HashMap<(usize, String), HashMap<char, usize>>,
) -> Result<EinsumAnalysis, CircuitError> {
let mut max_num_inputs = 0;
let mut max_input_size = 0;
let mut max_output_size = 0;
let mut max_num_output_axes = 0;
let mut reduction_length = 0;
for ((_, equation), input_axes_to_dim) in equations.iter() {
let analysis = analyze_single_equation(equation, input_axes_to_dim)?;
max_input_size = max_input_size.max(analysis.max_input_size);
max_output_size = max_output_size.max(analysis.output_size);
max_num_inputs = max_num_inputs.max(analysis.num_inputs);
max_num_output_axes = max_num_output_axes.max(analysis.num_output_axes);
reduction_length += analysis.reduction_length;
}
Ok(EinsumAnalysis {
max_input_size,
max_output_size,
max_num_inputs,
max_num_output_axes,
reduction_length,
})
}
///
pub fn analyze_single_equation(
equation: &str,
input_axes_to_dim: &HashMap<char, usize>,
) -> Result<SingleEquationAnalysis, CircuitError> {
// Sanitise equation to remove trivial axes
let equation = {
let (inputs_str, output_str) = equation.split_once("->").unwrap();
let input_equations: Vec<&str> = inputs_str.split(',').collect();
let inputs: Vec<String> = input_equations
.iter()
.map(|input| {
input
.chars()
.filter(|char| *input_axes_to_dim.get(char).unwrap() > 1)
.collect()
})
.collect();
let output = output_str
.chars()
.filter(|c| {
input_axes_to_dim.get(c).is_some() && *input_axes_to_dim.get(c).unwrap() > 1
})
.collect();
[inputs.join(","), output].join("->")
};
let (inputs_eq, output_eq) = equation.split_once("->").unwrap();
let input_equations: Vec<&str> = inputs_eq.split(',').collect();
let max_input_size = input_equations
.iter()
.map(|eqn| {
eqn.chars()
.map(|c| input_axes_to_dim.get(&c).unwrap())
.product()
})
.max()
.unwrap();
let output_indices: Vec<char> = output_eq.chars().collect();
let output_dims = output_indices
.iter()
.map(|c| input_axes_to_dim.get(&c).unwrap());
let output_size = output_dims.clone().product();
let output_reduction_length = {
let mut output_dims = output_dims.rev().cloned().collect_vec();
let mut total_length = 0;
for _ in 0..output_dims.len() {
let dot_product_len = output_dims.remove(0);
let num_dot_products: usize = output_dims.iter().product();
total_length += dot_product_len * num_dot_products;
}
total_length
};
let input_reductions_length = {
let input_reductions = reduction_planner::input_reductions(&equation)?;
input_reductions
.into_iter()
.map(|reduction| {
let (_, output_expr) = reduction.expression().split_once("->").unwrap();
let num_inputs = reduction.input_indices().len();
let dot_product_len = match reduction {
Reduction::RLC { axis, .. } => *input_axes_to_dim.get(&axis).unwrap(),
Reduction::Contraction { axis, .. } => *axis
.and_then(|axis| input_axes_to_dim.get(&axis))
.unwrap_or(&1),
};
let num_dot_products: usize = output_expr
.chars()
.map(|c| input_axes_to_dim.get(&c).unwrap())
.product();
// since `multi_dot` does pairwise mult between input pairs and final summation
if num_inputs <= 2 {
num_dot_products * dot_product_len
} else {
num_dot_products * (dot_product_len * num_inputs)
}
})
.sum::<usize>()
};
let dispatch_to_einsum_with_base_ops = {
let mut seen = HashSet::new();
let mut common_indices_to_inputs = vec![];
for input in input_equations.iter() {
for c in input.chars() {
if !seen.contains(&c) {
seen.insert(c);
} else {
common_indices_to_inputs.push(c);
}
}
}
let non_common_indices = input_axes_to_dim
.keys()
.filter(|&x| {
!common_indices_to_inputs.contains(x)
&& input_axes_to_dim.get(x).cloned().unwrap() > 1
})
.collect::<Vec<_>>();
!(output_indices.len() > 0
&& common_indices_to_inputs.len() > 0
&& non_common_indices.len() > 1)
};
let strategy = if dispatch_to_einsum_with_base_ops {
EinsumStrategy::BaseOps
} else {
EinsumStrategy::Freivalds
};
Ok(SingleEquationAnalysis {
output_size,
max_input_size,
equation: equation.to_string(),
num_inputs: input_equations.len(),
num_output_axes: output_indices.len(),
output_indices,
reduction_length: output_reduction_length + input_reductions_length,
strategy,
})
}

View File

@@ -0,0 +1,54 @@
use std::{collections::HashMap, marker::PhantomData};
use halo2_proofs::circuit::Value;
use halo2curves::ff::PrimeField;
use crate::{
circuit::CircuitError,
tensor::{Tensor, TensorError, TensorType},
};
/// Circuit parameter for a single einsum equation
#[derive(Clone, Debug, Default)]
pub struct SingleEinsumParams<F: PrimeField + TensorType + PartialOrd> {
///
pub equation: String,
/// Map from input axes to dimensions
pub input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> SingleEinsumParams<F> {
///
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}

View File

@@ -0,0 +1,359 @@
use halo2curves::ff::PrimeField;
use log::{error, trace};
use crate::{
circuit::{base::BaseOp, einsum::BaseOpInfo, region::RegionCtx, CheckMode, CircuitError},
tensor::{
get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
TensorError, TensorType, ValTensor, ValType,
},
};
use super::ContractionConfig;
/// Pairwise (elementwise) op layout
pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 2],
op: BaseOp,
phases: &[usize; 2],
) -> Result<ValTensor<F>, CircuitError> {
let (mut lhs, mut rhs) = if phases[0] <= phases[1] {
(values[0].clone(), values[1].clone())
} else {
(values[1].clone(), values[0].clone())
};
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
lhs.expand(&broadcasted_shape)?;
rhs.expand(&broadcasted_shape)?;
if lhs.len() != rhs.len() {
return Err(CircuitError::DimMismatch(format!(
"pairwise {} layout",
op.as_str()
)));
}
region.flush_einsum()?;
let input_vars = config.get_input_vars(phases.as_slice().into());
let output_var = config.get_output_var(phases.as_slice().into());
let inputs = [lhs, rhs]
.iter()
.zip(input_vars)
.map(|(val, var)| {
let res = region.assign_einsum(var, val)?;
Ok(res.get_inner()?)
})
.collect::<Result<Vec<_>, CircuitError>>()?;
// Now we can assign the dot product
// time the calc
let op_result = match op {
BaseOp::Add => add(&inputs),
BaseOp::Sub => sub(&inputs),
BaseOp::Mult => mult(&inputs),
_ => return Err(CircuitError::UnsupportedOp),
}
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let assigned_len = op_result.len();
let mut output = region.assign_einsum(output_var, &op_result.into())?;
// Enable the selectors
if !region.is_dummy() {
(0..assigned_len)
.map(|i| {
let (x, y, z) = output_var.cartesian_coord(region.einsum_col_coord() + i);
let op_info = BaseOpInfo {
op_kind: op.clone(),
input_phases: phases.as_slice().into(),
};
let selector = config.selectors.get(&(op_info, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
region.increment_einsum_col_coord(assigned_len);
output.reshape(&broadcasted_shape)?;
Ok(output)
}
pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 1],
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() == 1 {
return Ok(values[0].clone());
}
assert!(phase == 0 || phase == 1);
region.flush_einsum()?;
let mut input = values[0].clone();
let block_width = config.block_width();
let assigned_len: usize;
let input = {
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let var = config.get_input_vars([phase].as_slice().into())[0];
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
res.get_inner()?
};
// Now we can assign the dot product
let accumulated_sum = accumulated::sum(&input, block_width)?;
let output_var = config.get_output_var([phase].as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
output_var,
&accumulated_sum.into(),
check_mode,
)?;
// enable the selectors
if !region.is_dummy() {
for i in 0..output_assigned_len {
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// skip over duplicates at start of column
if z == 0 && i > 0 {
continue;
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::SumInit,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::Sum,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
}
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
// last element is the result
Ok(last_elem)
}
pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 1],
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert!(phase == 0 || phase == 1);
region.flush_einsum()?;
let block_width = config.block_width();
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
let var = config.get_input_vars([phase].as_slice().into())[0];
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
res.get_inner()?
};
// Now we can assign the dot product
let accumulated_prod = accumulated::prod(&input, block_width)?;
let output_var = config.get_output_var([phase].as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
output_var,
&accumulated_prod.into(),
check_mode,
)?;
// enable the selectors
if !region.is_dummy() {
(0..output_assigned_len)
.map(|i| {
let (x, _, z) =
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// skip over duplicates at start of column
if z == 0 && i > 0 {
return Ok(());
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::CumProdInit,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::CumProd,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
// last element is the result
Ok(last_elem)
}
pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 2],
phases: &[usize; 2],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
region.flush_einsum()?;
// time this entire function run
let global_start = instant::Instant::now();
let mut values = if phases[0] <= phases[1] {
[values[0].clone(), values[1].clone()]
} else {
[values[1].clone(), values[0].clone()]
};
let vars = config.get_input_vars(phases.as_slice().into());
let mut inputs = vec![];
let block_width = config.block_width();
let mut assigned_len = 0;
for (val, var) in values.iter_mut().zip(vars) {
val.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let inp = {
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?;
assigned_len = len;
res.get_inner()?
};
inputs.push(inp);
}
// Now we can assign the dot product
// time this step
let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?;
let output_var = config.get_output_var(phases.as_slice().into());
let (output, output_assigned_len) = region
.assign_einsum_with_duplication_constrained(output_var, &accumulated_dot.into(), check_mode)
.expect("failed to assign einsum with duplication constrained");
// enable the selectors
if !region.is_dummy() {
(0..output_assigned_len)
.map(|i| {
let (x, _, z) =
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// hop over duplicates at start of column
if z == 0 && i > 0 {
return Ok(());
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::DotInit,
input_phases: phases.as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::Dot,
input_phases: phases.as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
let elapsed = global_start.elapsed();
trace!("dot layout took: {:?}, row {}", elapsed, region.row());
trace!("----------------------------");
Ok(last_elem)
}
/// Dot product of more than two tensors
pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
phases: &[usize],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert!(phases.iter().all(|phase| *phase == 0 || *phase == 1));
if !values.iter().all(|value| value.len() == values[0].len()) {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
// time this entire function run
let global_start = instant::Instant::now();
let values: Vec<ValTensor<F>> = values.iter().copied().cloned().collect();
// do pairwise dot product between intermediate tensor and the next tensor
let (intermediate, output_phase) = values
.into_iter()
.zip(phases.iter().cloned())
.reduce(|(intermediate, intermediate_phase), (input, phase)| {
(
pairwise(
config,
region,
&[&intermediate, &input],
BaseOp::Mult,
&[intermediate_phase, phase],
)
.unwrap(),
std::cmp::max(intermediate_phase, phase),
)
})
.unwrap();
let accumulated_dot = sum(config, region, &[&intermediate], output_phase, check_mode)?;
let last_elem = accumulated_dot.last()?;
let elapsed = global_start.elapsed();
trace!("multi_dot layout took: {:?}, row {}", elapsed, region.row());
trace!("----------------------------");
Ok(last_elem)
}

View File

@@ -0,0 +1,867 @@
use crate::circuit::base::BaseOp;
use crate::circuit::chip::einsum::analysis::{analyze_single_equation, EinsumAnalysis};
use crate::circuit::einsum::layouts::{pairwise, sum};
use crate::circuit::einsum::reduction_planner::Reduction;
use crate::circuit::layouts::einsum_with_base_ops;
use crate::circuit::region::RegionCtx;
use crate::circuit::{BaseConfig, CheckMode, CircuitError};
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
Challenge, ConstraintSystem, Constraints, Expression, FirstPhase, Selector,
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use layouts::{dot, multi_dot, prod};
use std::collections::{BTreeMap, HashMap};
use std::marker::PhantomData;
///
pub mod analysis;
///
pub mod circuit_params;
mod layouts;
mod reduction_planner;
/// The maximum number of challenges
pub const NUM_MAX_EINSUM_CHALLENGES: usize = 10;
/// A struct representing reductions for the einsums
#[derive(Clone, Debug, Default)]
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
/// custom gate to constrain tensor contractions
contraction_gate: ContractionConfig<F>,
/// custom gate to constrain random linear combinations used by Freivalds' argument
rlc_gates: Vec<RLCConfig<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
///
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let dummy_contraction_gate = ContractionConfig {
inputs: [
[dummy_var.clone(), dummy_var.clone()],
[dummy_var.clone(), dummy_var.clone()],
],
outputs: [dummy_var.clone(), dummy_var.clone()],
selectors: BTreeMap::default(),
_marker: PhantomData,
};
Self {
contraction_gate: dummy_contraction_gate,
rlc_gates: (0..NUM_MAX_EINSUM_CHALLENGES)
.map(|_| RLCConfig::dummy(&dummy_var))
.collect(),
}
}
/// Configure the columns based on universal Einsum analysis
pub fn configure_universal(
meta: &mut ConstraintSystem<F>,
analysis: &EinsumAnalysis,
num_inner_cols: usize,
logrows: usize,
) -> Self {
let capacity = analysis.reduction_length;
let inputs: [VarTensor; 4] = [
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
];
let outputs = [
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
];
let contraction_gate = ContractionConfig::new(
meta,
&[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]],
&[&outputs[0], &outputs[1]],
);
let mut rlc_gates = vec![];
for _ in 0..analysis.max_num_output_axes {
let rlc_gate =
RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &outputs[1]);
rlc_gates.push(rlc_gate);
}
Self {
contraction_gate,
rlc_gates,
}
}
/// In dummy layout phase, calling this function will return error
pub fn challenges(&self) -> Result<Vec<Challenge>, CircuitError> {
self.rlc_gates
.iter()
.map(|gate| gate.challenge.ok_or(CircuitError::ChallengeNotSet))
.collect::<Result<Vec<_>, _>>()
}
///
pub fn assign_einsum(
&self,
base_config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input_tensors: &[&ValTensor<F>],
output_tensor: &ValTensor<F>,
equation: &str,
check_mode: &CheckMode,
) -> Result<(), CircuitError> {
region.set_num_einsum_inner_cols(self.contraction_gate.block_width());
let (input_exprs, _) = equation.split_once("->").unwrap();
let input_exprs = input_exprs.split(",").collect_vec();
assert_eq!(input_exprs.len(), input_tensors.len());
let mut input_tensors = input_tensors.iter().copied().cloned().collect_vec();
let mut output_tensor = output_tensor.clone();
let mut input_axes_to_dim: HashMap<char, usize> = HashMap::new();
input_exprs
.iter()
.zip(input_tensors.iter())
.for_each(|(indices, tensor)| {
indices.chars().zip(tensor.dims()).for_each(|(index, dim)| {
if let std::collections::hash_map::Entry::Vacant(e) =
input_axes_to_dim.entry(index)
{
e.insert(*dim);
}
});
});
let equation_analysis = analyze_single_equation(&equation, &input_axes_to_dim)?;
let equation = equation_analysis.equation;
// Remove trivial axes from tensors
input_tensors
.iter_mut()
.map(|tensor| tensor.remove_trivial_axes())
.collect::<Result<Vec<_>, TensorError>>()?;
output_tensor.remove_trivial_axes()?;
if matches!(
equation_analysis.strategy,
analysis::EinsumStrategy::BaseOps
) {
let _ = einsum_with_base_ops(
base_config,
region,
&input_tensors.iter().collect_vec(),
&equation,
)?;
return Ok(());
}
let output_shape = equation_analysis
.output_indices
.iter()
.map(|c| input_axes_to_dim.get(c).copied().unwrap())
.collect_vec();
let squashed_output =
self.assign_output(region, &output_tensor, output_shape, check_mode)?;
// reorder the reduction of input tensors and reduce
let reordered_input_reductions = reduction_planner::input_reductions(&equation).unwrap();
let mut tensors = input_tensors;
let mut reduced_input_phase = 0;
for reduction in reordered_input_reductions.iter() {
let (input_expr, output_expr) = reduction.expression().split_once("->").unwrap();
let input_exprs = input_expr.split(",").collect_vec();
let remaining_axes = output_expr.chars().collect_vec();
let mut remaining_axes_indices = remaining_axes
.iter()
.map(|c| 0..input_axes_to_dim[c])
.multi_cartesian_product()
.collect_vec();
// Dummy value to ensure the for loop runs at least once
if remaining_axes.is_empty() {
remaining_axes_indices.push(vec![]);
}
let input_tensors = reduction
.input_indices()
.iter()
.map(|idx| tensors[*idx].clone())
.collect_vec();
let mut flattened_input_tensors: Vec<Vec<ValTensor<F>>> =
vec![vec![]; input_tensors.len()];
for remaining_axes_indices in remaining_axes_indices {
// corresponds to 1 running sum of input tensors
for (i, (input_tensor, input_expr)) in
input_tensors.iter().zip(input_exprs.iter()).enumerate()
{
let mut sliced_dim = vec![];
input_expr.chars().for_each(|axis| {
if let Some(pos) = remaining_axes.iter().position(|c| *c == axis) {
sliced_dim
.push(remaining_axes_indices[pos]..remaining_axes_indices[pos] + 1);
} else {
// common axis
sliced_dim.push(0..input_axes_to_dim[&axis]);
}
});
let mut sliced_input_tensor = input_tensor.get_slice(&sliced_dim)?;
sliced_input_tensor.flatten();
flattened_input_tensors[i].push(sliced_input_tensor);
}
}
let flattened_input_tensors = flattened_input_tensors
.into_iter()
.map(|tensors| {
ValTensor::from(
tensors
.into_iter()
.flat_map(|t| t.get_inner_tensor().unwrap().clone().into_iter())
.collect_vec(),
)
})
.collect_vec();
let output_dims = output_expr
.chars()
.map(|c| input_axes_to_dim[&c])
.collect_vec();
let contracted_output = match reduction {
Reduction::RLC {
axis,
input_phase,
challenge_index,
..
} => {
assert_eq!(flattened_input_tensors.len(), 1);
let rlc_len = input_axes_to_dim[axis];
let mut result = self.rlc_gates[*challenge_index].assign_rlc(
region,
&flattened_input_tensors[0],
region.challenges()[*challenge_index],
rlc_len,
*input_phase,
check_mode,
)?;
result.reshape(&output_dims)?;
result
}
Reduction::Contraction {
axis, input_phases, ..
} => match axis {
Some(axis) => {
let dot_product_len = input_axes_to_dim[axis];
assign_input_contraction(
&self.contraction_gate,
region,
flattened_input_tensors,
dot_product_len,
&output_dims,
input_phases,
check_mode,
)?
}
None => {
let mut result = assign_pairwise_mult(
&self.contraction_gate,
region,
flattened_input_tensors,
input_phases,
)?;
result.reshape(&output_dims)?;
result
}
},
};
tensors.push(contracted_output);
reduced_input_phase = reduction.output_phase();
}
tensors.retain(|tensor| tensor.is_singleton());
let scalars: ValTensor<F> = tensors
.into_iter()
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
.collect_vec()
.into();
let squashed_input = prod(
&self.contraction_gate,
region,
&[&scalars],
reduced_input_phase,
check_mode,
)?;
region.constrain_equal(&squashed_input, &squashed_output)
}
fn assign_output(
&self,
region: &mut RegionCtx<F>,
output: &ValTensor<F>,
output_shape: Vec<usize>,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
let mut intermediate_values = output.clone();
let challenges = region
.challenges()
.iter()
.take(output_shape.len())
.copied()
.collect_vec();
// Loop over the output axes
for (idx, (rlc_config, challenge)) in self
.rlc_gates
.iter()
.take(output_shape.len())
.zip(challenges.iter())
.rev()
.enumerate()
{
let rlc_len = output_shape[output_shape.len() - idx - 1];
intermediate_values.flatten();
let phase = if idx > 0 { 1 } else { 0 };
intermediate_values = rlc_config.assign_rlc(
region,
&intermediate_values,
*challenge,
rlc_len,
phase,
check_mode,
)?;
}
let phase = if challenges.len() > 0 { 1 } else { 0 };
let output_var = self
.contraction_gate
.get_output_var([phase].as_slice().into());
let res = region.assign_einsum(output_var, &intermediate_values)?;
region.increment_einsum_col_coord(1);
Ok(res)
}
}
fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
flattened_tensors: Vec<ValTensor<F>>,
input_phases: &[usize],
) -> Result<ValTensor<F>, CircuitError> {
assert_eq!(flattened_tensors.len(), input_phases.len());
let (result, _) = flattened_tensors
.into_iter()
.zip(input_phases.iter().cloned())
.reduce(|(acc, acc_phase), (input, phase)| {
(
pairwise(
config,
region,
&[&acc, &input],
BaseOp::Mult,
&[acc_phase, phase],
)
.unwrap(),
std::cmp::max(acc_phase, phase),
)
})
.unwrap();
Ok(result)
}
fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
flattened_tensors: Vec<ValTensor<F>>,
dot_product_len: usize,
output_shape: &[usize],
input_phases: &[usize],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert_eq!(flattened_tensors.len(), input_phases.len());
let num_dot_products = output_shape.iter().product();
let mut dot_product_results = vec![];
for chunk_idx in 0..num_dot_products {
let start = chunk_idx * dot_product_len;
let tensors: Vec<_> = flattened_tensors
.iter()
.map(|tensor| tensor.get_slice(&[start..(start + dot_product_len)]))
.collect::<Result<Vec<_>, TensorError>>()?;
let result = if tensors.len() == 1 {
sum(config, region, &[&tensors[0]], input_phases[0], check_mode)?
} else if tensors.len() == 2 {
dot(
config,
region,
&[&tensors[0], &tensors[1]],
&[input_phases[0], input_phases[1]],
check_mode,
)?
} else {
multi_dot(
config,
region,
tensors.iter().collect_vec().as_slice(),
input_phases,
check_mode,
)?
};
dot_product_results.push(result.get_inner_tensor()?.get_scalar());
}
let mut tensor = ValTensor::from(dot_product_results);
tensor.reshape(output_shape)?;
Ok(tensor)
}
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)]
enum InputPhases {
FirstPhase,
SecondPhase,
BothFirstPhase, // [0, 0]
Mixed, // [0, 1] or [1, 0]
BothSecondPhase, // [1, 1]
}
impl From<&[usize]> for InputPhases {
fn from(phases: &[usize]) -> Self {
match phases {
[0] => Self::FirstPhase,
[1] => Self::SecondPhase,
[0, 0] => Self::BothFirstPhase,
[0, 1] | [1, 0] => Self::Mixed,
[1, 1] => Self::BothSecondPhase,
_ => panic!("Invalid phase combination"),
}
}
}
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
struct BaseOpInfo {
pub op_kind: BaseOp,
pub input_phases: InputPhases,
}
/// `ContractionConfig` is the custom gate to constrain tensor contractions
#[derive(Clone, Debug, Default)]
struct ContractionConfig<F: PrimeField + TensorType + PartialOrd> {
// [[first phase, first phase], [second phase, second phase]]
inputs: [[VarTensor; 2]; 2],
// [first phase, second phase]
outputs: [VarTensor; 2],
// (BaseOpInfo, block index, inner column index) -> selector
selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
fn get_input_vars(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
match input_phases {
InputPhases::FirstPhase => vec![&self.inputs[0][0]],
InputPhases::SecondPhase => vec![&self.inputs[1][0]],
InputPhases::BothFirstPhase => vec![&self.inputs[0][0], &self.inputs[0][1]],
InputPhases::Mixed => vec![&self.inputs[0][0], &self.inputs[1][0]],
InputPhases::BothSecondPhase => vec![&self.inputs[1][0], &self.inputs[1][1]],
}
}
fn get_output_var(&self, input_phases: InputPhases) -> &VarTensor {
match input_phases {
InputPhases::FirstPhase => &self.outputs[0],
InputPhases::SecondPhase => &self.outputs[1],
InputPhases::BothFirstPhase => &self.outputs[0],
InputPhases::Mixed => &self.outputs[1],
InputPhases::BothSecondPhase => &self.outputs[1],
}
}
fn block_width(&self) -> usize {
self.outputs[0].num_inner_cols()
}
fn new(
meta: &mut ConstraintSystem<F>,
inputs: &[&[&VarTensor; 2]; 2],
outputs: &[&VarTensor; 2],
) -> Self {
let mut selectors = BTreeMap::new();
let num_blocks = outputs[0].num_blocks();
let block_width = outputs[0].num_inner_cols();
for input_phases in [
InputPhases::BothFirstPhase,
InputPhases::Mixed,
InputPhases::BothSecondPhase,
] {
for i in 0..num_blocks {
for j in 0..block_width {
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Mult,
input_phases,
},
i,
j,
),
meta.selector(),
);
}
for i in 0..num_blocks {
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::DotInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Dot,
input_phases,
},
i,
0,
),
meta.selector(),
);
}
}
}
for input_phases in [InputPhases::FirstPhase, InputPhases::SecondPhase] {
for i in 0..num_blocks {
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::SumInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Sum,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::CumProdInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::CumProd,
input_phases,
},
i,
0,
),
meta.selector(),
);
}
}
for ((base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
let inputs = match base_op.input_phases {
InputPhases::FirstPhase => vec![inputs[0][0]],
InputPhases::SecondPhase => vec![inputs[1][0]],
InputPhases::BothFirstPhase => vec![inputs[0][0], inputs[0][1]],
InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]],
InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]],
};
let output = match base_op.input_phases {
InputPhases::FirstPhase => outputs[0],
InputPhases::SecondPhase => outputs[1],
InputPhases::BothFirstPhase => outputs[0],
InputPhases::Mixed => outputs[1],
InputPhases::BothSecondPhase => outputs[1],
};
assert_eq!(inputs.len(), base_op.op_kind.num_inputs());
match base_op.op_kind {
BaseOp::Mult => {
meta.create_gate(base_op.op_kind.as_str(), |meta| {
let selector = meta.query_selector(*selector);
let zero = Expression::<F>::Constant(F::ZERO);
let mut qis = vec![zero; 2];
for (q_i, input) in qis.iter_mut().zip(inputs) {
*q_i = input
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("contraction config: input query failed")[0]
.clone()
}
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("contraction config: output query failed");
let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res]
};
Constraints::with_selector(selector, constraints)
});
}
_ => {
meta.create_gate(base_op.op_kind.as_str(), |meta| {
let selector = meta.query_selector(*selector);
let mut qis = vec![vec![]; 2];
for (q_i, input) in qis.iter_mut().zip(inputs) {
*q_i = input
.query_whole_block(meta, *block_idx, 0, 1)
.expect("contraction config: input query failed")
.into_iter()
.collect()
}
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
.expect("contraction config: output query failed");
let res = base_op.op_kind.accum_f(
expected_output[0].clone(),
qis[1].clone(),
qis[0].clone(),
);
let constraints =
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res];
Constraints::with_selector(selector, constraints)
});
}
}
}
let first_phase_inputs: [VarTensor; 2] = inputs[0]
.iter()
.copied()
.cloned()
.collect_vec()
.try_into()
.unwrap();
let second_phase_inputs: [VarTensor; 2] = inputs[1]
.iter()
.copied()
.cloned()
.collect_vec()
.try_into()
.unwrap();
Self {
inputs: [first_phase_inputs, second_phase_inputs],
outputs: [outputs[0].clone(), outputs[1].clone()],
selectors,
_marker: PhantomData,
}
}
}
/// `RLCConfig` is the custom gate used for random linear combination with the specific challenge
#[derive(Clone, Debug)]
struct RLCConfig<F: PrimeField + TensorType + PartialOrd> {
/// The challenge used for the random linear combination
/// Challenge is `None` in the dummy configuration
pub challenge: Option<Challenge>,
/// [first phase, second phase]
pub inputs: [VarTensor; 2],
pub output: VarTensor,
/// (phase of input, block index) -> (init selector, acc selector)
pub selectors: BTreeMap<(usize, usize), (Selector, Selector)>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
fn dummy(dummy_var: &VarTensor) -> Self {
let challenge = None;
let inputs = [dummy_var.clone(), dummy_var.clone()];
let output = dummy_var.clone();
let selectors = BTreeMap::new();
Self {
challenge,
inputs,
output,
selectors,
_marker: PhantomData,
}
}
fn new(meta: &mut ConstraintSystem<F>, inputs: &[VarTensor; 2], output: &VarTensor) -> Self {
let challenge = meta.challenge_usable_after(FirstPhase);
let mut selectors = BTreeMap::new();
for (phase, input) in inputs.iter().enumerate() {
for block_idx in 0..input.num_blocks() {
let selector = (meta.selector(), meta.selector());
selectors.insert((phase, block_idx), selector);
}
}
let block_width = output.num_inner_cols();
let powers_of_challenge = (0..block_width)
.scan(Expression::Constant(F::ONE), |r_power, _| {
*r_power = r_power.clone() * challenge.expr();
Some(r_power.clone())
})
.collect_vec();
for ((phase, block_idx), (init_selector, acc_selector)) in selectors.iter() {
meta.create_gate("init", |meta| {
let selector = meta.query_selector(*init_selector);
let input_exprs = inputs[*phase]
.query_whole_block(meta, *block_idx, 0, 1)
.expect("rlc config: input query failed")
.into_iter()
.collect();
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, 0, 0, 1)
.expect("rlc config: output query failed");
let res = BaseOp::Dot.accum_f(
Expression::Constant(F::ZERO),
powers_of_challenge.iter().cloned().rev().collect_vec(),
input_exprs,
);
vec![expected_output[0].clone() - res]
};
Constraints::with_selector(selector, constraints)
});
meta.create_gate("acc", |meta| {
let selector = meta.query_selector(*acc_selector);
let input_exprs = inputs[*phase]
.query_whole_block(meta, *block_idx, 0, 1)
.expect("rlc config: input query failed")
.into_iter()
.collect();
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, 0, -1, 2)
.expect("rlc config: output query failed");
let res = BaseOp::Dot.accum_f(
expected_output[0].clone() * powers_of_challenge.last().cloned().unwrap(),
powers_of_challenge.iter().cloned().rev().collect_vec(),
input_exprs,
);
vec![expected_output[1].clone() - res]
};
Constraints::with_selector(selector, constraints)
});
}
Self {
inputs: inputs.clone(),
output: output.clone(),
selectors,
challenge: Some(challenge),
_marker: PhantomData,
}
}
fn assign_rlc(
&self,
region: &mut RegionCtx<F>,
flattened_input: &ValTensor<F>,
challenge: Value<F>,
rlc_len: usize,
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
region.flush_einsum()?;
let block_width = self.output.num_inner_cols();
let powers_of_challenge = (0..block_width)
.scan(Value::known(F::ONE), |challenge_power, _| {
*challenge_power = challenge_power.clone() * challenge;
Some(challenge_power.clone())
})
.collect_vec();
let mut rlc_results: Vec<ValType<F>> = vec![];
for tensor in flattened_input.get_inner_tensor()?.chunks_exact(rlc_len) {
let running_sums = tensor
.iter()
.chunks(block_width)
.into_iter()
.scan(Value::known(F::ZERO), |state, val| {
let curr_sum: Value<F> = val
.into_iter()
.zip(powers_of_challenge.iter().rev())
.map(|(v, c_power)| {
c_power.and_then(|c_power| {
v.get_felt_eval()
.and_then(|v| Some(Value::known(c_power * v)))
.unwrap_or(Value::unknown())
})
})
.reduce(|acc, v| acc + v)
.unwrap();
*state = *state * powers_of_challenge.last().unwrap() + curr_sum;
Some(*state)
})
.collect_vec();
let assigned_len = {
let mut input: ValTensor<F> = tensor.iter().collect_vec().into();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let (_, len) = region
.assign_einsum_with_duplication_unconstrained(&self.inputs[phase], &input)?;
len
};
let (assigned_output, assigned_output_len) = {
let running_sums = running_sums.into_iter().map(ValType::from).collect_vec();
region.assign_einsum_with_duplication_constrained(
&self.output,
&running_sums.into(),
check_mode,
)?
};
(0..assigned_output_len)
.map(|i| {
let (block_idx, _, z) = self
.output
.cartesian_coord(region.einsum_col_coord() + i * block_width);
if z == 0 && i > 0 {
return Ok(());
}
let selector = if i == 0 {
self.selectors
.get(&(phase, block_idx))
.map(|(init, _)| init)
} else {
self.selectors.get(&(phase, block_idx)).map(|(_, acc)| acc)
};
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
rlc_results.push(assigned_output.last()?.get_inner_tensor()?.get_scalar());
region.increment_einsum_col_coord(assigned_len);
}
Ok(rlc_results.into())
}
}

View File

@@ -0,0 +1,205 @@
use std::{collections::BTreeSet, ops::Index};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use crate::{
circuit::CircuitError,
tensor::{TensorType, ValTensor},
};
/// inj,jk->ik [inj,jk]
/// inj,i->nj => RLC [jk,nj]
/// jk,k->j => RLC [nj,j]
/// nj,j->n => Contraction [n]
/// n-> => Contraction []
///
/// bn,anm,bm->ba [bn,anm,bm]
/// bn,bm->bnm => Contraction [anm,bnm]
/// bnm,b->nm => RLC [anm,nm]
/// anm,a->nm => RLC [nm,nm]
/// nm,nm->m => Contraction [m]
/// m-> => Contraction []
#[derive(Debug)]
pub enum Reduction {
/// Random linear combination with powers of challenge along the axis
RLC {
expression: String,
axis: char,
/// Uniquely identifying index of input tensor to be reduced
input_index: TensorIndex,
/// phase of input tensor
input_phase: usize,
/// phase of output tensor
output_phase: usize,
challenge_index: usize,
},
Contraction {
expression: String,
/// when axis is `None`, the contraction is pairwise multiplication
axis: Option<char>,
/// Uniquely identifying indices of input tensors to be contracted
input_indices: Vec<TensorIndex>,
/// phases of input tensors
input_phases: Vec<usize>,
/// phase of output tensor
output_phase: usize,
},
}
#[derive(Clone, Copy, Debug)]
pub struct TensorIndex(usize);
impl<T: PrimeField + TensorType + PartialOrd> Index<TensorIndex> for Vec<ValTensor<T>> {
type Output = ValTensor<T>;
fn index(&self, index: TensorIndex) -> &Self::Output {
&self[index.0]
}
}
impl Reduction {
pub fn expression(&self) -> &str {
match self {
Reduction::Contraction { expression, .. } => expression,
Reduction::RLC { expression, .. } => &expression,
}
}
pub fn input_indices(&self) -> Vec<TensorIndex> {
match self {
Reduction::Contraction { input_indices, .. } => input_indices.clone(),
Reduction::RLC { input_index, .. } => vec![*input_index],
}
}
pub fn output_phase(&self) -> usize {
match self {
Reduction::Contraction { output_phase, .. } => *output_phase,
Reduction::RLC { output_phase, .. } => *output_phase,
}
}
}
pub fn input_reductions(expression: &str) -> Result<Vec<Reduction>, CircuitError> {
let (input_exprs, output_expr) = expression.split_once("->").unwrap();
let input_exprs: Vec<_> = input_exprs.split(",").map(|eq| eq.to_string()).collect();
// (phase, expression)
let input_exprs: Vec<(usize, String)> =
input_exprs.into_iter().map(|expr| (0, expr)).collect_vec();
let mut input_tensor_counter = input_exprs.len();
let mut input_exprs: Vec<((usize, String), TensorIndex)> = input_exprs
.into_iter()
.zip((0..input_tensor_counter).map(TensorIndex))
.collect();
let mut reductions: Vec<Reduction> = vec![];
// Reduce input_exprs along given axis
let mut reduce = |input_exprs: Vec<((usize, String), TensorIndex)>,
axis: char|
-> (Reduction, Vec<((usize, String), TensorIndex)>) {
let inputs = input_exprs
.iter()
.filter(|((_, eq), _)| eq.chars().contains(&axis))
.cloned()
.collect_vec();
let (inputs_axes, input_indices): (Vec<(usize, String)>, Vec<TensorIndex>) =
inputs.iter().cloned().unzip();
let (input_phases, inputs_axes): (Vec<usize>, Vec<String>) =
inputs_axes.into_iter().unzip();
let is_output_axis = output_expr.chars().contains(&axis);
let output: String = if is_output_axis == true && inputs.len() > 1 {
let output: BTreeSet<char> =
inputs_axes.iter().flat_map(|input| input.chars()).collect();
output.iter().collect()
} else {
let output: BTreeSet<char> = inputs_axes
.iter()
.flat_map(|input| input.chars().filter(|&c| c != axis))
.collect();
output.iter().collect()
};
let reduction = if is_output_axis == true && inputs.len() == 1 {
let mut expression = inputs_axes.join(",");
expression.push_str(format!(",{axis}").as_str());
expression.push_str("->");
expression.push_str(&output);
Reduction::RLC {
expression,
axis,
input_index: input_indices[0],
input_phase: input_phases[0],
output_phase: 1,
challenge_index: output_expr.chars().position(|c| c == axis).unwrap(),
}
} else if is_output_axis == true {
let mut expression = inputs_axes.join(",");
let output_phase = input_phases.iter().copied().max().unwrap();
expression.push_str("->");
expression.push_str(&output);
Reduction::Contraction {
expression,
axis: None,
input_indices: input_indices,
input_phases,
output_phase,
}
} else {
let mut expression = inputs_axes.join(",");
let output_phase = input_phases.iter().copied().max().unwrap();
expression.push_str("->");
expression.push_str(&output);
Reduction::Contraction {
expression,
axis: Some(axis),
input_indices: input_indices,
input_phases,
output_phase,
}
};
// Mutate input_exprs
let mut input_exprs = input_exprs.clone();
input_exprs.retain(|((_, input_eq), _)| !inputs_axes.contains(input_eq));
input_exprs.push((
(reduction.output_phase(), output.clone()),
TensorIndex(input_tensor_counter),
));
input_tensor_counter += 1;
(reduction, input_exprs)
};
let mut output_axes = output_expr.chars().collect_vec();
while let Some(axis) = output_axes.first().cloned() {
let num_inputs = input_exprs
.iter()
.filter(|((_, eq), _)| eq.chars().contains(&axis))
.count();
if num_inputs == 0 {
output_axes.remove(0);
} else {
let (reduction, new_input_exprs) = reduce(input_exprs, axis);
reductions.push(reduction);
input_exprs = new_input_exprs;
}
}
// These are not output axes and were not contracted with random vectors
let remaining_axes: BTreeSet<_> = input_exprs
.iter()
.flat_map(|((_, eq), _)| eq.chars())
.collect();
for axis in remaining_axes.iter() {
let (reduction, new_input_exprs) = reduce(input_exprs, *axis);
reductions.push(reduction);
input_exprs = new_input_exprs;
}
Ok(reductions)
}

View File

@@ -46,6 +46,9 @@ pub enum CircuitError {
/// Failed to get shuffle
#[error("failed to get shuffle for op: {0}")]
GetShuffleError(String),
/// Failed to get einsum
#[error("failed to get einsum for op: {0}")]
GetEinsumError(String),
/// Failed to get constants
#[error("failed to get constants for op: {0}")]
GetConstantsError(String),
@@ -61,6 +64,9 @@ pub enum CircuitError {
/// Missing product in einsum
#[error("missing product in einsum")]
MissingEinsumProduct,
/// Missing config in einsum
#[error("missing config in einsum")]
MissingEinsumConfig,
/// Mismatched lookup length
#[error("mismatched lookup lengths: {0} and {1}")]
MismatchedLookupLength(usize, usize),
@@ -109,4 +115,7 @@ pub enum CircuitError {
/// A decomposition base overflowed
#[error("decomposition base overflowed")]
DecompositionBaseOverflow,
/// Challenge not set
#[error("challenge not set")]
ChallengeNotSet,
}

View File

@@ -22,9 +22,8 @@ use crate::{
tensor::{
create_unit_tensor, get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
Tensor, TensorError, ValType,
DataFormat, KernelFormat, Tensor, TensorError, ValType,
},
tensor::{DataFormat, KernelFormat},
};
use super::*;
@@ -823,14 +822,73 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// let result = einsum::<Fp>(&dummy_config, &mut dummy_region, &[&x, &k], "mk,n->ma").unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1773]), &[1, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// ```
///
pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
inputs: &[&ValTensor<F>],
equation: &str,
) -> Result<ValTensor<F>, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut indices_to_size = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) {
e.insert(input.dims()[j]);
} else if indices_to_size[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
// Track the einsum equation
region.add_used_einsum_equation(equation.to_string(), &indices_to_size)?;
if config.einsums.is_none() {
return einsum_with_base_ops(config, region, inputs, equation);
}
let input_values = inputs
.iter()
.map(|t| t.get_inner())
.collect::<Result<Vec<_>, TensorError>>()?;
let (output_tensor, _) =
crate::tensor::ops::accumulated::einsum(equation, &input_values.iter().collect_vec())?;
config.einsums.as_ref().unwrap().assign_einsum(
config,
region,
inputs,
&output_tensor.clone().into(),
equation,
&config.check_mode,
)?;
let output: ValTensor<F> = output_tensor.into();
region.increment_einsum_index(1);
Ok(output)
}
/// einsum with base ops
pub fn einsum_with_base_ops<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
inputs: &[&ValTensor<F>],
equation: &str,
) -> Result<ValTensor<F>, CircuitError> {
let mut equation = equation.split("->");
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
@@ -1036,6 +1094,8 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let output: ValTensor<F> = output.into();
region.increment_einsum_index(1);
Ok(output)
}

View File

@@ -1,12 +1,12 @@
use crate::{
circuit::table::Range,
circuit::{einsum::NUM_MAX_EINSUM_CHALLENGES, table::Range},
fieldutils::IntegerRep,
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored::Colorize;
use halo2_proofs::{
circuit::Region,
circuit::{Region, Value},
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
@@ -85,6 +85,45 @@ impl ShuffleIndex {
}
}
/// Einsum index
#[derive(Clone, Debug, Default)]
pub struct EinsumIndex {
index: usize,
col_coord: usize,
// (einsum equation, input axes to dimensions map)
equations: Vec<(String, HashMap<char, usize>)>,
num_inner_cols: usize,
}
impl EinsumIndex {
/// Create a new einsum index
pub fn new(index: usize, col_coord: usize, num_inner_cols: usize) -> EinsumIndex {
EinsumIndex {
index,
col_coord,
equations: Vec::new(),
num_inner_cols,
}
}
/// Get the einsum index
pub fn index(&self) -> usize {
self.index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
/// update with another einsum index
pub fn update(&mut self, other: &EinsumIndex) {
self.index += other.index;
self.col_coord += other.col_coord;
self.equations.extend(other.equations.clone());
}
}
#[derive(Debug, Clone)]
/// Some settings for a region to differentiate it across the different phases of proof generation
pub struct RegionSettings {
@@ -176,9 +215,11 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
einsum_index: EinsumIndex,
statistics: RegionStatistics,
settings: RegionSettings,
assigned_constants: ConstantsMap<F>,
challenges: Vec<Value<F>>,
max_dynamic_input_len: usize,
}
@@ -250,6 +291,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.shuffle_index.col_coord += n;
}
/// increment the einsum index
pub fn increment_einsum_index(&mut self, n: usize) {
self.einsum_index.index += n;
}
/// increment the einsum column coordinate
pub fn increment_einsum_col_coord(&mut self, n: usize) {
self.einsum_index.col_coord += n;
}
///
pub fn witness_gen(&self) -> bool {
self.settings.witness_gen
@@ -265,6 +316,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&self.statistics
}
///
pub fn challenges(&self) -> &[Value<F>] {
&self.challenges
}
/// Create a new region context
pub fn new(
region: Region<'a, F>,
@@ -283,9 +339,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
linear_coord,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
einsum_index: EinsumIndex::default(),
statistics: RegionStatistics::default(),
settings: RegionSettings::all_true(decomp_base, decomp_legs),
assigned_constants: HashMap::new(),
challenges: vec![],
max_dynamic_input_len: 0,
}
}
@@ -304,6 +362,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
new_self
}
/// Create a new region context with challenges
pub fn new_with_challenges(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
decomp_base: usize,
decomp_legs: usize,
challenges: Vec<Value<F>>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
new_self.challenges = challenges;
new_self
}
/// Create a new region context
pub fn new_dummy(
row: usize,
@@ -320,9 +392,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
einsum_index: EinsumIndex::default(),
statistics: RegionStatistics::default(),
settings,
assigned_constants: HashMap::new(),
challenges: vec![Value::unknown(); NUM_MAX_EINSUM_CHALLENGES],
max_dynamic_input_len: 0,
}
}
@@ -333,6 +407,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
linear_coord: usize,
num_inner_cols: usize,
settings: RegionSettings,
challenges: Vec<Value<F>>,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -342,9 +417,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
einsum_index: EinsumIndex::default(),
statistics: RegionStatistics::default(),
settings,
assigned_constants: HashMap::new(),
challenges,
max_dynamic_input_len: 0,
}
}
@@ -398,6 +475,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let einsum_index = Arc::new(Mutex::new(self.einsum_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output.par_enum_map(|idx, _| {
@@ -412,6 +490,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
starting_linear_coord,
self.num_inner_cols,
self.settings.clone(),
self.challenges.clone(),
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -430,6 +509,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the einsum index
let mut einsum_index = einsum_index.lock().unwrap();
einsum_index.update(&local_reg.einsum_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
@@ -450,6 +532,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
self.einsum_index = Arc::try_unwrap(einsum_index)
.map_err(|e| CircuitError::GetEinsumError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| CircuitError::GetEinsumError(format!("{:?}", e)))?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
.into_inner()
@@ -516,6 +602,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.update_max_min_lookup_range(range)
}
/// add used einsum equation
pub fn add_used_einsum_equation(
&mut self,
equation: String,
input_axes_to_dims: &HashMap<char, usize>,
) -> Result<(), CircuitError> {
self.einsum_index
.equations
.push((equation, input_axes_to_dims.clone()));
Ok(())
}
/// Get the offset
pub fn row(&self) -> usize {
self.row
@@ -551,6 +649,31 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.shuffle_index.col_coord
}
/// einsum index
pub fn einsum_index(&self) -> usize {
self.einsum_index.index
}
/// einsum column coordinate
pub fn einsum_col_coord(&self) -> usize {
self.einsum_index.col_coord
}
/// get used einsum equations
pub fn used_einsum_equations(&self) -> Vec<(String, HashMap<char, usize>)> {
self.einsum_index.equations.clone()
}
/// set the number of inner columns used in einsum custom gate
pub fn set_num_einsum_inner_cols(&mut self, num_inner_cols: usize) {
self.einsum_index.num_inner_cols = num_inner_cols;
}
/// number of inner columns used in einsum custom gate
pub fn num_einsum_inner_cols(&self) -> usize {
self.einsum_index.num_inner_cols
}
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.statistics.used_lookups.clone()
@@ -640,6 +763,28 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor in einsum area
pub fn assign_einsum(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
Ok(var.assign(
&mut region.borrow_mut(),
self.einsum_col_coord(),
values,
&mut self.assigned_constants,
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_unconstrained(
&mut self,
@@ -697,6 +842,63 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_einsum_with_duplication_unconstrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_unconstrained(
&mut region.borrow_mut(),
self.einsum_col_coord(),
values,
&mut self.assigned_constants,
)?;
Ok((res, len))
} else {
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.einsum_col_coord(),
values,
false,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_einsum_with_duplication_constrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
check_mode: &crate::circuit::CheckMode,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_constrained(
&mut region.borrow_mut(),
self.row,
self.einsum_col_coord(),
values,
check_mode,
&mut self.assigned_constants,
)?;
Ok((res, len))
} else {
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.einsum_col_coord(),
values,
true,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
}
}
/// Enable a selector
pub fn enable(&mut self, selector: Option<&Selector>, offset: usize) -> Result<(), Error> {
match &self.region {
@@ -763,4 +965,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
Ok(())
}
/// flush row to the next row in einsum area
pub fn flush_einsum(&mut self) -> Result<(), CircuitError> {
// increment by the difference between the current linear coord and the next row
let num_einsum_inner_cols = self.num_einsum_inner_cols();
let remainder = self.einsum_col_coord() % num_einsum_inner_cols;
if remainder != 0 {
let diff = num_einsum_inner_cols - remainder;
self.increment_einsum_col_coord(diff);
}
if self.einsum_col_coord() % num_einsum_inner_cols != 0 {
return Err(CircuitError::FlushError);
}
Ok(())
}
}

View File

@@ -1290,6 +1290,7 @@ pub(crate) fn calibrate(
total_const_size: new_settings.total_const_size,
dynamic_lookup_params: new_settings.dynamic_lookup_params,
shuffle_params: new_settings.shuffle_params,
einsum_params: new_settings.einsum_params,
..settings.clone()
};

View File

@@ -64,6 +64,7 @@ use pyo3::types::PyDictMethods;
use pyo3::IntoPyObject;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::Deref;
pub use utilities::*;
pub use vars::*;
@@ -438,6 +439,15 @@ pub struct ShuffleParams {
pub total_shuffle_col_size: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
/// Parameters for einsum operations
pub struct EinsumParams {
/// einsum equations
pub equations: Vec<(String, HashMap<char, usize>)>,
/// total einsum column size
pub total_einsum_col_size: usize,
}
/// model parameters
#[derive(Clone, Debug, Default, PartialEq)]
pub struct GraphSettings {
@@ -453,6 +463,8 @@ pub struct GraphSettings {
pub dynamic_lookup_params: DynamicLookupParams,
/// shuffle parameters, flattened for backwards compatibility
pub shuffle_params: ShuffleParams,
/// einsum parameters
pub einsum_params: EinsumParams,
/// the shape of public inputs to the model (in order of appearance)
pub model_instance_shapes: Vec<Vec<usize>>,
/// model output scales
@@ -487,7 +499,7 @@ impl Serialize for GraphSettings {
if serializer.is_human_readable() {
// JSON format - use flattened fields for backwards compatibility
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("GraphSettings", 21)?;
let mut state = serializer.serialize_struct("GraphSettings", 22)?;
state.serialize_field("run_args", &self.run_args)?;
state.serialize_field("num_rows", &self.num_rows)?;
state.serialize_field("total_assignments", &self.total_assignments)?;
@@ -514,6 +526,9 @@ impl Serialize for GraphSettings {
&self.shuffle_params.total_shuffle_col_size,
)?;
// Serialize EinsumParams
state.serialize_field("einsum_params", &self.einsum_params)?;
state.serialize_field("model_instance_shapes", &self.model_instance_shapes)?;
state.serialize_field("model_output_scales", &self.model_output_scales)?;
state.serialize_field("model_input_scales", &self.model_input_scales)?;
@@ -530,13 +545,14 @@ impl Serialize for GraphSettings {
} else {
// Binary format (bincode) - use nested struct format
use serde::ser::SerializeTuple;
let mut state = serializer.serialize_tuple(18)?;
let mut state = serializer.serialize_tuple(19)?;
state.serialize_element(&self.run_args)?;
state.serialize_element(&self.num_rows)?;
state.serialize_element(&self.total_assignments)?;
state.serialize_element(&self.total_const_size)?;
state.serialize_element(&self.dynamic_lookup_params)?;
state.serialize_element(&self.shuffle_params)?;
state.serialize_element(&self.einsum_params)?;
state.serialize_element(&self.model_instance_shapes)?;
state.serialize_element(&self.model_output_scales)?;
state.serialize_element(&self.model_input_scales)?;
@@ -576,6 +592,8 @@ impl<'de> Deserialize<'de> for GraphSettings {
// Flattened ShuffleParams fields
NumShuffles,
TotalShuffleColSize,
// EinsumParams field
EinsumParams,
ModelInstanceShapes,
ModelOutputScales,
ModelInputScales,
@@ -615,6 +633,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
let mut num_dynamic_lookups = None;
let mut num_shuffles = None;
let mut total_shuffle_col_size = None;
let mut einsum_params = None;
let mut model_instance_shapes = None;
let mut model_output_scales = None;
let mut model_input_scales = None;
@@ -684,6 +703,12 @@ impl<'de> Deserialize<'de> for GraphSettings {
}
total_shuffle_col_size = Some(map.next_value()?);
}
Field::EinsumParams => {
if einsum_params.is_some() {
return Err(de::Error::duplicate_field("einsum_params"));
}
einsum_params = Some(map.next_value()?);
}
Field::ModelInstanceShapes => {
if model_instance_shapes.is_some() {
return Err(de::Error::duplicate_field("model_instance_shapes"));
@@ -822,6 +847,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
total_const_size,
dynamic_lookup_params,
shuffle_params,
einsum_params: einsum_params.unwrap_or_default(),
model_instance_shapes,
model_output_scales,
model_input_scales,
@@ -862,42 +888,45 @@ impl<'de> Deserialize<'de> for GraphSettings {
let shuffle_params = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(5, &self))?;
let model_instance_shapes = seq
let einsum_params = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(6, &self))?;
let model_output_scales = seq
let model_instance_shapes = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(7, &self))?;
let model_input_scales = seq
let model_output_scales = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(8, &self))?;
let module_sizes = seq
let model_input_scales = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(9, &self))?;
let required_lookups = seq
let module_sizes = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(10, &self))?;
let required_range_checks = seq
let required_lookups = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(11, &self))?;
let check_mode = seq
let required_range_checks = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(12, &self))?;
let version = seq
let check_mode = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(13, &self))?;
let num_blinding_factors = seq
let version = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(14, &self))?;
let timestamp = seq
let num_blinding_factors = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(15, &self))?;
let input_types = seq
let timestamp = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(16, &self))?;
let output_types = seq
let input_types = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(17, &self))?;
let output_types = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(18, &self))?;
Ok(GraphSettings {
run_args,
@@ -906,6 +935,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
total_const_size,
dynamic_lookup_params,
shuffle_params,
einsum_params,
model_instance_shapes,
model_output_scales,
model_input_scales,
@@ -935,6 +965,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
"num_dynamic_lookups",
"num_shuffles",
"total_shuffle_col_size",
"einsum_params",
"model_instance_shapes",
"model_output_scales",
"model_input_scales",
@@ -953,7 +984,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
deserializer.deserialize_struct("GraphSettings", FIELDS, GraphSettingsVisitor)
} else {
// Binary format (bincode) - use tuple deserialization
deserializer.deserialize_tuple(18, GraphSettingsVisitor)
deserializer.deserialize_tuple(19, GraphSettingsVisitor)
}
}
}
@@ -1038,6 +1069,13 @@ impl GraphSettings {
.ceil() as u32
}
/// Calculates the logrows for einsum computation area in which there is no column overflow
pub fn einsum_logrows(&self) -> u32 {
(self.einsum_params.total_einsum_col_size as f64 / self.run_args.num_inner_cols as f64)
.log2()
.ceil() as u32
}
/// calculate the total number of instances
pub fn total_instances(&self) -> Vec<usize> {
let mut instances: Vec<usize> = self.module_sizes.num_instances();
@@ -1593,10 +1631,11 @@ impl GraphCircuit {
let model_constraint_logrows = self.settings().model_constraint_logrows_with_blinding();
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
let constants_logrows = self.settings().constants_logrows();
let einsum_logrows = self.settings().einsum_logrows();
max_logrows = std::cmp::min(
max_logrows,
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
*[model_constraint_logrows, min_bits, constants_logrows]
*[model_constraint_logrows, min_bits, constants_logrows, einsum_logrows]
.iter()
.max()
.unwrap(),
@@ -2180,6 +2219,7 @@ pub mod tests {
num_shuffles: 3,
total_shuffle_col_size: 256,
},
einsum_params: EinsumParams::default(),
model_instance_shapes: vec![vec![1, 2, 3]],
model_output_scales: vec![],
model_input_scales: vec![],
@@ -2254,7 +2294,8 @@ pub mod tests {
"decomp_base": 128,
"decomp_legs": 2,
"bounded_log_lookup": false,
"ignore_range_check_inputs_outputs": false
"ignore_range_check_inputs_outputs": false,
"disable_freivalds": false
},
"num_rows": 236,
"total_assignments": 472,

View File

@@ -3,6 +3,7 @@ use super::extract_const_quantized_values;
use super::node::*;
use super::vars::*;
use super::GraphSettings;
use crate::circuit::einsum::analysis::analyze_einsum_usage;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
@@ -37,7 +38,6 @@ use log::{debug, info, trace};
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs;
@@ -106,6 +106,8 @@ pub struct DummyPassRes {
pub dynamic_lookup_params: DynamicLookupParams,
/// shuffle parameters
pub shuffle_params: ShuffleParams,
/// einsum parameters
pub einsum_params: crate::graph::EinsumParams,
/// num shuffles
pub num_shuffles: usize,
/// shuffle
@@ -592,6 +594,7 @@ impl Model {
output_types: Some(self.get_output_types()),
dynamic_lookup_params: res.dynamic_lookup_params,
shuffle_params: res.shuffle_params,
einsum_params: res.einsum_params,
total_const_size: res.total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
@@ -1047,6 +1050,7 @@ impl Model {
let lookup_range = settings.run_args.lookup_range;
let logrows = settings.run_args.logrows as usize;
let num_inner_cols = settings.run_args.num_inner_cols;
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
@@ -1095,6 +1099,24 @@ impl Model {
)?;
}
// Configures the circuit to use Freivalds' argument
// In the dummy phase, Freivalds' is configured as a default (unless `disable-freivalds` is not enabled),
// but if einsum coordinate is 0, it means that all the einsum layouts are dispatched to use only base operations.
if settings.einsum_params.total_einsum_col_size > 0 {
debug!("configuring einsums...");
let used_einsums: HashMap<(usize, String), HashMap<char, usize>> = settings
.einsum_params
.equations
.iter()
.enumerate()
.map(|(idx, (equation, indices_to_dims))| {
((idx, equation.clone()), indices_to_dims.clone())
})
.collect();
let analysis = analyze_einsum_usage(&used_einsums)?;
base_gate.configure_einsums(meta, &analysis, num_inner_cols, logrows)?;
}
Ok(base_gate)
}
@@ -1147,17 +1169,30 @@ impl Model {
let original_constants = constants.clone();
let challenges = {
if let Some(einsum_config) = &config.base.einsums {
einsum_config
.challenges()?
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec()
} else {
vec![]
}
};
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new_with_constants(
let mut thread_safe_region = RegionCtx::new_with_challenges(
region,
0,
run_args.num_inner_cols,
run_args.decomp_base,
run_args.decomp_legs,
original_constants.clone(),
challenges.clone(),
);
thread_safe_region.update_constants(original_constants.clone());
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1459,8 +1494,16 @@ impl Model {
results.insert(*input_idx, vec![inputs[i].clone()]);
}
let mut dummy_config =
PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols);
let mut dummy_config = {
if run_args.disable_freivalds {
PolyConfig::dummy_without_freivalds(
run_args.logrows as usize,
run_args.num_inner_cols,
)
} else {
PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols)
}
};
let mut model_config = ModelConfig {
base: dummy_config.clone(),
vars: ModelVars::new_dummy(),
@@ -1529,6 +1572,10 @@ impl Model {
num_shuffles: region.shuffle_index(),
total_shuffle_col_size: region.shuffle_col_coord(),
},
einsum_params: crate::graph::EinsumParams {
equations: region.used_einsum_equations(),
total_einsum_col_size: region.einsum_col_coord(),
},
total_const_size: region.total_constants(),
lookup_ops: region.used_lookups(),
range_checks: region.used_range_checks(),

View File

@@ -363,6 +363,13 @@ pub struct RunArgs {
/// Optional override for epsilon value
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
pub epsilon: Option<f64>,
/// Forcefully disable using Freivalds' argument in einsum operations
/// Freivalds' argument can make verifier bigger, so this option is useful when
/// the verifier size is a concern
/// Without this option the circuit layouter will always try to use Freivalds' argument
/// when it is good to do so
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
pub disable_freivalds: bool,
}
impl RunArgs {
@@ -398,6 +405,7 @@ impl Default for RunArgs {
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
epsilon: None,
disable_freivalds: false,
}
}
}

View File

@@ -480,6 +480,13 @@ impl<T: Clone + TensorType> Tensor<T> {
self[index].clone()
}
/// Extracts a single value from the tensor
pub fn get_scalar(&self) -> T {
assert!(self.inner.len() == 1);
assert!(self.dims.iter().all(|dim| *dim == 1));
self.inner[0].clone()
}
/// Get a mutable array index from rows / columns indices.
///
/// ```
@@ -901,6 +908,22 @@ impl<T: Clone + TensorType> Tensor<T> {
Ok(())
}
/// remove axes that have dimensions 1
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap();
/// let b = a.remove_trivial_axes().unwrap();
/// assert_eq!(b, expected);
/// ```
pub fn remove_trivial_axes(&self) -> Result<Self, TensorError> {
let mut result = self.clone();
let new_dims: Vec<_> = self.dims.iter().copied().filter(|dim| *dim > 1).collect();
result.reshape(&new_dims)?;
Ok(result)
}
/// Move axis of the tensor
/// ```
/// use ezkl::tensor::Tensor;

View File

@@ -5,6 +5,7 @@ use crate::{
};
use itertools::Itertools;
use maybe_rayon::{iter::ParallelIterator, prelude::IntoParallelRefIterator};
use std::collections::{HashMap, HashSet};
pub use std::ops::{Add, Mul, Neg, Sub};
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
@@ -2396,6 +2397,8 @@ pub mod nonlinearities {
/// Ops that return the transcript i.e intermediate calcs of an op
pub mod accumulated {
use maybe_rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator};
use super::*;
/// Dot product of two tensors.
@@ -2523,4 +2526,327 @@ pub mod accumulated {
Ok(transcript)
}
#[inline]
fn row_major_strides(dims: &[usize]) -> Vec<usize> {
let mut s = vec![0; dims.len()];
let mut acc = 1;
for (i, &d) in dims.iter().enumerate().rev() {
s[i] = acc;
acc *= d;
}
s
}
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::accumulated::einsum;
///
/// // matmul case
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[2, 3, 2, 1, 1, 1]),
/// &[3, 2],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("ij,jk->ik", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[8, 9, 5, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// // element wise multiplication
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
/// &[3, 3],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("ij,ij->ij", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 9, 2, 6, 12, 3, 8, 15]), &[3, 3]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // dot product of A with the transpose of B.
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
/// &[3, 3],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("ik,jk->ij", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 14, 14, 20, 20, 20, 26, 26, 26]), &[3, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // dot product
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
/// &[3, 3],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("ik,ik->i", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 20, 26]), &[3]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // dot product
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3]),
/// &[3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3]),
/// &[3],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("i,i->", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[14]), &[1]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // wut ?
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("anm,bm->ba", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[68, 80, 95, 113, 134, 158]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // wutttttt ?
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let z = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8, 9, 9]),
/// &[2, 3],
/// ).unwrap();
///
/// let (result, _) = einsum::<IntegerRep>("bn,anm,bm->ba", &[&z, &x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[390, 414, 534, 994, 1153, 1384]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // contraction with a single common axis
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("abc,cd->", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[648]), &[1]).unwrap();
/// assert_eq!(result, expected);
///
/// // contraction with no common axes (outer product)
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("abc,ed->", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1296]), &[1]).unwrap();
/// assert_eq!(result, expected);
///
/// // trivial axes mapping
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5]),
/// &[2],
/// ).unwrap();
///
/// let (result, _) = einsum::<IntegerRep>("mk,k->m", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[41, 68]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
/// let (result, _) = einsum::<IntegerRep>("mk,k->mn", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[41, 68]), &[2, 1]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[0, 0, 0, 3]),
/// &[1, 4],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[213, 227, 74, 77]),
/// &[4],
/// ).unwrap();
///
/// let (result, _) = einsum::<IntegerRep>("mk,k->ma", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[231]), &[1, 1]).unwrap();
/// assert_eq!(result, expected);
/// // subtle difference
/// let (result, _) = einsum::<IntegerRep>("mk,n->ma", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1773]), &[1, 1]).unwrap();
/// assert_eq!(result, expected);
///
/// ```
///
pub fn einsum<T>(
equation: &str,
input_tensors: &[&Tensor<T>],
) -> Result<(Tensor<T>, HashMap<char, usize>), TensorError>
where
T: Clone + TensorType + Mul<Output = T> + Add<Output = T> + Send + Sync,
{
let (input_exprs, output_expr) = equation.split_once("->").unwrap();
let input_exprs: Vec<&str> = input_exprs.split(',').collect();
assert_eq!(input_exprs.len(), input_tensors.len());
let mut dim_of: HashMap<char, usize> = HashMap::new();
for (input_expr, t) in input_exprs.iter().zip(input_tensors.iter()) {
for (c, &d) in input_expr.chars().zip(t.dims().iter()) {
let e = dim_of.entry(c).or_insert(d);
debug_assert!((*e == d) || (*e == 1) || (d == 1));
*e = (*e).max(d);
}
}
// Output dims
let out_idx: Vec<char> = output_expr.chars().collect();
let out_dims: Vec<usize> = out_idx
.iter()
.map(|c| *dim_of.get(c).unwrap_or(&1))
.collect();
// Reduction indices
let all_idx: HashSet<char> = dim_of.keys().copied().collect();
let out_set: HashSet<char> = out_idx.iter().copied().collect();
let red_idx: Vec<char> = all_idx.difference(&out_set).copied().collect();
let red_dims: Vec<usize> = red_idx.iter().map(|c| dim_of[c]).collect();
// Fast index->pos
let out_pos: HashMap<char, usize> =
out_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
let red_pos: HashMap<char, usize> =
red_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
// Precompute strides per input and contributions
struct Contrib {
out_stride: Vec<usize>,
red_stride: Vec<usize>,
}
let contribs: Vec<Contrib> = input_exprs
.iter()
.zip(input_tensors.iter())
.map(|(expr, t)| {
let dims = t.dims().to_vec();
let strides = row_major_strides(&dims);
let mut out_stride = vec![0; out_idx.len()];
let mut red_stride = vec![0; red_idx.len()];
for (ax, (c, &d)) in expr.chars().zip(dims.iter()).enumerate() {
let s = if d == 1 { 0 } else { strides[ax] };
if let Some(&p) = out_pos.get(&c) {
out_stride[p] = s;
} else if let Some(&q) = red_pos.get(&c) {
red_stride[q] = s;
}
}
Contrib {
out_stride,
red_stride,
}
})
.collect();
// Prepare output buffer
let mut out = if out_dims.is_empty() {
Tensor::<T>::new(None, &[1])?
} else {
Tensor::<T>::new(None, &out_dims)?
};
let out_rank = out_dims.len();
let red_rank = red_dims.len();
// Materialize output elements one by one
out.par_iter_mut()
.enumerate()
.for_each(|(out_linear_coord, out)| {
let mut out_index = vec![0usize; out_rank];
{
let mut x = out_linear_coord;
for i in (0..out_rank).rev() {
let d = out_dims[i];
out_index[i] = x % d;
x /= d;
}
}
// Base offset per input from output coordinates
let mut base_off = vec![0usize; input_tensors.len()];
for (i, c) in contribs.iter().enumerate() {
let mut off = 0usize;
for p in 0..out_rank {
off += out_index[p] * c.out_stride[p];
}
base_off[i] = off;
}
let mut acc = T::zero().unwrap();
if red_rank == 0 {
// No reduction -> just multiply corresponding elements
let mut prod = T::one().unwrap();
for (i, t) in input_tensors.iter().enumerate() {
let val = t.get_flat_index(base_off[i]);
prod = prod * val;
}
acc = acc + prod;
} else {
// Iterate over all reduction coords
let red_size = red_dims.iter().product::<usize>();
let mut red_index = vec![0usize; red_rank];
for red_linear_coord in 0..red_size {
{
let mut x = red_linear_coord;
for q in (0..red_rank).rev() {
let d = red_dims[q];
red_index[q] = x % d;
x /= d;
}
}
let mut prod = T::one().unwrap();
for (i, (t, c)) in input_tensors.iter().zip(contribs.iter()).enumerate() {
let mut off = base_off[i];
for q in 0..red_rank {
off += red_index[q] * c.red_stride[q];
}
let val = t.get_flat_index(off);
prod = prod * val;
}
acc = acc + prod;
}
}
// write result
*out = acc;
});
Ok((out, dim_of))
}
}

View File

@@ -940,6 +940,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// remove axes that have dimensions 1
pub fn remove_trivial_axes(&mut self) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.remove_trivial_axes()?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
}
};
Ok(())
}
/// Takes a slice of the tensor along a given axis
///
/// # Arguments

View File

@@ -1,3 +1,4 @@
use halo2_proofs::plonk::SecondPhase;
use log::{debug, error, warn};
use crate::circuit::{region::ConstantsMap, CheckMode};
@@ -152,6 +153,52 @@ impl VarTensor {
}
}
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
/// the values need to be hidden in the proof.
///
/// # Arguments
/// * `cs` - The constraint system to create columns in
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_inner_cols` - Number of columns in each inner block
/// * `capacity` - Total number of advice cells to allocate
///
/// # Returns
/// A new VarTensor::Advice in SecondPhase with blinded columns enabled for equality constraints
pub fn new_advice_in_second_phase<F: PrimeField>(
cs: &mut ConstraintSystem<F>,
logrows: usize,
num_inner_cols: usize,
capacity: usize,
) -> Self {
let max_rows = Self::max_rows(cs, logrows);
let max_assignments = Self::max_rows(cs, logrows) * num_inner_cols;
let mut modulo = (capacity / max_assignments) + 1;
// we add a buffer for duplicated rows (we get at most 1 duplicated row per column)
modulo = ((capacity + modulo) / max_assignments) + 1;
let mut advices = vec![];
if modulo > 1 {
debug!("using column duplication for {} advice blocks", modulo - 1);
}
for _ in 0..modulo {
let mut inner = vec![];
for _ in 0..num_inner_cols {
let col = cs.advice_column_in(SecondPhase);
cs.enable_equality(col);
inner.push(col);
}
advices.push(inner);
}
VarTensor::Advice {
inner: advices,
num_inner_cols,
col_size: max_rows,
}
}
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
/// Fixed columns are used for constant values that are known at circuit creation time.
///
@@ -270,7 +317,7 @@ impl VarTensor {
/// # Returns
/// A tuple of (block_index, column_index, row_index)
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
// x indexes over blocks of size num_inner_cols
// x (block idx) indexes over blocks of size num_inner_cols
let x = linear_coord / self.block_size();
// y indexes over the cols inside a block
let y = linear_coord % self.num_inner_cols();
@@ -519,7 +566,7 @@ impl VarTensor {
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
row: usize,
_row: usize,
offset: usize,
values: &ValTensor<F>,
single_inner_col: bool,
@@ -545,7 +592,7 @@ impl VarTensor {
self.num_inner_cols()
};
let duplication_offset = if single_inner_col { row } else { offset };
let (_, _, duplication_offset) = self.cartesian_coord(offset);
// duplicates every nth element to adjust for column overflow
let mut res: ValTensor<F> = v
@@ -651,7 +698,7 @@ impl VarTensor {
>(
&self,
region: &mut Region<F>,
row: usize,
_row: usize,
offset: usize,
values: &ValTensor<F>,
check_mode: &CheckMode,
@@ -669,7 +716,7 @@ impl VarTensor {
ValTensor::Value { inner: v, dims, .. } => {
let duplication_freq = self.col_size();
let num_repeats = 1;
let duplication_offset = row;
let (_, _, duplication_offset) = self.cartesian_coord(offset);
// duplicates every nth element to adjust for column overflow
let v = v

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -28,7 +28,8 @@
"decomp_base": 128,
"decomp_legs": 2,
"bounded_log_lookup": false,
"ignore_range_check_inputs_outputs": false
"ignore_range_check_inputs_outputs": false,
"disable_freivalds": false
},
"num_rows": 236,
"total_assignments": 472,
@@ -36,6 +37,10 @@
"total_dynamic_col_size": 0,
"max_dynamic_input_len": 0,
"num_dynamic_lookups": 0,
"einsum_params": {
"equations": [],
"total_einsum_col_size": 0
},
"num_shuffles": 0,
"total_shuffle_col_size": 0,
"model_instance_shapes": [

Binary file not shown.

View File

@@ -207,7 +207,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 99] = [
const TESTS: [&str; 100] = [
"1l_mlp", //0
"1l_slice", //1
"1l_concat", //2
@@ -311,6 +311,7 @@ mod native_tests {
"exp", // 96
"general_exp", // 97
"integer_div", // 98
"large_mlp", // 99
];
const WASM_TESTS: [&str; 44] = [
@@ -550,7 +551,7 @@ mod native_tests {
}
});
seq!(N in 0..=98 {
seq!(N in 0..99 {
// #(#[test_case(TESTS[N])])*
// #[ignore]
@@ -1023,7 +1024,7 @@ mod native_tests {
});
seq!(N in 0..=98 {
seq!(N in 0..=99 {
#(#[test_case(TESTS[N])])*
fn kzg_evm_prove_and_verify_reusable_verifier_(test: &str) {
crate::native_tests::init_binary();

View File

@@ -5,8 +5,7 @@ mod wasm32 {
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
kzgCommit, pkValidation, poseidonHash, proofValidation, prove, settingsValidation,
srsValidation, u8_array_to_u128_le, verify, verifyAggr, verifyEVM, vkValidation,
witnessValidation,
srsValidation, u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
};
use ezkl::circuit::modules::polycommit::PolyCommitChip;
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
@@ -94,21 +93,6 @@ mod wasm32 {
assert_eq!(calldata, reference_calldata);
}
#[wasm_bindgen_test]
async fn can_verify_evm() {
// verify with single purpose evm verifier contract
let value = verifyEVM(
wasm_bindgen::Clamped(PROOF.to_vec()),
VERIFIER_BYTECODE.to_vec(),
None,
)
.map_err(|_| "failed")
.unwrap();
// should not fail
assert!(value);
}
#[wasm_bindgen_test]
fn verify_kzg_commit() {
// create a vector of field elements Vec<Fr> and assign it to the message variable