mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3468a2b02d |
192
.github/workflows/engine.yml
vendored
Normal file
192
.github/workflows/engine.yml
vendored
Normal file
@@ -0,0 +1,192 @@
|
||||
name: Build and Publish EZKL Engine npm package
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag to release"
|
||||
required: true
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: .
|
||||
jobs:
|
||||
publish-wasm-bindings:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
name: publish-wasm-bindings
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
cache: false
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
|
||||
- name: Create package.json in pkg folder
|
||||
shell: bash
|
||||
run: |
|
||||
cat > pkg/package.json << EOF
|
||||
{
|
||||
"name": "@ezkljs/engine",
|
||||
"version": "$RELEASE_TAG",
|
||||
"dependencies": {
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"files": [
|
||||
"nodejs/ezkl_bg.wasm",
|
||||
"nodejs/ezkl.js",
|
||||
"nodejs/ezkl.d.ts",
|
||||
"nodejs/package.json",
|
||||
"nodejs/utils.js",
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
],
|
||||
"main": "nodejs/ezkl.js",
|
||||
"module": "web/ezkl.js",
|
||||
"types": "nodejs/ezkl.d.ts",
|
||||
"sideEffects": [
|
||||
"web/snippets/*"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:21,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
|
||||
|
||||
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
|
||||
run: |
|
||||
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
|
||||
|
||||
- name: Add serialize and deserialize methods to nodejs bundle
|
||||
run: |
|
||||
echo '
|
||||
const JSONBig = require("json-bigint");
|
||||
|
||||
function deserialize(buffer) { // buffer is a Uint8ClampedArray | Uint8Array // return a JSON object
|
||||
if (buffer instanceof Uint8ClampedArray) {
|
||||
buffer = new Uint8Array(buffer.buffer);
|
||||
}
|
||||
const string = new TextDecoder().decode(buffer);
|
||||
const jsonObject = JSONBig.parse(string);
|
||||
return jsonObject;
|
||||
}
|
||||
|
||||
function serialize(data) { // data is an object // return a Uint8ClampedArray
|
||||
// Step 1: Stringify the Object with BigInt support
|
||||
if (typeof data === "object") {
|
||||
data = JSONBig.stringify(data);
|
||||
}
|
||||
// Step 2: Encode the JSON String
|
||||
const uint8Array = new TextEncoder().encode(data);
|
||||
|
||||
// Step 3: Convert to Uint8ClampedArray
|
||||
return new Uint8ClampedArray(uint8Array.buffer);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
deserialize,
|
||||
serialize
|
||||
};
|
||||
' > pkg/nodejs/utils.js
|
||||
- name: Add serialize and deserialize methods to web bundle
|
||||
run: |
|
||||
echo '
|
||||
import { parse, stringify } from "json-bigint";
|
||||
|
||||
export function deserialize(buffer) { // buffer is a Uint8ClampedArray | Uint8Array // return a JSON object
|
||||
if (buffer instanceof Uint8ClampedArray) {
|
||||
buffer = new Uint8Array(buffer.buffer);
|
||||
}
|
||||
const string = new TextDecoder().decode(buffer);
|
||||
const jsonObject = parse(string);
|
||||
return jsonObject;
|
||||
}
|
||||
|
||||
export function serialize(data) { // data is an object // return a Uint8ClampedArray
|
||||
// Step 1: Stringify the Object with BigInt support
|
||||
if (typeof data === "object") {
|
||||
data = stringify(data);
|
||||
}
|
||||
// Step 2: Encode the JSON String
|
||||
const uint8Array = new TextEncoder().encode(data);
|
||||
|
||||
// Step 3: Convert to Uint8ClampedArray
|
||||
return new Uint8ClampedArray(uint8Array.buffer);
|
||||
}
|
||||
' > pkg/web/utils.js
|
||||
- name: Expose serialize and deserialize imports in nodejs target
|
||||
run: |
|
||||
sed -i '53i// import serialize and deserialize from utils.js\nconst { serialize, deserialize } = require(`./utils.js`);\nmodule.exports.serialize = serialize;\nmodule.exports.deserialize = deserialize;' pkg/nodejs/ezkl.js
|
||||
- name: Expose serialize and deserialize imports in web target
|
||||
run: |
|
||||
sed -i '51i\
|
||||
// import serialize and deserialize from utils.js\
|
||||
import { serialize, deserialize } from '\''./utils.js'\'';\
|
||||
export { serialize, deserialize };' pkg/web/ezkl.js
|
||||
- name: Add serialize and deserialize imports to nodejs ezkl.d.ts
|
||||
run: |
|
||||
sed -i '1i\
|
||||
export declare function serialize(data: object | string): Uint8ClampedArray;\
|
||||
export declare function deserialize(buffer: Uint8ClampedArray | Uint8Array): any;' pkg/nodejs/ezkl.d.ts
|
||||
|
||||
- name: Add serialize and deserialize imports to web ezkl.d.ts
|
||||
run: |
|
||||
sed -i '1i\
|
||||
export declare function serialize(data: object | string): Uint8ClampedArray;\
|
||||
export declare function deserialize(buffer: Uint8ClampedArray | Uint8Array): any;' pkg/web/ezkl.d.ts
|
||||
|
||||
- name: Create README.md in pkg folder
|
||||
run: |
|
||||
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd pkg
|
||||
npm install
|
||||
npm ci
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
14
.github/workflows/pypi-gpu.yml
vendored
14
.github/workflows/pypi-gpu.yml
vendored
@@ -27,8 +27,6 @@ jobs:
|
||||
target: [x86_64]
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
@@ -38,16 +36,6 @@ jobs:
|
||||
python-version: 3.12
|
||||
architecture: x64
|
||||
|
||||
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
|
||||
- name: Set pyproject.toml version to match github tag and rename ezkl to ezkl-gpu
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -83,7 +71,7 @@ jobs:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: auto
|
||||
container: off
|
||||
args: --release --out dist --features python-bindings,gpu-accelerated
|
||||
args: --release --out dist --features python-bindings,icicle
|
||||
|
||||
- name: Install built wheel
|
||||
if: matrix.target == 'x86_64'
|
||||
|
||||
11
.github/workflows/release.yml
vendored
11
.github/workflows/release.yml
vendored
@@ -47,8 +47,6 @@ jobs:
|
||||
TARGET_DIR: ./target
|
||||
RUST_BACKTRACE: 1
|
||||
PCRE2_SYS_STATIC: 1
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
@@ -61,13 +59,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
|
||||
- name: Get release version from tag
|
||||
shell: bash
|
||||
@@ -88,7 +79,7 @@ jobs:
|
||||
sudo apt-get update
|
||||
|
||||
- name: Build release binary
|
||||
run: cargo build --release -Z sparse-registry --features gpu-accelerated
|
||||
run: cargo build --release -Z sparse-registry --features icicle
|
||||
|
||||
- name: Build archive
|
||||
shell: bash
|
||||
|
||||
629
.github/workflows/rust.yml
vendored
629
.github/workflows/rust.yml
vendored
@@ -20,18 +20,18 @@ env:
|
||||
|
||||
jobs:
|
||||
fr-age-test:
|
||||
needs: [build, library-tests, docs]
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
@@ -46,76 +46,57 @@ jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- name: Build
|
||||
run: cargo build --verbose
|
||||
|
||||
docs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Docs
|
||||
run: cargo doc --verbose
|
||||
|
||||
library-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -124,80 +105,108 @@ jobs:
|
||||
run: cargo test --doc --verbose
|
||||
- name: Library tests
|
||||
run: cargo nextest run --lib --verbose
|
||||
- name: Library tests (original lookup)
|
||||
run: cargo nextest run --lib --verbose --no-default-features --features ezkl,eth-original-lookup
|
||||
|
||||
ultra-overflow-tests-gpu:
|
||||
# ultra-overflow-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2025-05-01
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
# with:
|
||||
# wasmtime-version: "3.0.1"
|
||||
# # - name: Matmul overflow (wasi)
|
||||
# # run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# # - name: Conv overflow (wasi)
|
||||
# # run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: lookup overflow
|
||||
# run: cargo nextest run lookup_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Matmul overflow
|
||||
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Conv overflow
|
||||
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
# - name: Conv + relu overflow
|
||||
# run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored
|
||||
|
||||
ultra-overflow-tests_og-lookup:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: gpu
|
||||
runs-on: non-gpu,non-sgx
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Setup GPU dependencies
|
||||
run: sudo ./setup-gpu.sh --yes
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
# - name: Matmul overflow (wasi)
|
||||
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run lookup_ultra_overflow --no-capture --features gpu-accelerated -- --include-ignored
|
||||
run: cargo nextest run --release lookup_ultra_overflow --no-capture --no-default-features --features ezkl,eth-original-lookup -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture --features gpu-accelerated -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run --release matmul_col_ultra_overflow --no-capture --no-default-features --features ezkl,eth-original-lookup -- --include-ignored
|
||||
- name: Conv overflow
|
||||
run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features gpu-accelerated -- --include-ignored
|
||||
run: RUST_LOG=debug cargo nextest run --release conv_col_ultra_overflow --no-capture --no-default-features --features ezkl,eth-original-lookup -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: cargo nextest run conv_relu_col_ultra_overflow --no-capture --features gpu-accelerated -- --include-ignored
|
||||
run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl,eth-original-lookup -- --include-ignored
|
||||
|
||||
ultra-overflow-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs]
|
||||
runs-on: non-gpu,non-sgx
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
# - name: Matmul overflow (wasi)
|
||||
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: lookup overflow
|
||||
run: cargo nextest run lookup_ultra_overflow --no-capture -- --include-ignored
|
||||
- name: Matmul overflow
|
||||
@@ -210,61 +219,91 @@ jobs:
|
||||
model-serialization:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Model serialization different binary ID
|
||||
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
|
||||
|
||||
mock-proving-tests:
|
||||
wasm32-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs]
|
||||
runs-on: ubuntu-latest-64-cores
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
# add `atomics` and `bulk-memory` to RUSTFLAGS to enable wasm-bindgen tests
|
||||
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- name: Create webdriver.json to disable timeouts
|
||||
run: |
|
||||
echo '{"args": ["--headless", "--disable-gpu", "--disable-dev-shm-usage", "--no-sandbox"]}' > webdriver.json
|
||||
- name: Run wasm verifier tests
|
||||
run: |
|
||||
ulimit -n 65536
|
||||
WASM_BINDGEN_TEST_THREADS=1 \
|
||||
WASM_BINDGEN_TEST_TIMEOUT=1800 \
|
||||
CHROMEDRIVER_ARGS="--log-level=INFO" \
|
||||
wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web -- --nocapture
|
||||
|
||||
mock-proving-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu,non-sgx
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: Large 1D Conv Mock
|
||||
@@ -317,42 +356,59 @@ jobs:
|
||||
prove-and-verify-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs]
|
||||
runs-on: non-gpu,non-sgx
|
||||
# needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
with:
|
||||
version: 8
|
||||
- name: Use Node.js 18.12.1
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: "Add rust-src"
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and package
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 56b806a3ba7866a3b061093bebd0fa2ace97f1fc --locked anvil --force
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: KZG prove and verify tests (EVM)
|
||||
run: cargo nextest run --verbose "tests_evm::kzg_evm_prove_and_verify_::" --test-threads 1
|
||||
# - name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
|
||||
# run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --features reusable-verifier --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --features reusable-verifier --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg inputs)
|
||||
@@ -366,42 +422,106 @@ jobs:
|
||||
- name: KZG prove and verify tests (EVM + hashed outputs)
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
|
||||
|
||||
# prove-and-verify-tests-metal:
|
||||
# permissions:
|
||||
# contents: read
|
||||
# runs-on: macos-13
|
||||
# # needs: [build, library-tests, docs]
|
||||
# steps:
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2025-05-01
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
# with:
|
||||
# # Pin to version 0.12.1
|
||||
# version: 'v0.12.1'
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2025-05-01
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - name: Use pnpm 8
|
||||
# uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
# with:
|
||||
# version: 8
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
|
||||
|
||||
prove-and-verify-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
runs-on: non-gpu,non-sgx
|
||||
needs: [build, library-tests, docs]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
|
||||
with:
|
||||
version: 8
|
||||
- name: Use Node.js 18.12.1
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies for js tests
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::t
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
- name: KZG prove and verify tests (hashed inputs + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_hashed_inputs_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
- name: IPA prove and verify tests
|
||||
run: cargo nextest run --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
- name: IPA prove and verify tests (ipa outputs)
|
||||
run: cargo nextest run --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
- name: KZG prove and verify tests single inner col
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_single_col
|
||||
- name: KZG prove and verify tests triple inner col
|
||||
@@ -421,81 +541,165 @@ jobs:
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed
|
||||
|
||||
prove-and-verify-tests-gpu:
|
||||
# prove-and-verify-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2025-05-01
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (kzg outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public inputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (fixed params)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (hashed outputs)
|
||||
# run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: gpu
|
||||
needs: [build, library-tests, docs]
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
|
||||
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Setup GPU dependencies
|
||||
run: sudo ./setup-gpu.sh --yes
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::t --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_::t --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_public_input --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_fixed_params --features gpu-accelerated --test-threads 1
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features gpu-accelerated --test-threads 1
|
||||
- name: Mock aggr tests (KZG)
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
|
||||
# prove-and-verify-aggr-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
# with:
|
||||
# persist-credentials: false
|
||||
# - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
# with:
|
||||
# toolchain: nightly-2025-05-01
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG tests
|
||||
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG tests
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 56b806a3ba7866a3b061093bebd0fa2ace97f1fc --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
examples:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
needs: [build, library-tests, docs]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
RUSTFLAGS: "-C linker=gcc"
|
||||
OPENSSL_NO_VENDOR: 1
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: install libc6
|
||||
run: sudo apt-get install -y libc6
|
||||
- name: Install cmake and build dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
|
||||
- name: Force rebuild icicle dependencies
|
||||
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -505,20 +709,21 @@ jobs:
|
||||
python-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
runs-on: non-gpu,non-sgx
|
||||
needs: [build, library-tests, docs]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
@@ -539,25 +744,26 @@ jobs:
|
||||
accuracy-measurement-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: [non-gpu, non-sgx]
|
||||
needs: [build, library-tests, docs]
|
||||
runs-on: non-gpu,non-sgx
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -578,24 +784,24 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
needs: [build, library-tests, docs, python-tests]
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
|
||||
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -636,3 +842,92 @@ jobs:
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_verifier_ --no-capture
|
||||
- name: Reusable verifier tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_verifier_ --no-capture --test-threads 1
|
||||
|
||||
ios-integration-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Run ios tests
|
||||
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2025-05-01-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
|
||||
|
||||
swift-package-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: macos-latest
|
||||
needs: [ios-integration-tests]
|
||||
|
||||
env:
|
||||
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
|
||||
with:
|
||||
toolchain: nightly-2025-05-01
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift- repository
|
||||
run: |
|
||||
git clone https://github.com/zkonduit/ezkl-swift-package.git
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
134
.github/workflows/swift-pm.yml
vendored
Normal file
134
.github/workflows/swift-pm.yml
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
name: Build and Publish EZKL iOS SPM package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
# Only support SemVer versioning tags
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+'
|
||||
- '[0-9]+.[0-9]+.[0-9]+'
|
||||
|
||||
jobs:
|
||||
build-and-update:
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
runs-on: macos-latest
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
|
||||
steps:
|
||||
- name: Checkout EZKL
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Extract TAG from github.ref_name
|
||||
run: |
|
||||
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
|
||||
TAG="${RELEASE_TAG}"
|
||||
echo "Original TAG: $TAG"
|
||||
# Remove leading 'v' if present to match the Swift Package Manager version format.
|
||||
NEW_TAG=${TAG#v}
|
||||
echo "Stripped TAG: $NEW_TAG"
|
||||
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Rust (nightly)
|
||||
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
|
||||
with:
|
||||
toolchain: nightly
|
||||
override: true
|
||||
|
||||
- name: Build EzklCoreBindings
|
||||
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
|
||||
|
||||
- name: Clone ezkl-swift-package repository
|
||||
run: |
|
||||
git clone https://${{ env.EZKL_SWIFT_PACKAGE_REPO }}
|
||||
|
||||
- name: Copy EzklCoreBindings
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/
|
||||
mkdir -p ezkl-swift-package/Tests/EzklAssets/
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
if git diff --quiet Sources/EzklCoreBindings Tests/EzklAssets; then
|
||||
echo "no_changes=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "no_changes=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Set up Xcode environment
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
sudo xcodebuild -license accept
|
||||
|
||||
- name: Run Package Tests
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
xcodebuild test \
|
||||
-scheme EzklPackage \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-resultBundlePath ../testResults
|
||||
|
||||
- name: Run Example App Tests
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package/Example
|
||||
xcodebuild test \
|
||||
-project Example.xcodeproj \
|
||||
-scheme EzklApp \
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
|
||||
- name: Setup Git
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
git config user.name "GitHub Action"
|
||||
git config user.email "action@github.com"
|
||||
git remote set-url origin https://zkonduit:${EZKL_SWIFT_PACKAGE_REPO_TOKEN}@${{ env.EZKL_SWIFT_PACKAGE_REPO }}
|
||||
env:
|
||||
EZKL_SWIFT_PACKAGE_REPO_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
|
||||
|
||||
- name: Commit and Push Changes
|
||||
if: steps.check_changes.outputs.no_changes == 'false'
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
git add Sources/EzklCoreBindings Tests/EzklAssets
|
||||
git commit -m "Automatically updated EzklCoreBindings for EZKL"
|
||||
if ! git push origin; then
|
||||
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Tag the latest commit
|
||||
run: |
|
||||
cd ezkl-swift-package
|
||||
source $GITHUB_ENV
|
||||
# Tag the latest commit on the current branch
|
||||
if git rev-parse "$TAG" >/dev/null 2>&1; then
|
||||
echo "Tag $TAG already exists locally. Skipping tag creation."
|
||||
else
|
||||
git tag "$TAG"
|
||||
fi
|
||||
|
||||
if ! git push origin "$TAG"; then
|
||||
echo "::error::Failed to push tag '$TAG' to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure EZKL_PORTER_TOKEN has correct permissions."
|
||||
exit 1
|
||||
fi
|
||||
313
Cargo.lock
generated
313
Cargo.lock
generated
@@ -881,6 +881,17 @@ dependencies = [
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi 0.1.19",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aurora-engine-modexp"
|
||||
version = "1.2.0"
|
||||
@@ -959,6 +970,29 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.69.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash 1.1.0",
|
||||
"shlex",
|
||||
"syn 2.0.101",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
@@ -1177,6 +1211,15 @@ dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cexpr"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||
dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
@@ -1227,7 +1270,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"half",
|
||||
"half 2.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"libc",
|
||||
"libloading",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "2.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"textwrap 0.11.0",
|
||||
"unicode-width 0.1.14",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1258,7 +1323,7 @@ version = "4.5.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c06f5378ea264ad4f82bbc826628b5aad714a75abf6ece087e923010eb937fb6"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"clap 4.5.37",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1445,6 +1510,32 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"cast",
|
||||
"clap 2.34.0",
|
||||
"criterion-plot 0.4.5",
|
||||
"csv",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
"num-traits",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_cbor",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.5.1"
|
||||
@@ -1454,8 +1545,8 @@ dependencies = [
|
||||
"anes",
|
||||
"cast",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"clap 4.5.37",
|
||||
"criterion-plot 0.5.0",
|
||||
"is-terminal",
|
||||
"itertools 0.10.5",
|
||||
"num-traits",
|
||||
@@ -1471,6 +1562,16 @@ dependencies = [
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools 0.10.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
@@ -1543,6 +1644,27 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf"
|
||||
dependencies = [
|
||||
"csv-core",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv-core"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
@@ -1941,13 +2063,14 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"alloy",
|
||||
"bincode",
|
||||
"camino",
|
||||
"chrono",
|
||||
"clap",
|
||||
"clap 4.5.37",
|
||||
"clap_complete",
|
||||
"colored",
|
||||
"colored_json",
|
||||
"console_error_panic_hook",
|
||||
"criterion",
|
||||
"criterion 0.5.1",
|
||||
"ecc",
|
||||
"env_logger 0.10.2",
|
||||
"ethabi",
|
||||
@@ -1959,7 +2082,6 @@ dependencies = [
|
||||
"halo2_solidity_verifier",
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"hex",
|
||||
"icicle-runtime",
|
||||
"indicatif",
|
||||
"instant",
|
||||
"itertools 0.10.5",
|
||||
@@ -1995,7 +2117,9 @@ dependencies = [
|
||||
"tosubcommand",
|
||||
"tract-onnx",
|
||||
"uniffi",
|
||||
"uniffi_bindgen",
|
||||
"unzip-n",
|
||||
"uuid",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-console-logger",
|
||||
"wasm-bindgen-rayon",
|
||||
@@ -2204,7 +2328,7 @@ version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4"
|
||||
dependencies = [
|
||||
"rustix",
|
||||
"rustix 1.0.5",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -2397,6 +2521,12 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "1.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.6.0"
|
||||
@@ -2411,7 +2541,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac%2Fconditional-compilation-icicle2#1dd2090741f006fd031a07da7f3c9dfce5e0015e"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.6",
|
||||
"bitvec",
|
||||
@@ -2428,7 +2558,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#01c88842679b4308e43ae5ed91c4183e861669bd"
|
||||
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"blake2b_simd",
|
||||
@@ -2438,7 +2568,7 @@ dependencies = [
|
||||
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
|
||||
"icicle-bn254",
|
||||
"icicle-core",
|
||||
"icicle-runtime",
|
||||
"icicle-cuda-runtime",
|
||||
"instant",
|
||||
"lazy_static",
|
||||
"log",
|
||||
@@ -2446,7 +2576,7 @@ dependencies = [
|
||||
"mopro-msm",
|
||||
"rand_chacha 0.3.1",
|
||||
"rand_core 0.6.4",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"serde",
|
||||
"sha3 0.9.1",
|
||||
"tracing",
|
||||
@@ -2661,6 +2791,15 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.9"
|
||||
@@ -2852,45 +2991,33 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "icicle-bn254"
|
||||
version = "3.7.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
|
||||
version = "2.8.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
|
||||
dependencies = [
|
||||
"cmake",
|
||||
"criterion 0.3.6",
|
||||
"icicle-core",
|
||||
"icicle-hash",
|
||||
"icicle-runtime",
|
||||
"icicle-cuda-runtime",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icicle-core"
|
||||
version = "3.7.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
|
||||
version = "2.8.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
|
||||
dependencies = [
|
||||
"criterion 0.3.6",
|
||||
"hex",
|
||||
"icicle-runtime",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"icicle-cuda-runtime",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icicle-hash"
|
||||
version = "3.7.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
|
||||
name = "icicle-cuda-runtime"
|
||||
version = "2.8.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
|
||||
dependencies = [
|
||||
"cmake",
|
||||
"icicle-core",
|
||||
"icicle-runtime",
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icicle-runtime"
|
||||
version = "3.7.0"
|
||||
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
|
||||
dependencies = [
|
||||
"cmake",
|
||||
"once_cell",
|
||||
"bindgen",
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3343,12 +3470,28 @@ dependencies = [
|
||||
"spin",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazycell"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.172"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.13"
|
||||
@@ -3376,6 +3519,12 @@ dependencies = [
|
||||
"redox_syscall",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.9.4"
|
||||
@@ -3814,7 +3963,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"pyo3",
|
||||
"pyo3-build-config",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4245,6 +4394,16 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "primal-check"
|
||||
version = "0.3.4"
|
||||
@@ -4497,7 +4656,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"socket2",
|
||||
"thiserror 2.0.12",
|
||||
@@ -4516,7 +4675,7 @@ dependencies = [
|
||||
"getrandom 0.3.2",
|
||||
"rand 0.9.1",
|
||||
"ring",
|
||||
"rustc-hash",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
@@ -5014,6 +5173,12 @@ version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.1.1"
|
||||
@@ -5059,6 +5224,19 @@ dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.0.5"
|
||||
@@ -5068,7 +5246,7 @@ dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"linux-raw-sys 0.9.4",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -5316,6 +5494,16 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_cbor"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
|
||||
dependencies = [
|
||||
"half 1.8.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.219"
|
||||
@@ -5762,7 +5950,7 @@ dependencies = [
|
||||
"fastrand",
|
||||
"getrandom 0.3.2",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"rustix 1.0.5",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -5808,6 +5996,15 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||
dependencies = [
|
||||
"unicode-width 0.1.14",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.16.2"
|
||||
@@ -6201,7 +6398,7 @@ dependencies = [
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"dyn-hash",
|
||||
"half",
|
||||
"half 2.6.0",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"maplit",
|
||||
@@ -6236,7 +6433,7 @@ dependencies = [
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"dyn-hash",
|
||||
"half",
|
||||
"half 2.6.0",
|
||||
"lazy_static",
|
||||
"liquid",
|
||||
"liquid-core",
|
||||
@@ -6394,6 +6591,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31bff6daf87277a9014bcdefbc2842b0553392919d1096843c5aad899ca4588"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"uniffi_bindgen",
|
||||
"uniffi_build",
|
||||
"uniffi_core",
|
||||
"uniffi_macros",
|
||||
@@ -6416,7 +6614,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"paste",
|
||||
"serde",
|
||||
"textwrap",
|
||||
"textwrap 0.16.2",
|
||||
"toml 0.5.11",
|
||||
"uniffi_meta",
|
||||
"uniffi_testing",
|
||||
@@ -6508,7 +6706,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cef408229a3a407fafa4c36dc4f6ece78a6fb258ab28d2b64bddd49c8cb680f6"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"textwrap",
|
||||
"textwrap 0.16.2",
|
||||
"uniffi_meta",
|
||||
"uniffi_testing",
|
||||
"weedle2",
|
||||
@@ -6576,6 +6774,15 @@ version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
|
||||
dependencies = [
|
||||
"getrandom 0.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.1"
|
||||
@@ -6839,6 +7046,18 @@ dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
||||
dependencies = [
|
||||
"either",
|
||||
"home",
|
||||
"once_cell",
|
||||
"rustix 0.38.44",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
@@ -7216,7 +7435,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rustix",
|
||||
"rustix 1.0.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
55
Cargo.toml
55
Cargo.toml
@@ -16,12 +16,12 @@ crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "ac/conditional-compilation-icicle2" }
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
|
||||
"circuit-params", "mv-lookup"
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", features = [
|
||||
"circuit-params",
|
||||
] }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
itertools = { version = "0.10.3", default-features = false }
|
||||
@@ -33,10 +33,10 @@ thiserror = { version = "1.0.38", default-features = false }
|
||||
hex = { version = "0.4.3", default-features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde", "mv-lookup"
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/zkonduit/ezkl-verifier", branch = "main", optional = true, features = [
|
||||
"evm", "mv-lookup",
|
||||
"evm",
|
||||
] }
|
||||
maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
@@ -96,9 +96,12 @@ objc = { version = "0.2.4", optional = true }
|
||||
pyo3-stub-gen = { version = "0.6.0", optional = true }
|
||||
jemallocator = { version = "0.5", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
# GPU / device related things (optional - only enabled with gpu-accelerated feature)
|
||||
icicle-runtime = { git = "https://github.com/ingonyama-zk/icicle", branch="emir/gate_eval_2", package="icicle-runtime", optional = true }
|
||||
|
||||
# universal bindings
|
||||
uniffi = { version = "=0.28.0", optional = true }
|
||||
getrandom = { version = "0.2.8", optional = true }
|
||||
uniffi_bindgen = { version = "=0.28.0", optional = true }
|
||||
camino = { version = "^1.1", optional = true }
|
||||
uuid = { version = "1.10.0", features = ["v4"], optional = true }
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
|
||||
colored = { version = "2.0.0", default-features = false, optional = true }
|
||||
@@ -208,7 +211,9 @@ test = false
|
||||
bench = false
|
||||
required-features = ["ezkl"]
|
||||
|
||||
|
||||
[[bin]]
|
||||
name = "ios_gen_bindings"
|
||||
required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
|
||||
|
||||
[[bin]]
|
||||
name = "py_stub_gen"
|
||||
@@ -217,17 +222,24 @@ required-features = ["python-bindings"]
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = [
|
||||
"eth",
|
||||
"dep:halo2_solidity_verifier",
|
||||
"eth-mv-lookup",
|
||||
"ezkl",
|
||||
"precompute-coset",
|
||||
"no-banner",
|
||||
"parallel-poly-read",
|
||||
"reusable-verifier",
|
||||
]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
|
||||
universal-bindings = [
|
||||
"uniffi",
|
||||
"mv-lookup",
|
||||
"precompute-coset",
|
||||
"parallel-poly-read",
|
||||
"solidity-verifier-mv-lookup",
|
||||
]
|
||||
logging = ["dep:colored", "dep:env_logger", "dep:chrono"]
|
||||
ios-bindings = ["universal-bindings"]
|
||||
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
|
||||
ezkl = [
|
||||
"onnx",
|
||||
"tabled/color",
|
||||
@@ -249,6 +261,10 @@ ezkl = [
|
||||
"logging",
|
||||
]
|
||||
eth = ["dep:alloy", "dep:foundry-compilers", "dep:ethabi"]
|
||||
solidity-verifier = ["dep:halo2_solidity_verifier"]
|
||||
solidity-verifier-mv-lookup = ["halo2_solidity_verifier/mv-lookup"]
|
||||
eth-mv-lookup = ["solidity-verifier-mv-lookup", "mv-lookup", "eth"]
|
||||
eth-original-lookup = ["eth", "solidity-verifier"]
|
||||
parallel-poly-read = [
|
||||
"halo2_proofs/circuit-params",
|
||||
"halo2_proofs/parallel-poly-read",
|
||||
@@ -257,7 +273,7 @@ mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup"]
|
||||
asm = ["halo2curves/asm", "halo2_proofs/asm"]
|
||||
precompute-coset = ["halo2_proofs/precompute-coset"]
|
||||
det-prove = []
|
||||
gpu-accelerated = ["halo2_proofs/gpu-accelerated", "dep:icicle-runtime"]
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
no-update = []
|
||||
@@ -268,17 +284,6 @@ mimalloc = ["dep:mimalloc"]
|
||||
reusable-verifier = []
|
||||
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
|
||||
"circuit-params", "mv-lookup"
|
||||
] }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
|
||||
"circuit-params", "mv-lookup"
|
||||
] }
|
||||
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
@@ -296,5 +301,3 @@ opt-level = 3
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = ["-O4", "--flexible-inline-max-function-size", "4294967295"]
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_keys;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
@@ -152,6 +153,8 @@ fn runcnvrl(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -2,6 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -119,6 +120,8 @@ fn rundot(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -1,78 +1,53 @@
|
||||
use criterion::{
|
||||
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
|
||||
Throughput,
|
||||
};
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
use ezkl::pfsys::{create_keys, create_proof_circuit};
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
static mut K: usize = 15;
|
||||
const K: usize = 16;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum_params: SingleEinsumParams<F>,
|
||||
struct MyCircuit {
|
||||
inputs: [ValTensor<Fr>; 2],
|
||||
_marker: PhantomData<Fr>,
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
impl Circuit<Fr> for MyCircuit {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = SingleEinsumParams<Fr>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = ();
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
unsafe {
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
let _constant = VarTensor::constant_cols(cs, K, 2, false);
|
||||
}
|
||||
let a = VarTensor::new_advice(cs, K, 1, len * len);
|
||||
|
||||
config
|
||||
}
|
||||
let b = VarTensor::new_advice(cs, K, 1, len * len);
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
SingleEinsumParams::<Fr>::new(
|
||||
&self.einsum_params.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
let output = VarTensor::new_advice(cs, K, 1, (len + 1) * len);
|
||||
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -80,33 +55,16 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.as_ref()
|
||||
.ok_or(Error::Synthesis)?
|
||||
.challenges()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum_params.equation.clone(),
|
||||
equation: "ab,bc->ac".to_string(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -119,49 +77,41 @@ impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
|
||||
fn runmatmul(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("accum_einsum_matmul");
|
||||
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear));
|
||||
group.sampling_mode(criterion::SamplingMode::Flat);
|
||||
group.sample_size(10);
|
||||
let len = 128;
|
||||
unsafe {
|
||||
LEN = len;
|
||||
}
|
||||
for k in 16..17 {
|
||||
let params = unsafe {
|
||||
K = k;
|
||||
gen_srs::<KZGCommitmentScheme<_>>(K as u32)
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
|
||||
for &len in [4, 32].iter() {
|
||||
unsafe {
|
||||
LEN = len;
|
||||
};
|
||||
|
||||
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[len, len]).unwrap();
|
||||
|
||||
let einsum_params = SingleEinsumParams::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum_params,
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, ¶ms, false)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit<Fr>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
@@ -174,6 +124,8 @@ fn runmatmul(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -5,7 +5,7 @@ use ezkl::circuit::*;
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -154,6 +154,8 @@ fn runmatmul(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -5,7 +5,7 @@ use ezkl::circuit::lookup::LookupOp;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -157,6 +157,8 @@ fn runmatmul(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -116,6 +116,8 @@ fn runsum(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -4,7 +4,7 @@ use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_keys;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
@@ -131,6 +131,8 @@ fn runsumpool(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -118,6 +118,8 @@ fn runadd(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -3,7 +3,7 @@ use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -117,6 +117,8 @@ fn runpow(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -8,7 +8,7 @@ use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_keys;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -104,6 +104,8 @@ fn runposeidon(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -4,7 +4,7 @@ use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{BaseConfig as Config, CheckMode};
|
||||
use ezkl::fieldutils::IntegerRep;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -130,6 +130,8 @@ fn runrelu(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -4,7 +4,7 @@ use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
|
||||
use ezkl::fieldutils::IntegerRep;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
@@ -124,6 +124,8 @@ fn runrelu(c: &mut Criterion) {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
4
build.rs
4
build.rs
@@ -1,3 +1,7 @@
|
||||
fn main() {
|
||||
if cfg!(feature = "ios-bindings-test") {
|
||||
println!("cargo::rustc-env=UNIFFI_CARGO_BUILD_EXTRA_ARGS=--features=ios-bindings --no-default-features");
|
||||
}
|
||||
|
||||
println!("cargo::rerun-if-changed=build.rs");
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '23.0.2'
|
||||
release = '22.2.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const K: usize = 13;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let mut config = Self::Config::default();
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.as_ref()
|
||||
.ok_or(Error::Synthesis)?
|
||||
.challenges()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runmatmul() {
|
||||
let len = 64;
|
||||
|
||||
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[len, len]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runmatmul()
|
||||
}
|
||||
@@ -1,179 +0,0 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 11;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, 1, len);
|
||||
let b = VarTensor::new_advice(cs, K, 1, len);
|
||||
let output = VarTensor::new_advice(cs, K, 1, len);
|
||||
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 1;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.as_ref()
|
||||
.ok_or(Error::Synthesis)?
|
||||
.challenges()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runbatchmatmul() {
|
||||
let batch_size = 5;
|
||||
let len = 12;
|
||||
|
||||
let mut a = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[batch_size, len, len]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[batch_size, len, len]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("ijk,ikl->ijl", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runbatchmatmul()
|
||||
}
|
||||
@@ -866,7 +866,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 98,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -879,7 +879,6 @@
|
||||
"run_args.input_visibility = \"private\"\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.disable_freivalds = True\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -1143,4 +1142,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -253,6 +253,8 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
@@ -301,4 +303,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -546,7 +546,7 @@
|
||||
"\n",
|
||||
"proof_path = os.path.join('proof.json')\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(proof_path=proof_path)\n",
|
||||
"proof = ezkl.prove(proof_type=\"single\", proof_path=proof_path)\n",
|
||||
"\n",
|
||||
"print(proof)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
@@ -736,4 +736,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
@@ -574,7 +574,7 @@
|
||||
"\n",
|
||||
"proof_path = os.path.join('proof.json')\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(proof_path=proof_path)\n",
|
||||
"proof = ezkl.prove(proof_type=\"single\", proof_path=proof_path)\n",
|
||||
"\n",
|
||||
"print(proof)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
@@ -768,4 +768,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
@@ -54,7 +54,7 @@
|
||||
" gip_run_args.param_scale = 19\n",
|
||||
" gip_run_args.logrows = 8\n",
|
||||
" run_args = ezkl.gen_settings(py_run_args=gip_run_args)\n",
|
||||
" await ezkl.get_srs()\n",
|
||||
" await ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n",
|
||||
" ezkl.compile_circuit()\n",
|
||||
" res = ezkl.gen_witness()\n",
|
||||
" print(res)\n",
|
||||
@@ -127,4 +127,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -105,7 +105,7 @@
|
||||
"\n",
|
||||
"class GCNConv(Module):\n",
|
||||
" def __init__(self, in_channels, out_channels):\n",
|
||||
" super(GCNConv, self).__init__() \n",
|
||||
" super(GCNConv, self).__init__() # \"Add\" aggregation.\n",
|
||||
" self.lin = torch.nn.Linear(in_channels, out_channels)\n",
|
||||
"\n",
|
||||
" self.reset_parameters()\n",
|
||||
@@ -563,6 +563,7 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
@@ -624,4 +625,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -286,6 +286,8 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
@@ -339,4 +341,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -248,7 +248,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -263,6 +263,8 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
@@ -311,4 +313,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -368,7 +368,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -236,7 +236,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -240,7 +240,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -358,7 +358,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -278,7 +278,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -232,7 +232,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -442,7 +442,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -227,7 +227,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -252,7 +252,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -422,7 +422,7 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -378,7 +378,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -301,7 +301,7 @@
|
||||
"run_args.param_scale = 0\n",
|
||||
"run_args.logrows = 18\n",
|
||||
"\n",
|
||||
"ezkl.get_srs(logrows=run_args.logrows, )\n"
|
||||
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -399,6 +399,7 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"for-aggr\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(res)\n",
|
||||
@@ -437,6 +438,28 @@
|
||||
" print(\"----- proving split \"+str(i))\n",
|
||||
" prove_model(i)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also mock aggregate the split proofs into a single proof. This is useful if you want to verify the proof on chain at a lower cost. Here we mock aggregate the proofs to save time. You can use other notebooks to see how to aggregate in full ! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now mock aggregate the proofs\n",
|
||||
"# proofs = []\n",
|
||||
"# for i in range(3):\n",
|
||||
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
"# proofs.append(proof_path)\n",
|
||||
"\n",
|
||||
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -461,4 +484,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -303,7 +303,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
@@ -543,7 +543,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -939,7 +939,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -234,7 +234,7 @@
|
||||
"run_args.input_scale = 2\n",
|
||||
"run_args.logrows = 15\n",
|
||||
"\n",
|
||||
"ezkl.get_srs(logrows=run_args.logrows, )"
|
||||
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -330,6 +330,7 @@
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"for-aggr\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(res)\n",
|
||||
@@ -425,6 +426,28 @@
|
||||
"for i in range(2):\n",
|
||||
" prove_model(i)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also mock aggregate the split proofs into a single proof. This is useful if you want to verify the proof on chain at a lower cost. Here we mock aggregate the proofs to save time. You can use other notebooks to see how to aggregate in full ! "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now mock aggregate the proofs\n",
|
||||
"proofs = []\n",
|
||||
"for i in range(2):\n",
|
||||
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
|
||||
" proofs.append(proof_path)\n",
|
||||
"\n",
|
||||
"ezkl.mock_aggregate(proofs, logrows=22, split_proofs = True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -449,4 +472,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -260,7 +260,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -173,7 +173,7 @@
|
||||
" assert os.path.isfile(settings_path)\n",
|
||||
"\n",
|
||||
" # GENERATE A PROOF\n",
|
||||
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path)\n",
|
||||
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, \"single\")\n",
|
||||
" assert os.path.isfile(proof_path)\n",
|
||||
"\n",
|
||||
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",
|
||||
|
||||
@@ -384,7 +384,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
@@ -411,7 +411,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path_faulty,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
@@ -438,7 +438,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path_truthy,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
407
examples/notebooks/simple_demo_aggregated_proofs.ipynb
Normal file
407
examples/notebooks/simple_demo_aggregated_proofs.ipynb
Normal file
@@ -0,0 +1,407 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## EZKL Jupyter Notebook Demo (Aggregated Proofs) \n",
|
||||
"\n",
|
||||
"Demonstrates how to use EZKL with aggregated proofs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"from torch import nn\n",
|
||||
"import ezkl\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Defines the model\n",
|
||||
"# we got convs, we got relu, we got linear layers\n",
|
||||
"# What else could one want ????\n",
|
||||
"\n",
|
||||
"class MyModel(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MyModel, self).__init__()\n",
|
||||
"\n",
|
||||
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)\n",
|
||||
" self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)\n",
|
||||
"\n",
|
||||
" self.relu = nn.ReLU()\n",
|
||||
"\n",
|
||||
" self.d1 = nn.Linear(48, 48)\n",
|
||||
" self.d2 = nn.Linear(48, 10)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" # 32x1x28x28 => 32x32x26x26\n",
|
||||
" x = self.conv1(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
" x = self.conv2(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
"\n",
|
||||
" # flatten => 32 x (32*26*26)\n",
|
||||
" x = x.flatten(start_dim = 1)\n",
|
||||
"\n",
|
||||
" # 32 x (32*26*26) => 32x128\n",
|
||||
" x = self.d1(x)\n",
|
||||
" x = self.relu(x)\n",
|
||||
"\n",
|
||||
" # logits => 32x10\n",
|
||||
" logits = self.d2(x)\n",
|
||||
"\n",
|
||||
" return logits\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = MyModel()\n",
|
||||
"\n",
|
||||
"# Train the model as you like here (skipped for brevity)\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"srs_path = os.path.join('kzg.srs')\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')\n",
|
||||
"aggregate_proof_path = os.path.join('aggr.pf')\n",
|
||||
"aggregate_vk_path = os.path.join('aggr.vk')\n",
|
||||
"aggregate_pk_path = os.path.join('aggr.pk')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"shape = [1, 28, 28]\n",
|
||||
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
|
||||
"x = 0.1*torch.rand(1,*shape, requires_grad=True)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
" # Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" model_path, # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( data, open(data_path, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cal_path = os.path.join(\"calibration.json\")\n",
|
||||
"\n",
|
||||
"data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aa4f090",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"for-aggr\", # IMPORTANT NOTE: To produce an aggregated EVM proof you will want to use poseidon for the smaller proofs\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0832b909",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate a larger SRS. This is needed for the aggregated proof\n",
|
||||
"\n",
|
||||
"res = await ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c5a64be6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run mock aggregate to check whether the proof works\n",
|
||||
"# Use mock to check for validity as it takes a shorter time to check compared to a full aggregated proof\n",
|
||||
"\n",
|
||||
"res = ezkl.mock_aggregate([proof_path], 21)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fee8acc6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Setup the vk and pk for aggregate\n",
|
||||
"res = ezkl.setup_aggregate(\n",
|
||||
" [proof_path],\n",
|
||||
" aggregate_vk_path,\n",
|
||||
" aggregate_pk_path,\n",
|
||||
" 21\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(aggregate_vk_path)\n",
|
||||
"assert os.path.isfile(aggregate_pk_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "171702d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run aggregate proof\n",
|
||||
"res = ezkl.aggregate(\n",
|
||||
" [proof_path],\n",
|
||||
" aggregate_proof_path,\n",
|
||||
" aggregate_pk_path,\n",
|
||||
" \"evm\",\n",
|
||||
" 21,\n",
|
||||
" \"safe\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(aggregate_proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "671dfdd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Check if the proof is valid\n",
|
||||
"res = ezkl.verify_aggr(\n",
|
||||
" aggregate_proof_path,\n",
|
||||
" aggregate_vk_path,\n",
|
||||
" 21,\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"id": "50eba2f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create a smart contract verifier for the aggregated proof\n",
|
||||
"\n",
|
||||
"sol_code_path = os.path.join(\"Verifier.sol\")\n",
|
||||
"abi_path = os.path.join(\"Verifier_ABI.json\")\n",
|
||||
"\n",
|
||||
"res = await ezkl.create_evm_verifier_aggr(\n",
|
||||
" [settings_path],\n",
|
||||
" aggregate_vk_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
" logrows=21)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -255,7 +255,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -253,7 +253,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -254,7 +254,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -233,7 +233,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -323,7 +323,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
@@ -442,7 +442,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -271,7 +271,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -236,7 +236,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -707,7 +707,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -596,7 +596,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -580,7 +580,7 @@
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
||||
@@ -759,7 +759,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -277,7 +277,7 @@
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" ",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
from torch import nn
|
||||
import torch.nn.init as init
|
||||
import torch
|
||||
import json
|
||||
|
||||
N = 100
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, inplace=False):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.aff1 = nn.Linear(N,N)
|
||||
self.aff2 = nn.Linear(N,N)
|
||||
self.aff3 = nn.Linear(N,N)
|
||||
self.aff4 = nn.Linear(N,N)
|
||||
self.aff5 = nn.Linear(N,N)
|
||||
self.aff6 = nn.Linear(N,N)
|
||||
self.aff7 = nn.Linear(N,N)
|
||||
self.aff8 = nn.Linear(N,N)
|
||||
self.aff9 = nn.Linear(N,N)
|
||||
self.relu = nn.ReLU()
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
# concat 10 x along dim 0
|
||||
x = x.repeat(10, 1)
|
||||
x = self.aff1(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff2(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff3(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff4(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff5(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff6(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff7(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff8(x)
|
||||
x = self.relu(x)
|
||||
x = self.aff9(x)
|
||||
return x
|
||||
|
||||
|
||||
def _initialize_weights(self):
|
||||
init.orthogonal_(self.aff1.weight)
|
||||
|
||||
model = Model()
|
||||
|
||||
# Flips the neural net into inference mode
|
||||
model.eval()
|
||||
model.to('cpu')
|
||||
|
||||
|
||||
x = torch.randn(1, N)
|
||||
# Export the model
|
||||
torch.onnx.export(model, # model being run
|
||||
# model input (or a tuple for multiple inputs)
|
||||
x,
|
||||
# where to save the model (can be a file or file-like object)
|
||||
"network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=12, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data_json = dict(input_data=[data_array])
|
||||
|
||||
print(data_json)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data_json, open("input.json", 'w'))
|
||||
@@ -1 +0,0 @@
|
||||
{"input_data": [[0.33088353276252747, -0.8819183707237244, 1.245591163635254, -1.807046890258789, 1.9922369718551636, -0.3360576629638672, 0.4529011845588684, -0.3590165674686432, 0.08356846123933792, 0.5126393437385559, 0.44627535343170166, 1.4916497468948364, 0.49731069803237915, -0.9748706817626953, -0.4923185408115387, 1.3548223972320557, 0.2306872010231018, 1.125955581665039, -1.7063908576965332, 0.3777385354042053, -2.7988760471343994, -1.1846797466278076, 0.7473157048225403, 1.490412950515747, 0.017497723922133446, 2.113945245742798, -1.2141249179840088, -0.16120357811450958, 0.021127669140696526, 0.7207374572753906, -1.369688868522644, -0.7369781732559204, -0.630584180355072, -0.4520200788974762, 0.29123976826667786, 0.6334688067436218, -0.869332492351532, -1.258501648902893, 0.3012596666812897, -0.5507447123527527, 0.669975757598877, 0.15088629722595215, -0.1050339788198471, 0.5505334138870239, -0.1287376880645752, -1.4297826290130615, -0.01703289896249771, -1.2296998500823975, 0.5122153162956238, -0.16924428939819336, -0.415036678314209, -1.1979341506958008, 0.05831022188067436, -0.4411357045173645, 2.0713791847229004, 1.4611141681671143, -0.9357407093048096, -0.333297461271286, -0.676478385925293, 1.390028476715088, -0.05827632546424866, 1.535687804222107, 0.3060210347175598, -0.03171076253056526, -0.614985466003418, 1.2040390968322754, 0.31318482756614685, -1.2134959697723389, 0.13110508024692535, -1.4880926609039307, 1.7007993459701538, 1.5412729978561401, 0.09260450303554535, 0.7649128437042236, -0.5009126663208008, -0.5356241464614868, -0.069572813808918, -0.011717632412910461, 0.21314217150211334, -0.1985170543193817, -0.0223808903247118, 1.2128918170928955, 0.8334696888923645, 1.9029873609542847, -0.11491120606660843, -0.10303237289190292, -0.2467050403356552, 1.557223916053772, -1.1108328104019165, -0.9065343141555786, -0.2271333783864975, 0.6959827542304993, -0.48698121309280396, 0.5689510703086853, 1.115319013595581, -0.8907430768013, -0.24722427129745483, -0.7437837719917297, 0.6742106676101685, -1.7830933332443237]]}
|
||||
Binary file not shown.
@@ -1,23 +1,23 @@
|
||||
## The worm
|
||||
## The worm
|
||||
|
||||
This is an onnx file for a [WormVAE](https://github.com/TuragaLab/wormvae?tab=readme-ov-file) model, which is a VAE / latent-space representation of the C. elegans connectome.
|
||||
|
||||
The model "is a large-scale latent variable model with a very high-dimensional latent space
|
||||
consisting of voltage dynamics of 300 neurons over 5 minutes of time at the simulation frequency
|
||||
of 160 Hz. The generative model for these latent variables is described by stochastic differential
|
||||
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
|
||||
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
|
||||
|
||||
In effect this is a generative model for a worm's voltage dynamics, which can be used to generate new worm-like voltage dynamics given previous connectome state.
|
||||
|
||||
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
|
||||
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
|
||||
|
||||
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
|
||||
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
|
||||
|
||||
```bash
|
||||
git lfs fetch --all
|
||||
```
|
||||
|
||||
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
|
||||
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
|
||||
|
||||
```bash
|
||||
ezkl gen-settings --param-visibility=fixed
|
||||
@@ -28,7 +28,17 @@ ezkl gen-witness
|
||||
ezkl prove
|
||||
```
|
||||
|
||||
You might also need to aggregate the proof to get it to fit on chain.
|
||||
|
||||
```bash
|
||||
ezkl aggregate
|
||||
```
|
||||
|
||||
You can then create a smart contract that verifies this aggregate proof
|
||||
|
||||
```bash
|
||||
ezkl create-evm-verifier-aggr
|
||||
```
|
||||
|
||||
This can then be deployed on the chain of your choice.
|
||||
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::floor_planner::V1;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::Fr;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
const K: usize = 11;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
einsum: Einsum<F>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
|
||||
equation: String,
|
||||
input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for MyCircuit<Fr> {
|
||||
type Config = BaseConfig<Fr>;
|
||||
type FloorPlanner = V1;
|
||||
type Params = Einsum<Fr>;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
|
||||
let len = unsafe { LEN };
|
||||
|
||||
let a = VarTensor::new_advice(cs, K, 1, len);
|
||||
let b = VarTensor::new_advice(cs, K, 1, len);
|
||||
let output = VarTensor::new_advice(cs, K, 1, len);
|
||||
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
|
||||
|
||||
let mut equations = HashMap::new();
|
||||
equations.insert((0, params.equation), params.input_axes_to_dims);
|
||||
let analysis = analyze_einsum_usage(&equations).unwrap();
|
||||
let num_einsum_inner_cols = 2;
|
||||
config
|
||||
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
|
||||
.unwrap();
|
||||
let _constant = VarTensor::constant_cols(cs, K, 2, false);
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn params(&self) -> Self::Params {
|
||||
Einsum::<Fr>::new(
|
||||
&self.einsum.equation,
|
||||
&[
|
||||
&self.inputs[0].get_inner().unwrap(),
|
||||
&self.inputs[1].get_inner().unwrap(),
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
unimplemented!("call configure_with_params instead")
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), Error> {
|
||||
let challenges = config
|
||||
.einsums
|
||||
.as_ref()
|
||||
.ok_or(Error::Synthesis)?
|
||||
.challenges()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec();
|
||||
|
||||
layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new_with_challenges(
|
||||
region,
|
||||
0,
|
||||
1,
|
||||
1024,
|
||||
2,
|
||||
challenges.clone(),
|
||||
);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: self.einsum.equation.clone(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn runmatmul() {
|
||||
let i = 10;
|
||||
let n = 10;
|
||||
let j = 40;
|
||||
let k = 10;
|
||||
|
||||
let mut a = Tensor::from((0..i * n * j).map(|_| Value::known(Fr::random(OsRng))));
|
||||
a.reshape(&[i, n, j]).unwrap();
|
||||
|
||||
// parameters
|
||||
let mut b = Tensor::from((0..j * k).map(|_| Value::known(Fr::random(OsRng))));
|
||||
b.reshape(&[j, k]).unwrap();
|
||||
|
||||
let einsum = Einsum::<Fr>::new("inj,jk->ik", &[&a, &b]).unwrap();
|
||||
|
||||
let circuit = MyCircuit {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(b)],
|
||||
einsum,
|
||||
};
|
||||
|
||||
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
mock_prover.assert_satisfied();
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
runmatmul()
|
||||
}
|
||||
573
ezkl.pyi
573
ezkl.pyi
@@ -10,27 +10,31 @@ class PyG1:
|
||||
r"""
|
||||
pyclass containing the struct used for G1, this is mostly a helper class
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
class PyG1Affine:
|
||||
r"""
|
||||
pyclass containing the struct used for G1
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
class PyRunArgs:
|
||||
r"""
|
||||
Python class containing the struct used for run_args
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
PyRunArgs
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
class PyCommitments(Enum):
|
||||
r"""
|
||||
pyclass representing an enum, denoting the type of commitment
|
||||
"""
|
||||
KZG = auto()
|
||||
IPA = auto()
|
||||
|
||||
class PyInputType(Enum):
|
||||
Bool = auto()
|
||||
F16 = auto()
|
||||
@@ -43,19 +47,57 @@ class PyTestDataSource(Enum):
|
||||
r"""
|
||||
pyclass representing an enum
|
||||
"""
|
||||
|
||||
File = auto()
|
||||
OnChain = auto()
|
||||
|
||||
def buffer_to_felts(buffer: typing.Sequence[int]) -> list[str]:
|
||||
def aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,transcript:str,logrows:int,check_mode:str,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:PyCommitments) -> bool:
|
||||
r"""
|
||||
Creates an aggregated proof
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_snarks: list[str]
|
||||
List of paths to the various proofs
|
||||
|
||||
proof_path: str
|
||||
Path to output the aggregated proof
|
||||
|
||||
vk_path: str
|
||||
Path to the VK file
|
||||
|
||||
transcript:
|
||||
Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
|
||||
|
||||
logrows:
|
||||
Logrows used for aggregation circuit
|
||||
|
||||
check_mode: str
|
||||
Run sanity checks during calculations. Accepts `safe` or `unsafe`
|
||||
|
||||
split-proofs: bool
|
||||
Whether the accumulated proofs are segments of a larger circuit
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS used
|
||||
|
||||
commitment: str
|
||||
Accepts "kzg" or "ipa"
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def buffer_to_felts(buffer:typing.Sequence[int]) -> list[str]:
|
||||
r"""
|
||||
Converts a buffer to vector of field elements
|
||||
|
||||
|
||||
Arguments
|
||||
-------
|
||||
buffer: list[int]
|
||||
List of integers representing a buffer
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
@@ -63,175 +105,173 @@ def buffer_to_felts(buffer: typing.Sequence[int]) -> list[str]:
|
||||
"""
|
||||
...
|
||||
|
||||
def calibrate_settings(
|
||||
data: str | os.PathLike | pathlib.Path,
|
||||
model: str | os.PathLike | pathlib.Path,
|
||||
settings: str | os.PathLike | pathlib.Path,
|
||||
target: str,
|
||||
lookup_safety_margin: float,
|
||||
scales: typing.Optional[typing.Sequence[int]],
|
||||
scale_rebase_multiplier: typing.Sequence[int],
|
||||
max_logrows: typing.Optional[int],
|
||||
) -> typing.Any:
|
||||
def calibrate_settings(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,settings:str | os.PathLike | pathlib.Path,target:str,lookup_safety_margin:float,scales:typing.Optional[typing.Sequence[int]],scale_rebase_multiplier:typing.Sequence[int],max_logrows:typing.Optional[int]) -> typing.Any:
|
||||
r"""
|
||||
Calibrates the circuit settings
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data: str
|
||||
Path to the calibration data
|
||||
|
||||
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
|
||||
settings: str
|
||||
Path to the settings file
|
||||
|
||||
|
||||
lookup_safety_margin: int
|
||||
the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
|
||||
|
||||
scales: list[int]
|
||||
Optional scales to specifically try for calibration
|
||||
|
||||
|
||||
scale_rebase_multiplier: list[int]
|
||||
Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
|
||||
|
||||
|
||||
max_logrows: int
|
||||
Optional max logrows to use for calibration
|
||||
|
||||
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def compile_circuit(
|
||||
model: str | os.PathLike | pathlib.Path,
|
||||
compiled_circuit: str | os.PathLike | pathlib.Path,
|
||||
settings_path: str | os.PathLike | pathlib.Path,
|
||||
) -> bool:
|
||||
def compile_circuit(model:str | os.PathLike | pathlib.Path,compiled_circuit:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Compiles the circuit for use in other steps
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx model file
|
||||
|
||||
|
||||
compiled_circuit: str
|
||||
Path to output the compiled circuit
|
||||
|
||||
|
||||
settings_path: str
|
||||
Path to the settings files
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_verifier(
|
||||
vk_path: str | os.PathLike | pathlib.Path,
|
||||
settings_path: str | os.PathLike | pathlib.Path,
|
||||
sol_code_path: str | os.PathLike | pathlib.Path,
|
||||
abi_path: str | os.PathLike | pathlib.Path,
|
||||
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
reusable: bool,
|
||||
) -> typing.Any:
|
||||
def create_evm_verifier(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
|
||||
r"""
|
||||
Creates an EVM compatible verifier, you will need solc installed in your environment to run this
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
|
||||
settings_path: str
|
||||
The path to the settings file
|
||||
|
||||
|
||||
sol_code_path: str
|
||||
The path to the create the solidity verifier
|
||||
|
||||
|
||||
abi_path: str
|
||||
The path to create the ABI for the solidity verifier
|
||||
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
|
||||
reusable: bool
|
||||
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_vka(
|
||||
vk_path: str | os.PathLike | pathlib.Path,
|
||||
settings_path: str | os.PathLike | pathlib.Path,
|
||||
vka_path: str | os.PathLike | pathlib.Path,
|
||||
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
) -> typing.Any:
|
||||
def create_evm_verifier_aggr(aggregation_settings:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,logrows:int,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
|
||||
r"""
|
||||
Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_settings: str
|
||||
path to the settings file
|
||||
|
||||
vk_path: str
|
||||
The path to load the desired verification key file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the Solidity code
|
||||
|
||||
abi_path: str
|
||||
The path to output the Solidity verifier ABI
|
||||
|
||||
logrows: int
|
||||
Number of logrows used during aggregated setup
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
reusable: bool
|
||||
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vka_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
|
||||
|
||||
This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
|
||||
settings_path: str
|
||||
The path to the settings file
|
||||
|
||||
|
||||
vka_path: str
|
||||
The path to the create the vka calldata.
|
||||
|
||||
|
||||
abi_path: str
|
||||
The path to create the ABI for the solidity verifier
|
||||
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def deploy_evm(
|
||||
addr_path: str | os.PathLike | pathlib.Path,
|
||||
sol_code_path: str | os.PathLike | pathlib.Path,
|
||||
rpc_url: typing.Optional[str],
|
||||
contract_type: str,
|
||||
optimizer_runs: int,
|
||||
private_key: typing.Optional[str],
|
||||
) -> typing.Any:
|
||||
def deploy_evm(addr_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],contract_type:str,optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
deploys the solidity verifier
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_evm_calldata(
|
||||
proof: str | os.PathLike | pathlib.Path,
|
||||
calldata: str | os.PathLike | pathlib.Path,
|
||||
addr_vk: typing.Optional[str],
|
||||
) -> list[int]:
|
||||
def encode_evm_calldata(proof:str | os.PathLike | pathlib.Path,calldata:str | os.PathLike | pathlib.Path,addr_vk:typing.Optional[str]) -> list[int]:
|
||||
r"""
|
||||
Creates encoded evm calldata from a proof file
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof: str
|
||||
Path to the proof file
|
||||
|
||||
|
||||
calldata: str
|
||||
Path to the calldata file to save
|
||||
|
||||
|
||||
addr_vk: str
|
||||
The address of the verification key contract (if the verifier key is to be rendered as a separate contract)
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
vec[u8]
|
||||
@@ -239,16 +279,16 @@ def encode_evm_calldata(
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_big_endian(felt: str) -> str:
|
||||
def felt_to_big_endian(felt:str) -> str:
|
||||
r"""
|
||||
Converts a field element hex string to big endian
|
||||
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
@@ -256,54 +296,54 @@ def felt_to_big_endian(felt: str) -> str:
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_float(felt: str, scale: int) -> float:
|
||||
def felt_to_float(felt:str,scale:int) -> float:
|
||||
r"""
|
||||
Converts a field element hex string to a floating point number
|
||||
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
|
||||
scale: float
|
||||
The scaling factor used to convert the field element into a floating point representation
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_int(felt: str) -> int:
|
||||
def felt_to_int(felt:str) -> int:
|
||||
r"""
|
||||
Converts a field element hex string to an integer
|
||||
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
"""
|
||||
...
|
||||
|
||||
def float_to_felt(input: float, scale: int, input_type: PyInputType) -> str:
|
||||
def float_to_felt(input:float,scale:int,input_type:PyInputType) -> str:
|
||||
r"""
|
||||
Converts a floating point element to a field element hex string
|
||||
|
||||
|
||||
Arguments
|
||||
-------
|
||||
input: float
|
||||
The field element represented as a string
|
||||
|
||||
|
||||
scale: float
|
||||
The scaling factor used to quantize the float into a field element
|
||||
|
||||
|
||||
input_type: PyInputType
|
||||
The type of the input
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
@@ -311,97 +351,101 @@ def float_to_felt(input: float, scale: int, input_type: PyInputType) -> str:
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_settings(
|
||||
model: str | os.PathLike | pathlib.Path,
|
||||
output: str | os.PathLike | pathlib.Path,
|
||||
py_run_args: typing.Optional[PyRunArgs],
|
||||
) -> bool:
|
||||
def gen_settings(model:str | os.PathLike | pathlib.Path,output:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> bool:
|
||||
r"""
|
||||
Generates the circuit settings
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
|
||||
output: str
|
||||
Path to create the settings file
|
||||
|
||||
|
||||
py_run_args: PyRunArgs
|
||||
PyRunArgs object to initialize the settings
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_srs(srs_path: str | os.PathLike | pathlib.Path, logrows: int) -> None:
|
||||
def gen_srs(srs_path:str | os.PathLike | pathlib.Path,logrows:int) -> None:
|
||||
r"""
|
||||
Generates the Structured Reference String (SRS), use this only for testing purposes
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
srs_path: str
|
||||
Path to the create the SRS file
|
||||
|
||||
|
||||
logrows: int
|
||||
The number of logrows for the SRS file
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_vk_from_pk_single(
|
||||
path_to_pk: str | os.PathLike | pathlib.Path,
|
||||
circuit_settings_path: str | os.PathLike | pathlib.Path,
|
||||
vk_output_path: str | os.PathLike | pathlib.Path,
|
||||
) -> bool:
|
||||
def gen_vk_from_pk_aggr(path_to_pk:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Generates a vk from a pk for a model circuit and saves it to a file
|
||||
|
||||
Generates a vk from a pk for an aggregate circuit and saves it to a file
|
||||
|
||||
Arguments
|
||||
-------
|
||||
path_to_pk: str
|
||||
Path to the proving key
|
||||
|
||||
circuit_settings_path: str
|
||||
Path to the witness file
|
||||
|
||||
|
||||
vk_output_path: str
|
||||
Path to create the vk file
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_witness(
|
||||
data: str | os.PathLike | pathlib.Path,
|
||||
model: str | os.PathLike | pathlib.Path,
|
||||
output: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
vk_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
) -> typing.Any:
|
||||
def gen_vk_from_pk_single(path_to_pk:str | os.PathLike | pathlib.Path,circuit_settings_path:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Generates a vk from a pk for a model circuit and saves it to a file
|
||||
|
||||
Arguments
|
||||
-------
|
||||
path_to_pk: str
|
||||
Path to the proving key
|
||||
|
||||
circuit_settings_path: str
|
||||
Path to the witness file
|
||||
|
||||
vk_output_path: str
|
||||
Path to create the vk file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_witness(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,output:typing.Optional[str | os.PathLike | pathlib.Path],vk_path:typing.Optional[str | os.PathLike | pathlib.Path],srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Runs the forward pass operation to generate a witness
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data: str
|
||||
Path to the data file
|
||||
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
|
||||
output: str
|
||||
Path to create the witness file
|
||||
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
@@ -409,89 +453,126 @@ def gen_witness(
|
||||
"""
|
||||
...
|
||||
|
||||
def get_srs(
|
||||
settings_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
logrows: typing.Optional[int],
|
||||
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
) -> typing.Any:
|
||||
def get_srs(settings_path:typing.Optional[str | os.PathLike | pathlib.Path],logrows:typing.Optional[int],srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:typing.Optional[PyCommitments]) -> typing.Any:
|
||||
r"""
|
||||
Gets a public srs
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
|
||||
logrows: int
|
||||
The number of logrows for the SRS file
|
||||
|
||||
|
||||
srs_path: str
|
||||
Path to the create the SRS file
|
||||
|
||||
|
||||
commitment: str
|
||||
Specify the commitment used ("kzg", "ipa")
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def kzg_commit(
|
||||
message: typing.Sequence[str],
|
||||
vk_path: str | os.PathLike | pathlib.Path,
|
||||
settings_path: str | os.PathLike | pathlib.Path,
|
||||
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
) -> list[PyG1Affine]:
|
||||
def ipa_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
|
||||
r"""
|
||||
Generate a kzg commitment.
|
||||
|
||||
Generate an ipa commitment.
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represnted as strings
|
||||
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
|
||||
srs_path: str
|
||||
Path to the Structure Reference String (SRS) file
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PyG1Affine]
|
||||
"""
|
||||
...
|
||||
|
||||
def mock(
|
||||
witness: str | os.PathLike | pathlib.Path, model: str | os.PathLike | pathlib.Path
|
||||
) -> bool:
|
||||
def kzg_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
|
||||
r"""
|
||||
Generate a kzg commitment.
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represnted as strings
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
srs_path: str
|
||||
Path to the Structure Reference String (SRS) file
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PyG1Affine]
|
||||
"""
|
||||
...
|
||||
|
||||
def mock(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Mocks the prover
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
witness: str
|
||||
Path to the witness file
|
||||
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def poseidon_hash(message: typing.Sequence[str]) -> list[str]:
|
||||
def mock_aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],logrows:int,split_proofs:bool) -> bool:
|
||||
r"""
|
||||
Mocks the aggregate prover
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_snarks: list[str]
|
||||
List of paths to the relevant proof files
|
||||
|
||||
logrows: int
|
||||
Number of logrows to use for the aggregation circuit
|
||||
|
||||
split_proofs: bool
|
||||
Indicates whether the accumulated are segments of a larger proof
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def poseidon_hash(message:typing.Sequence[str]) -> list[str]:
|
||||
r"""
|
||||
Generate a poseidon hash.
|
||||
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represented as strings
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
@@ -499,104 +580,126 @@ def poseidon_hash(message: typing.Sequence[str]) -> list[str]:
|
||||
"""
|
||||
...
|
||||
|
||||
def prove(
|
||||
witness: str | os.PathLike | pathlib.Path,
|
||||
model: str | os.PathLike | pathlib.Path,
|
||||
pk_path: str | os.PathLike | pathlib.Path,
|
||||
proof_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
) -> typing.Any:
|
||||
def prove(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,proof_path:typing.Optional[str | os.PathLike | pathlib.Path],proof_type:str,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Runs the prover on a set of inputs
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
witness: str
|
||||
Path to the witness file
|
||||
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
|
||||
pk_path: str
|
||||
Path to the proving key file
|
||||
|
||||
|
||||
proof_path: str
|
||||
Path to create the proof file
|
||||
|
||||
|
||||
proof_type: str
|
||||
Accepts `single`, `for-aggr`
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def setup(
|
||||
model: str | os.PathLike | pathlib.Path,
|
||||
vk_path: str | os.PathLike | pathlib.Path,
|
||||
pk_path: str | os.PathLike | pathlib.Path,
|
||||
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
witness_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
disable_selector_compression: bool,
|
||||
) -> bool:
|
||||
def setup(model:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],witness_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool) -> bool:
|
||||
r"""
|
||||
Runs the setup process
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
|
||||
vk_path: str
|
||||
Path to create the verification key file
|
||||
|
||||
|
||||
pk_path: str
|
||||
Path to create the proving key file
|
||||
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
|
||||
witness_path: str
|
||||
Path to the witness file
|
||||
|
||||
|
||||
disable_selector_compression: bool
|
||||
Whether to compress the selectors or not
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def swap_proof_commitments(
|
||||
proof_path: str | os.PathLike | pathlib.Path,
|
||||
witness_path: str | os.PathLike | pathlib.Path,
|
||||
) -> None:
|
||||
def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,logrows:int,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool,commitment:PyCommitments) -> bool:
|
||||
r"""
|
||||
Runs the setup process for an aggregate setup
|
||||
|
||||
Arguments
|
||||
---------
|
||||
sample_snarks: list[str]
|
||||
List of paths to the various proofs
|
||||
|
||||
vk_path: str
|
||||
Path to create the aggregated VK
|
||||
|
||||
pk_path: str
|
||||
Path to create the aggregated PK
|
||||
|
||||
logrows: int
|
||||
Number of logrows to use
|
||||
|
||||
split_proofs: bool
|
||||
Whether the accumulated are segments of a larger proof
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
disable_selector_compression: bool
|
||||
Whether to compress selectors
|
||||
|
||||
commitment: str
|
||||
Accepts `kzg`, `ipa`
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def swap_proof_commitments(proof_path:str | os.PathLike | pathlib.Path,witness_path:str | os.PathLike | pathlib.Path) -> None:
|
||||
r"""
|
||||
Swap the commitments in a proof
|
||||
|
||||
|
||||
Arguments
|
||||
-------
|
||||
proof_path: str
|
||||
Path to the proof file
|
||||
|
||||
|
||||
witness_path: str
|
||||
Path to the witness file
|
||||
"""
|
||||
...
|
||||
|
||||
def table(
|
||||
model: str | os.PathLike | pathlib.Path, py_run_args: typing.Optional[PyRunArgs]
|
||||
) -> str:
|
||||
def table(model:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> str:
|
||||
r"""
|
||||
Displays the table as a string in python
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
|
||||
Returns
|
||||
---------
|
||||
str
|
||||
@@ -604,59 +707,78 @@ def table(
|
||||
"""
|
||||
...
|
||||
|
||||
def verify(
|
||||
proof_path: str | os.PathLike | pathlib.Path,
|
||||
settings_path: str | os.PathLike | pathlib.Path,
|
||||
vk_path: str | os.PathLike | pathlib.Path,
|
||||
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
|
||||
reduced_srs: bool,
|
||||
) -> bool:
|
||||
def verify(proof_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reduced_srs:bool) -> bool:
|
||||
r"""
|
||||
Verifies a given proof
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof_path: str
|
||||
Path to create the proof file
|
||||
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key file
|
||||
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
|
||||
non_reduced_srs: bool
|
||||
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def verify_evm(
|
||||
addr_verifier: str,
|
||||
proof_path: str | os.PathLike | pathlib.Path,
|
||||
rpc_url: typing.Optional[str],
|
||||
vka_path: typing.Optional[str],
|
||||
) -> typing.Any:
|
||||
def verify_aggr(proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,logrows:int,commitment:PyCommitments,reduced_srs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> bool:
|
||||
r"""
|
||||
Verifies and aggregate proof
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof_path: str
|
||||
The path to the proof file
|
||||
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
logrows: int
|
||||
logrows used for aggregation circuit
|
||||
|
||||
commitment: str
|
||||
Accepts "kzg" or "ipa"
|
||||
|
||||
reduced_srs: bool
|
||||
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],vka_path:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
verifies an evm compatible proof, you will need solc installed in your environment to run this
|
||||
|
||||
|
||||
Arguments
|
||||
---------
|
||||
addr_verifier: str
|
||||
The verifier contract's address as a hex string
|
||||
|
||||
|
||||
proof_path: str
|
||||
The path to the proof file (generated using the prove command)
|
||||
|
||||
|
||||
rpc_url: str
|
||||
RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
|
||||
|
||||
vka_path: str
|
||||
The path to the VKA calldata bytes file (generated using the create_evm_vka command)
|
||||
Returns
|
||||
@@ -664,3 +786,4 @@ def verify_evm(
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
4
jest.config.js
Normal file
4
jest.config.js
Normal file
@@ -0,0 +1,4 @@
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
};
|
||||
30
package.json
Normal file
30
package.json
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"name": "ezkljs-tests",
|
||||
"version": "0.1.0",
|
||||
"author": "Ethan Cemer",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"test": "jest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"@ezkljs/verify": "^0.0.6",
|
||||
"@jest/types": "^29.6.3",
|
||||
"@types/file-saver": "^2.0.5",
|
||||
"@types/jest": "^29.5.3",
|
||||
"@types/json-bigint": "^1.0.1",
|
||||
"@types/node": "20.4.5",
|
||||
"buffer": "^6.0.3",
|
||||
"env": "^0.0.2",
|
||||
"fs": "0.0.1-security",
|
||||
"jest": "^29.6.3",
|
||||
"json-bigint": "^1.0.0",
|
||||
"minimist": "^1.2.8",
|
||||
"solc": "^0.8.21",
|
||||
"ts-jest": "^29.1.1",
|
||||
"ts-loader": "^9.4.4",
|
||||
"ts-node": "^10.9.1",
|
||||
"tsconfig-paths": "^4.2.0",
|
||||
"typescript": "5.1.6"
|
||||
}
|
||||
}
|
||||
3596
pnpm-lock.yaml
generated
Normal file
3596
pnpm-lock.yaml
generated
Normal file
File diff suppressed because it is too large
Load Diff
229
setup-gpu.sh
229
setup-gpu.sh
@@ -1,229 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Default installation directory
|
||||
DEFAULT_INSTALL_DIR="/opt/icicle/lib/backend/halo2"
|
||||
|
||||
# Halo2 repository details
|
||||
HALO2_REPO="https://github.com/zkonduit/halo2"
|
||||
HALO2_BRANCH="ac/conditional-compilation-icicle2"
|
||||
|
||||
# Parse command line arguments
|
||||
AUTO_YES=false
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
-y|--yes)
|
||||
AUTO_YES=true
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo "Options:"
|
||||
echo " -y, --yes Automatically answer 'yes' to all prompts"
|
||||
echo " -h, --help Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $arg"
|
||||
echo "Use -h or --help for usage information"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo -e "${GREEN}EZKL GPU Setup Script${NC}"
|
||||
echo -e "${GREEN}=====================${NC}"
|
||||
echo ""
|
||||
|
||||
# Parse commit hash from Cargo.lock
|
||||
echo "Parsing halo2 commit hash from Cargo.lock..."
|
||||
if [ ! -f "Cargo.lock" ]; then
|
||||
echo -e "${RED}Error: Cargo.lock not found. Please run this script from the project root.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HALO2_COMMIT=$(grep "github\.com/zkonduit/halo2?" Cargo.lock | grep -v "halo2wrong" | head -1 | grep -o "#[a-f0-9]\{40\}" | cut -c2-)
|
||||
|
||||
if [ -z "$HALO2_COMMIT" ]; then
|
||||
echo -e "${RED}Error: Could not parse halo2 commit hash from Cargo.lock${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Found halo2 commit: $HALO2_COMMIT${NC}"
|
||||
echo ""
|
||||
echo "This script will:"
|
||||
echo "1. Sparse checkout the halo2 repository at commit $HALO2_COMMIT"
|
||||
echo "2. Extract only the icicle/backend/cuda/ directory"
|
||||
echo "3. Set the ICICLE_BACKEND_INSTALL_DIR environment variable"
|
||||
echo ""
|
||||
|
||||
# Check if user wants to override the default directory
|
||||
if [ "$AUTO_YES" = true ]; then
|
||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
||||
echo -e "${GREEN}Using default installation directory: ${INSTALL_DIR}${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}Default installation directory: ${DEFAULT_INSTALL_DIR}${NC}"
|
||||
read -p "Do you want to use a different directory? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
read -p "Enter the installation directory: " INSTALL_DIR
|
||||
INSTALL_DIR="${INSTALL_DIR/#\~/$HOME}" # Expand ~ to $HOME
|
||||
else
|
||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
||||
fi
|
||||
|
||||
# Confirm the installation directory
|
||||
echo ""
|
||||
echo -e "${YELLOW}Installation directory: ${INSTALL_DIR}${NC}"
|
||||
read -p "Continue with this directory? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo -e "${RED}Setup cancelled by user.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check if ICICLE_BACKEND_INSTALL_DIR is already set
|
||||
if [ ! -z "$ICICLE_BACKEND_INSTALL_DIR" ] && [ "$AUTO_YES" = false ]; then
|
||||
echo ""
|
||||
echo -e "${YELLOW}Warning: ICICLE_BACKEND_INSTALL_DIR is already set to: $ICICLE_BACKEND_INSTALL_DIR${NC}"
|
||||
read -p "Do you want to override it? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo -e "${RED}Setup cancelled by user.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
elif [ ! -z "$ICICLE_BACKEND_INSTALL_DIR" ] && [ "$AUTO_YES" = true ]; then
|
||||
echo -e "${GREEN}Overriding existing ICICLE_BACKEND_INSTALL_DIR (was: $ICICLE_BACKEND_INSTALL_DIR)${NC}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}Starting GPU setup...${NC}"
|
||||
|
||||
# Create installation directory
|
||||
echo "Creating installation directory..."
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
|
||||
# Create temporary directory for sparse checkout
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
echo "Using temporary directory: $TEMP_DIR"
|
||||
|
||||
# Clone with sparse checkout
|
||||
echo "Cloning halo2 repository with sparse checkout..."
|
||||
cd "$TEMP_DIR"
|
||||
git clone --filter=blob:none --sparse "$HALO2_REPO" halo2
|
||||
cd halo2
|
||||
|
||||
# Checkout the specific branch and commit
|
||||
echo "Checking out branch $HALO2_BRANCH at commit $HALO2_COMMIT..."
|
||||
git checkout "$HALO2_BRANCH"
|
||||
git checkout "$HALO2_COMMIT"
|
||||
|
||||
# Configure sparse checkout
|
||||
echo "Configuring sparse checkout for icicle/backend/cuda/..."
|
||||
git sparse-checkout init --cone
|
||||
git sparse-checkout set icicle/backend/cuda/
|
||||
|
||||
# Copy the icicle directory to the installation location
|
||||
if [ -d "icicle/backend/cuda" ]; then
|
||||
echo "Copying icicle/backend/cuda/ to $INSTALL_DIR..."
|
||||
cp -r icicle/backend/cuda/* "$INSTALL_DIR/"
|
||||
echo -e "${GREEN}Files copied successfully!${NC}"
|
||||
else
|
||||
echo -e "${RED}Error: icicle/backend/cuda directory not found in the repository${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Clean up temporary directory
|
||||
echo "Cleaning up temporary files..."
|
||||
rm -rf "$TEMP_DIR"
|
||||
|
||||
# Ask user about setting environment variable permanently
|
||||
SETUP_PERMANENT_ENV=false
|
||||
if [ "$AUTO_YES" = true ]; then
|
||||
SETUP_PERMANENT_ENV=true
|
||||
echo ""
|
||||
echo -e "${GREEN}Setting ICICLE_BACKEND_INSTALL_DIR environment variable permanently...${NC}"
|
||||
else
|
||||
echo ""
|
||||
echo -e "${YELLOW}Do you want to set ICICLE_BACKEND_INSTALL_DIR environment variable permanently?${NC}"
|
||||
echo "This will add 'export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\"' to your shell configuration file."
|
||||
read -p "Set environment variable permanently? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
SETUP_PERMANENT_ENV=true
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$SETUP_PERMANENT_ENV" = true ]; then
|
||||
echo "Setting ICICLE_BACKEND_INSTALL_DIR environment variable..."
|
||||
|
||||
# Detect shell and set environment variable accordingly
|
||||
if [ -n "$ZSH_VERSION" ]; then
|
||||
SHELL_RC="$HOME/.zshrc"
|
||||
elif [ -n "$BASH_VERSION" ]; then
|
||||
SHELL_RC="$HOME/.bashrc"
|
||||
else
|
||||
# Try to detect based on $SHELL
|
||||
case "$SHELL" in
|
||||
*/zsh)
|
||||
SHELL_RC="$HOME/.zshrc"
|
||||
;;
|
||||
*/bash)
|
||||
SHELL_RC="$HOME/.bashrc"
|
||||
;;
|
||||
*)
|
||||
SHELL_RC="$HOME/.profile"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Add environment variable to shell configuration
|
||||
ENV_EXPORT="export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\""
|
||||
|
||||
# Check if the variable is already set in the file
|
||||
if [ -f "$SHELL_RC" ] && grep -q "ICICLE_BACKEND_INSTALL_DIR" "$SHELL_RC"; then
|
||||
# Replace existing line
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
# macOS
|
||||
sed -i '' "s|export ICICLE_BACKEND_INSTALL_DIR=.*|$ENV_EXPORT|" "$SHELL_RC"
|
||||
else
|
||||
# Linux
|
||||
sed -i "s|export ICICLE_BACKEND_INSTALL_DIR=.*|$ENV_EXPORT|" "$SHELL_RC"
|
||||
fi
|
||||
echo "Updated existing ICICLE_BACKEND_INSTALL_DIR in $SHELL_RC"
|
||||
else
|
||||
# Add new line
|
||||
echo "$ENV_EXPORT" >> "$SHELL_RC"
|
||||
echo "Added ICICLE_BACKEND_INSTALL_DIR to $SHELL_RC"
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Environment variable set permanently.${NC}"
|
||||
else
|
||||
echo "Skipping permanent environment variable setup."
|
||||
fi
|
||||
|
||||
# Export for current session regardless
|
||||
export ICICLE_BACKEND_INSTALL_DIR="$INSTALL_DIR"
|
||||
echo "Environment variable set for current session."
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}GPU setup completed successfully!${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}Important:${NC}"
|
||||
echo "1. The ICICLE_BACKEND_INSTALL_DIR environment variable has been set to: $INSTALL_DIR"
|
||||
if [ "$SETUP_PERMANENT_ENV" = true ]; then
|
||||
echo "2. Please restart your terminal or run: source $SHELL_RC"
|
||||
else
|
||||
echo "2. To use GPU features, set: export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\""
|
||||
fi
|
||||
echo "3. You can now build with GPU support using: cargo build --features gpu-accelerated"
|
||||
echo ""
|
||||
echo -e "${GREEN}Setup complete!${NC}"
|
||||
@@ -15,6 +15,9 @@ use log::{error, info};
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
use rand::prelude::SliceRandom;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(feature = "icicle")]
|
||||
use std::env;
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
pub async fn main() {
|
||||
@@ -28,7 +31,12 @@ pub async fn main() {
|
||||
init_logger();
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
banner();
|
||||
|
||||
#[cfg(feature = "icicle")]
|
||||
if env::var("ENABLE_ICICLE_GPU").is_ok() {
|
||||
info!("Running with ICICLE GPU");
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
}
|
||||
debug!(
|
||||
"command: \n {}",
|
||||
&command.as_json().to_colored_json_auto().unwrap()
|
||||
|
||||
269
src/bin/ios_gen_bindings.rs
Normal file
269
src/bin/ios_gen_bindings.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
use camino::Utf8Path;
|
||||
use std::fs;
|
||||
use std::fs::remove_dir_all;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use uniffi_bindgen::bindings::SwiftBindingGenerator;
|
||||
use uniffi_bindgen::library_mode::generate_bindings;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn main() {
|
||||
let library_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME is not set");
|
||||
let mode = determine_build_mode();
|
||||
build_bindings(&library_name, mode);
|
||||
}
|
||||
|
||||
/// Determines the build mode based on the CONFIGURATION environment variable.
|
||||
/// Defaults to "release" if not set or unrecognized.
|
||||
/// "release" mode takes longer to build but produces optimized code, which has smaller size and is faster.
|
||||
fn determine_build_mode() -> &'static str {
|
||||
match std::env::var("CONFIGURATION").map(|s| s.to_lowercase()) {
|
||||
Ok(ref config) if config == "debug" => "debug",
|
||||
_ => "release",
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds the Swift bindings and XCFramework for the specified library and build mode.
|
||||
fn build_bindings(library_name: &str, mode: &str) {
|
||||
// Get the root directory of this Cargo project
|
||||
let manifest_dir = std::env::var_os("CARGO_MANIFEST_DIR")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| std::env::current_dir().unwrap());
|
||||
|
||||
// Define the build directory inside the manifest directory
|
||||
let build_dir = manifest_dir.join("build");
|
||||
|
||||
// Create a temporary directory to store the bindings and combined library
|
||||
let tmp_dir = mktemp_local(&build_dir);
|
||||
|
||||
// Define directories for Swift bindings and output bindings
|
||||
let swift_bindings_dir = tmp_dir.join("SwiftBindings");
|
||||
let bindings_out = create_bindings_out_dir(&tmp_dir);
|
||||
let framework_out = bindings_out.join("EzklCore.xcframework");
|
||||
|
||||
// Define target architectures for building
|
||||
// We currently only support iOS devices and simulators running on ARM Macs
|
||||
// This is due to limiting the library size to under 100MB for GitHub Commit Size Limit
|
||||
// To support older Macs (Intel), follow the instructions in the comments below
|
||||
#[allow(clippy::useless_vec)]
|
||||
let target_archs = vec![
|
||||
vec!["aarch64-apple-ios"], // iOS device
|
||||
vec!["aarch64-apple-ios-sim"], // iOS simulator ARM Mac
|
||||
// vec!["aarch64-apple-ios-sim", "x86_64-apple-ios"], // TODO - replace the above line with this line to allow running on older Macs (Intel)
|
||||
];
|
||||
|
||||
// Build the library for each architecture and combine them
|
||||
let out_lib_paths: Vec<PathBuf> = target_archs
|
||||
.iter()
|
||||
.map(|archs| build_combined_archs(library_name, archs, &build_dir, mode))
|
||||
.collect();
|
||||
|
||||
// Generate the path to the built dynamic library (.dylib)
|
||||
let out_dylib_path = build_dir.join(format!(
|
||||
"{}/{}/lib{}.dylib",
|
||||
target_archs[0][0], mode, library_name
|
||||
));
|
||||
|
||||
// Generate Swift bindings using uniffi_bindgen
|
||||
generate_ios_bindings(&out_dylib_path, &swift_bindings_dir)
|
||||
.expect("Failed to generate iOS bindings");
|
||||
|
||||
// Move the generated Swift file to the bindings output directory
|
||||
fs::rename(
|
||||
swift_bindings_dir.join(format!("{}.swift", library_name)),
|
||||
bindings_out.join("EzklCore.swift"),
|
||||
)
|
||||
.expect("Failed to copy swift bindings file");
|
||||
|
||||
// Rename the `ios_ezklFFI.modulemap` file to `module.modulemap`
|
||||
fs::rename(
|
||||
swift_bindings_dir.join(format!("{}FFI.modulemap", library_name)),
|
||||
swift_bindings_dir.join("module.modulemap"),
|
||||
)
|
||||
.expect("Failed to rename modulemap file");
|
||||
|
||||
// Create the XCFramework from the combined libraries and Swift bindings
|
||||
create_xcframework(&out_lib_paths, &swift_bindings_dir, &framework_out);
|
||||
|
||||
// Define the destination directory for the bindings
|
||||
let bindings_dest = build_dir.join("EzklCoreBindings");
|
||||
if bindings_dest.exists() {
|
||||
fs::remove_dir_all(&bindings_dest).expect("Failed to remove existing bindings directory");
|
||||
}
|
||||
|
||||
// Move the bindings output to the destination directory
|
||||
fs::rename(&bindings_out, &bindings_dest).expect("Failed to move framework into place");
|
||||
|
||||
// Clean up temporary directories
|
||||
cleanup_temp_dirs(&build_dir);
|
||||
}
|
||||
|
||||
/// Creates the output directory for the bindings.
|
||||
/// Returns the path to the bindings output directory.
|
||||
fn create_bindings_out_dir(base_dir: &Path) -> PathBuf {
|
||||
let bindings_out = base_dir.join("EzklCoreBindings");
|
||||
fs::create_dir_all(&bindings_out).expect("Failed to create bindings output directory");
|
||||
bindings_out
|
||||
}
|
||||
|
||||
/// Builds the library for each architecture and combines them into a single library using lipo.
|
||||
/// Returns the path to the combined library.
|
||||
fn build_combined_archs(
|
||||
library_name: &str,
|
||||
archs: &[&str],
|
||||
build_dir: &Path,
|
||||
mode: &str,
|
||||
) -> PathBuf {
|
||||
// Build the library for each architecture
|
||||
let out_lib_paths: Vec<PathBuf> = archs
|
||||
.iter()
|
||||
.map(|&arch| {
|
||||
build_for_arch(arch, build_dir, mode);
|
||||
build_dir
|
||||
.join(arch)
|
||||
.join(mode)
|
||||
.join(format!("lib{}.a", library_name))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create a unique temporary directory for the combined library
|
||||
let lib_out = mktemp_local(build_dir).join(format!("lib{}.a", library_name));
|
||||
|
||||
// Combine the libraries using lipo
|
||||
let mut lipo_cmd = Command::new("lipo");
|
||||
lipo_cmd
|
||||
.arg("-create")
|
||||
.arg("-output")
|
||||
.arg(lib_out.to_str().unwrap());
|
||||
for lib_path in &out_lib_paths {
|
||||
lipo_cmd.arg(lib_path.to_str().unwrap());
|
||||
}
|
||||
|
||||
let status = lipo_cmd.status().expect("Failed to run lipo command");
|
||||
if !status.success() {
|
||||
panic!("lipo command failed with status: {}", status);
|
||||
}
|
||||
|
||||
lib_out
|
||||
}
|
||||
|
||||
/// Builds the library for a specific architecture.
|
||||
fn build_for_arch(arch: &str, build_dir: &Path, mode: &str) {
|
||||
// Ensure the target architecture is installed
|
||||
install_arch(arch);
|
||||
|
||||
// Run cargo build for the specified architecture and mode
|
||||
let mut build_cmd = Command::new("cargo");
|
||||
build_cmd
|
||||
.arg("build")
|
||||
.arg("--no-default-features")
|
||||
.arg("--features")
|
||||
.arg("ios-bindings");
|
||||
|
||||
if mode == "release" {
|
||||
build_cmd.arg("--release");
|
||||
}
|
||||
build_cmd
|
||||
.arg("--lib")
|
||||
.env("CARGO_BUILD_TARGET_DIR", build_dir)
|
||||
.env("CARGO_BUILD_TARGET", arch);
|
||||
|
||||
let status = build_cmd.status().expect("Failed to run cargo build");
|
||||
if !status.success() {
|
||||
panic!("cargo build failed for architecture: {}", arch);
|
||||
}
|
||||
}
|
||||
|
||||
/// Installs the specified target architecture using rustup.
|
||||
fn install_arch(arch: &str) {
|
||||
let status = Command::new("rustup")
|
||||
.arg("target")
|
||||
.arg("add")
|
||||
.arg(arch)
|
||||
.status()
|
||||
.expect("Failed to run rustup command");
|
||||
|
||||
if !status.success() {
|
||||
panic!("Failed to install target architecture: {}", arch);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates Swift bindings for the iOS library using uniffi_bindgen.
|
||||
fn generate_ios_bindings(dylib_path: &Path, binding_dir: &Path) -> Result<(), std::io::Error> {
|
||||
// Remove existing binding directory if it exists
|
||||
if binding_dir.exists() {
|
||||
remove_dir_all(binding_dir)?;
|
||||
}
|
||||
|
||||
// Generate the Swift bindings using uniffi_bindgen
|
||||
generate_bindings(
|
||||
Utf8Path::from_path(dylib_path).ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid dylib path")
|
||||
})?,
|
||||
None,
|
||||
&SwiftBindingGenerator,
|
||||
None,
|
||||
Utf8Path::from_path(binding_dir).ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
"Invalid Swift bindings directory",
|
||||
)
|
||||
})?,
|
||||
true,
|
||||
)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates an XCFramework from the combined libraries and Swift bindings.
|
||||
fn create_xcframework(lib_paths: &[PathBuf], swift_bindings_dir: &Path, framework_out: &Path) {
|
||||
let mut xcbuild_cmd = Command::new("xcodebuild");
|
||||
xcbuild_cmd.arg("-create-xcframework");
|
||||
|
||||
// Add each library and its corresponding headers to the xcodebuild command
|
||||
for lib_path in lib_paths {
|
||||
println!("Including library: {:?}", lib_path);
|
||||
xcbuild_cmd.arg("-library");
|
||||
xcbuild_cmd.arg(lib_path.to_str().unwrap());
|
||||
xcbuild_cmd.arg("-headers");
|
||||
xcbuild_cmd.arg(swift_bindings_dir.to_str().unwrap());
|
||||
}
|
||||
|
||||
xcbuild_cmd.arg("-output");
|
||||
xcbuild_cmd.arg(framework_out.to_str().unwrap());
|
||||
|
||||
let status = xcbuild_cmd.status().expect("Failed to run xcodebuild");
|
||||
if !status.success() {
|
||||
panic!("xcodebuild failed with status: {}", status);
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a temporary directory inside the build path with a unique UUID.
|
||||
/// This ensures unique build artifacts for concurrent builds.
|
||||
fn mktemp_local(build_path: &Path) -> PathBuf {
|
||||
let dir = tmp_local(build_path).join(Uuid::new_v4().to_string());
|
||||
fs::create_dir(&dir).expect("Failed to create temporary directory");
|
||||
dir
|
||||
}
|
||||
|
||||
/// Gets the path to the local temporary directory inside the build path.
|
||||
fn tmp_local(build_path: &Path) -> PathBuf {
|
||||
let tmp_path = build_path.join("tmp");
|
||||
if let Ok(metadata) = fs::metadata(&tmp_path) {
|
||||
if !metadata.is_dir() {
|
||||
panic!("Expected 'tmp' to be a directory");
|
||||
}
|
||||
} else {
|
||||
fs::create_dir_all(&tmp_path).expect("Failed to create local temporary directory");
|
||||
}
|
||||
tmp_path
|
||||
}
|
||||
|
||||
/// Cleans up temporary directories inside the build path.
|
||||
fn cleanup_temp_dirs(build_dir: &Path) {
|
||||
let tmp_dir = build_dir.join("tmp");
|
||||
if tmp_dir.exists() {
|
||||
fs::remove_dir_all(tmp_dir).expect("Failed to remove temporary directories");
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,12 @@
|
||||
/// Python bindings
|
||||
#[cfg(feature = "python-bindings")]
|
||||
pub mod python;
|
||||
/// Universal bindings for all platforms
|
||||
#[cfg(any(
|
||||
feature = "universal-bindings",
|
||||
all(target_arch = "wasm32", target_os = "unknown")
|
||||
))]
|
||||
pub mod universal;
|
||||
/// wasm prover and verifier
|
||||
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
|
||||
pub mod wasm;
|
||||
|
||||
@@ -12,10 +12,14 @@ use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
|
||||
};
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::{
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
|
||||
ProofType, TranscriptType,
|
||||
};
|
||||
use crate::Commitments;
|
||||
use crate::RunArgs;
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
|
||||
use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
@@ -89,6 +93,7 @@ impl From<PyG1> for G1 {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// pyclass containing the struct used for G1
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -120,6 +125,7 @@ impl From<PyG1Affine> for G1Affine {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Python class containing the struct used for run_args
|
||||
///
|
||||
/// Returns
|
||||
@@ -169,6 +175,9 @@ struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// str: check mode, accepts `safe`, `unsafe`
|
||||
pub check_mode: CheckMode,
|
||||
#[pyo3(get, set)]
|
||||
/// str: commitment type, accepts `kzg`, `ipa`
|
||||
pub commitment: PyCommitments,
|
||||
/// int: The base used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_base: usize,
|
||||
@@ -184,9 +193,6 @@ struct PyRunArgs {
|
||||
/// float: epsilon used for arguments that use division
|
||||
#[pyo3(get, set)]
|
||||
pub epsilon: f64,
|
||||
/// bool: Whether to disable using Freivalds' argument in einsum operations
|
||||
#[pyo3(get, set)]
|
||||
pub disable_freivalds: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -216,11 +222,11 @@ impl From<PyRunArgs> for RunArgs {
|
||||
variables: py_run_args.variables,
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
decomp_base: py_run_args.decomp_base,
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
|
||||
epsilon: Some(py_run_args.epsilon),
|
||||
disable_freivalds: py_run_args.disable_freivalds,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,11 +249,61 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
variables: self.variables,
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
commitment: self.commitment.into(),
|
||||
decomp_base: self.decomp_base,
|
||||
decomp_legs: self.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
|
||||
epsilon: eps,
|
||||
disable_freivalds: self.disable_freivalds,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass_enum]
|
||||
/// pyclass representing an enum, denoting the type of commitment
|
||||
pub enum PyCommitments {
|
||||
/// KZG commitment
|
||||
KZG,
|
||||
/// IPA commitment
|
||||
IPA,
|
||||
}
|
||||
|
||||
impl From<Option<Commitments>> for PyCommitments {
|
||||
fn from(commitment: Option<Commitments>) -> Self {
|
||||
match commitment {
|
||||
Some(Commitments::KZG) => PyCommitments::KZG,
|
||||
Some(Commitments::IPA) => PyCommitments::IPA,
|
||||
None => PyCommitments::KZG,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyCommitments> for Commitments {
|
||||
fn from(py_commitments: PyCommitments) -> Self {
|
||||
match py_commitments {
|
||||
PyCommitments::KZG => Commitments::KZG,
|
||||
PyCommitments::IPA => Commitments::IPA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<PyCommitments> for Commitments {
|
||||
fn into(self) -> PyCommitments {
|
||||
match self {
|
||||
Commitments::KZG => PyCommitments::KZG,
|
||||
Commitments::IPA => PyCommitments::IPA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PyCommitments {
|
||||
type Err = String;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"kzg" => Ok(PyCommitments::KZG),
|
||||
"ipa" => Ok(PyCommitments::IPA),
|
||||
_ => Err("Invalid value for Commitments".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -564,7 +620,8 @@ fn kzg_commit(
|
||||
let settings = GraphSettings::load(&settings_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
|
||||
|
||||
let srs_path = crate::execute::get_srs_path(settings.run_args.logrows, srs_path);
|
||||
let srs_path =
|
||||
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
|
||||
|
||||
let srs = load_srs_prover::<KZGCommitmentScheme<Bn256>>(srs_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
|
||||
@@ -581,6 +638,65 @@ fn kzg_commit(
|
||||
Ok(output.iter().map(|x| (*x).into()).collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
/// Generate an ipa commitment.
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the Structure Reference String (SRS) file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[PyG1Affine]
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn ipa_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Vec<PyG1Affine>> {
|
||||
let message: Vec<Fr> = message
|
||||
.iter()
|
||||
.map(crate::pfsys::string_to_field::<Fr>)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
|
||||
|
||||
let srs_path =
|
||||
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
|
||||
|
||||
let srs = load_srs_prover::<IPACommitmentScheme<G1Affine>>(srs_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
|
||||
|
||||
let vk = load_vk::<IPACommitmentScheme<G1Affine>, GraphCircuit>(vk_path, settings)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load vk"))?;
|
||||
|
||||
let output = PolyCommitChip::commit::<IPACommitmentScheme<G1Affine>>(
|
||||
message,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
&srs,
|
||||
);
|
||||
|
||||
Ok(output.iter().map(|x| (*x).into()).collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
/// Swap the commitments in a proof
|
||||
///
|
||||
/// Arguments
|
||||
@@ -646,6 +762,37 @@ fn gen_vk_from_pk_single(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// path_to_pk: str
|
||||
/// Path to the proving key
|
||||
///
|
||||
/// vk_output_path: str
|
||||
/// Path to create the vk file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
#[pyfunction(signature = (
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
|
||||
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
|
||||
|
||||
let vk = pk.get_vk();
|
||||
|
||||
// now save
|
||||
save_vk::<G1Affine>(&vk_output_path, vk)
|
||||
.map_err(|_| PyIOError::new_err("Failed to save vk"))?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Displays the table as a string in python
|
||||
///
|
||||
/// Arguments
|
||||
@@ -708,6 +855,8 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
/// srs_path: str
|
||||
/// Path to the create the SRS file
|
||||
///
|
||||
/// commitment: str
|
||||
/// Specify the commitment used ("kzg", "ipa")
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -717,6 +866,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
logrows=None,
|
||||
srs_path=None,
|
||||
commitment=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn get_srs(
|
||||
@@ -724,9 +874,15 @@ fn get_srs(
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
srs_path: Option<PathBuf>,
|
||||
commitment: Option<PyCommitments>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let commitment: Option<Commitments> = match commitment {
|
||||
Some(c) => Some(c.into()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::get_srs_cmd(srs_path, settings_path, logrows)
|
||||
crate::execute::get_srs_cmd(srs_path, settings_path, logrows, commitment)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
@@ -961,6 +1117,42 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Mocks the aggregate prover
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_snarks: list[str]
|
||||
/// List of paths to the relevant proof files
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows to use for the aggregation circuit
|
||||
///
|
||||
/// split_proofs: bool
|
||||
/// Indicates whether the accumulated are segments of a larger proof
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
split_proofs = false,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn mock_aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
split_proofs: bool,
|
||||
) -> PyResult<bool> {
|
||||
crate::execute::mock_aggregate(aggregation_snarks, logrows, split_proofs).map_err(|e| {
|
||||
let err_str = format!("Failed to run mock: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Runs the setup process
|
||||
///
|
||||
/// Arguments
|
||||
@@ -1036,6 +1228,8 @@ fn setup(
|
||||
/// proof_path: str
|
||||
/// Path to create the proof file
|
||||
///
|
||||
/// proof_type: str
|
||||
/// Accepts `single`, `for-aggr`
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
@@ -1049,6 +1243,7 @@ fn setup(
|
||||
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
|
||||
pk_path=PathBuf::from(DEFAULT_PK),
|
||||
proof_path=None,
|
||||
proof_type=ProofType::default(),
|
||||
srs_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
@@ -1057,6 +1252,7 @@ fn prove(
|
||||
model: PathBuf,
|
||||
pk_path: PathBuf,
|
||||
proof_path: Option<PathBuf>,
|
||||
proof_type: ProofType,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<PyObject> {
|
||||
let snark = crate::execute::prove(
|
||||
@@ -1065,6 +1261,7 @@ fn prove(
|
||||
pk_path,
|
||||
proof_path,
|
||||
srs_path,
|
||||
proof_type,
|
||||
CheckMode::UNSAFE,
|
||||
)
|
||||
.map_err(|e| {
|
||||
@@ -1123,6 +1320,77 @@ fn verify(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Runs the setup process for an aggregate setup
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// sample_snarks: list[str]
|
||||
/// List of paths to the various proofs
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to create the aggregated VK
|
||||
///
|
||||
/// pk_path: str
|
||||
/// Path to create the aggregated PK
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows to use
|
||||
///
|
||||
/// split_proofs: bool
|
||||
/// Whether the accumulated are segments of a larger proof
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// disable_selector_compression: bool
|
||||
/// Whether to compress selectors
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts `kzg`, `ipa`
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
sample_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
pk_path=PathBuf::from(DEFAULT_PK_AGGREGATED),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
split_proofs = false,
|
||||
srs_path = None,
|
||||
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
pk_path: PathBuf,
|
||||
logrows: u32,
|
||||
split_proofs: bool,
|
||||
srs_path: Option<PathBuf>,
|
||||
disable_selector_compression: bool,
|
||||
commitment: PyCommitments,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::setup_aggregate(
|
||||
sample_snarks,
|
||||
vk_path,
|
||||
pk_path,
|
||||
srs_path,
|
||||
logrows,
|
||||
split_proofs,
|
||||
disable_selector_compression,
|
||||
commitment.into(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to setup aggregate: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Compiles the circuit for use in other steps
|
||||
///
|
||||
/// Arguments
|
||||
@@ -1152,7 +1420,144 @@ fn compile_circuit(
|
||||
settings_path: PathBuf,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::compile_circuit(model, compiled_circuit, settings_path).map_err(|e| {
|
||||
let err_str = format!("Failed to compile circuit: {}", e);
|
||||
let err_str = format!("Failed to setup aggregate: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Creates an aggregated proof
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_snarks: list[str]
|
||||
/// List of paths to the various proofs
|
||||
///
|
||||
/// proof_path: str
|
||||
/// Path to output the aggregated proof
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the VK file
|
||||
///
|
||||
/// transcript:
|
||||
/// Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
|
||||
///
|
||||
/// logrows:
|
||||
/// Logrows used for aggregation circuit
|
||||
///
|
||||
/// check_mode: str
|
||||
/// Run sanity checks during calculations. Accepts `safe` or `unsafe`
|
||||
///
|
||||
/// split-proofs: bool
|
||||
/// Whether the accumulated proofs are segments of a larger circuit
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS used
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts "kzg" or "ipa"
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
|
||||
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
transcript=TranscriptType::default(),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
check_mode=CheckMode::UNSAFE,
|
||||
split_proofs = false,
|
||||
srs_path=None,
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
proof_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
transcript: TranscriptType,
|
||||
logrows: u32,
|
||||
check_mode: CheckMode,
|
||||
split_proofs: bool,
|
||||
srs_path: Option<PathBuf>,
|
||||
commitment: PyCommitments,
|
||||
) -> Result<bool, PyErr> {
|
||||
// the K used for the aggregation circuit
|
||||
crate::execute::aggregate(
|
||||
proof_path,
|
||||
aggregation_snarks,
|
||||
vk_path,
|
||||
srs_path,
|
||||
transcript,
|
||||
logrows,
|
||||
check_mode,
|
||||
split_proofs,
|
||||
commitment.into(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run aggregate: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Verifies and aggregate proof
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// proof_path: str
|
||||
/// The path to the proof file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// The path to the verification key file
|
||||
///
|
||||
/// logrows: int
|
||||
/// logrows used for aggregation circuit
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts "kzg" or "ipa"
|
||||
///
|
||||
/// reduced_srs: bool
|
||||
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
|
||||
srs_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn verify_aggr(
|
||||
proof_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
logrows: u32,
|
||||
commitment: PyCommitments,
|
||||
reduced_srs: bool,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::verify_aggr(
|
||||
proof_path,
|
||||
vk_path,
|
||||
srs_path,
|
||||
logrows,
|
||||
reduced_srs,
|
||||
commitment.into(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run verify_aggr: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
@@ -1259,7 +1664,7 @@ fn create_evm_verifier(
|
||||
|
||||
#[cfg(feature = "reusable-verifier")]
|
||||
/// Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
|
||||
/// This is useful for deploying verifier that were otherwise too big to fit on chain .
|
||||
/// This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
@@ -1465,6 +1870,75 @@ fn verify_evm<'a>(
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_settings: str
|
||||
/// path to the settings file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// The path to load the desired verification key file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the Solidity code
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to output the Solidity verifier ABI
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows used during aggregated setup
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// reusable: bool
|
||||
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
aggregation_settings=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
srs_path=None,
|
||||
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_verifier_aggr(
|
||||
py: Python<'_>,
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
reusable: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
reusable,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier_aggr: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
// Define a function to gather stub information.
|
||||
define_stub_info_gatherer!(stub_info);
|
||||
|
||||
@@ -1476,16 +1950,19 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyG1Affine>()?;
|
||||
m.add_class::<PyG1>()?;
|
||||
m.add_class::<PyTestDataSource>()?;
|
||||
m.add_class::<PyCommitments>()?;
|
||||
m.add_class::<PyInputType>()?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(kzg_commit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(ipa_commit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(swap_proof_commitments, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(poseidon_hash, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(float_to_felt, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(buffer_to_felts, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(table, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(mock, m)?)?;
|
||||
@@ -1498,12 +1975,17 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_aggregate, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
|
||||
#[cfg(feature = "reusable-verifier")]
|
||||
m.add_function(wrap_pyfunction!(create_evm_vka, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
|
||||
#[cfg(feature = "reusable-verifier")]
|
||||
m.add_function(wrap_pyfunction!(register_vka, m)?)?;
|
||||
@@ -1519,6 +2001,24 @@ impl pyo3_stub_gen::PyStubType for CalibrationTarget {
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for ProofType {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for TranscriptType {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for CheckMode {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
|
||||
606
src/bindings/universal.rs
Normal file
606
src/bindings/universal.rs
Normal file
@@ -0,0 +1,606 @@
|
||||
use halo2_proofs::{
|
||||
plonk::*,
|
||||
poly::{
|
||||
commitment::{CommitmentScheme, ParamsProver},
|
||||
ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
multiopen::{ProverIPA, VerifierIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
},
|
||||
kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
},
|
||||
VerificationStrategy,
|
||||
},
|
||||
};
|
||||
use std::fmt::Display;
|
||||
use std::io::BufReader;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::{
|
||||
circuit::region::RegionSettings,
|
||||
graph::GraphSettings,
|
||||
pfsys::{
|
||||
create_proof_circuit, encode_calldata,
|
||||
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
|
||||
verify_proof_circuit, TranscriptType,
|
||||
},
|
||||
tensor::TensorType,
|
||||
CheckMode, Commitments, EZKLError as InnerEZKLError,
|
||||
};
|
||||
|
||||
use crate::circuit::modules::poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::graph::{GraphCircuit, GraphWitness};
|
||||
use halo2curves::{
|
||||
bn256::{Bn256, Fr, G1Affine},
|
||||
ff::{FromUniformBytes, PrimeField},
|
||||
};
|
||||
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
|
||||
|
||||
/// Wrapper around the Error Message
|
||||
#[cfg_attr(feature = "ios-bindings", derive(uniffi::Error))]
|
||||
#[derive(Debug)]
|
||||
pub enum EZKLError {
|
||||
/// Some Comment
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
impl Display for EZKLError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
EZKLError::InternalError(e) => write!(f, "Internal error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InnerEZKLError> for EZKLError {
|
||||
fn from(e: InnerEZKLError) -> Self {
|
||||
EZKLError::InternalError(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash the input message with poseidon
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn poseidon_hash(message: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
|
||||
let message: Vec<Fr> = serde_json::from_slice(&message[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(serde_json::to_vec(&output).map_err(InnerEZKLError::from)?)
|
||||
}
|
||||
|
||||
/// Hash the input message with poseidon without converting to Fr
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn poseidon_hash_no_felt(message: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
|
||||
let message: Vec<Fr> = message.iter().map(|x| Fr::from(*x as u64)).collect();
|
||||
|
||||
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
|
||||
.map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(serde_json::to_vec(&output).map_err(InnerEZKLError::from)?)
|
||||
}
|
||||
|
||||
/// Encode verifier calldata from proof and ethereum vk_address
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn encode_verifier_calldata(
|
||||
// TODO - shuold it be pub or pub or pub(super)?
|
||||
proof: Vec<u8>,
|
||||
vka: Option<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
let snark: crate::pfsys::Snark<Fr, G1Affine> =
|
||||
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let vka_buf: Option<Vec<[u8; 32]>> = if let Some(vka) = vka {
|
||||
let array: Vec<[u8; 32]> =
|
||||
serde_json::from_slice(&vka[..]).map_err(InnerEZKLError::from)?;
|
||||
Some(array)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let vka: Option<&[[u8; 32]]> = vka_buf.as_deref();
|
||||
|
||||
let flattened_instances = snark.instances.into_iter().flatten();
|
||||
|
||||
let encoded = encode_calldata(vka, &snark.proof, &flattened_instances.collect::<Vec<_>>());
|
||||
|
||||
Ok(encoded)
|
||||
}
|
||||
|
||||
/// Generate witness from compiled circuit and input json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn gen_witness(compiled_circuit: Vec<u8>, input: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
|
||||
println!("[circuit]");
|
||||
let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
.map_err(|e| {
|
||||
EZKLError::InternalError(format!("Failed to deserialize compiled model: {}", e))
|
||||
})?;
|
||||
|
||||
println!("[input]");
|
||||
let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize input: {}", e)))?;
|
||||
|
||||
println!("[load graph input]");
|
||||
let mut input = circuit
|
||||
.load_graph_input(&input)
|
||||
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
|
||||
|
||||
println!("[load graph witness]");
|
||||
let witness = circuit
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(
|
||||
&mut input,
|
||||
None,
|
||||
None,
|
||||
RegionSettings::all_true(
|
||||
circuit.settings().run_args.decomp_base,
|
||||
circuit.settings().run_args.decomp_legs,
|
||||
),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
|
||||
|
||||
println!("[serialize witness]");
|
||||
serde_json::to_vec(&witness)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize witness: {}", e)))
|
||||
}
|
||||
|
||||
/// Generate verifying key from compiled circuit, and parameters srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn gen_vk(
|
||||
compiled_circuit: Vec<u8>,
|
||||
srs: Vec<u8>,
|
||||
compress_selectors: bool,
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
|
||||
|
||||
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
|
||||
|
||||
let vk = create_vk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
|
||||
|
||||
let mut serialized_vk = Vec::new();
|
||||
vk.write(
|
||||
&mut serialized_vk,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
|
||||
|
||||
Ok(serialized_vk)
|
||||
}
|
||||
|
||||
/// Generate proving key from vk, compiled circuit and parameters srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn gen_pk(vk: Vec<u8>, compiled_circuit: Vec<u8>, srs: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
|
||||
|
||||
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
|
||||
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit.settings().clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
|
||||
|
||||
let pk = create_pk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk, &circuit, ¶ms)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to create proving key: {}", e)))?;
|
||||
|
||||
let mut serialized_pk = Vec::new();
|
||||
pk.write(&mut serialized_pk, halo2_proofs::SerdeFormat::RawBytes)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize proving key: {}", e)))?;
|
||||
|
||||
Ok(serialized_pk)
|
||||
}
|
||||
|
||||
/// Verify proof with vk, proof json, circuit settings json and srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn verify(
|
||||
proof: Vec<u8>,
|
||||
vk: Vec<u8>,
|
||||
settings: Vec<u8>,
|
||||
srs: Vec<u8>,
|
||||
) -> Result<bool, EZKLError> {
|
||||
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize settings: {}", e)))?;
|
||||
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
|
||||
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings.clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let orig_n = 1 << circuit_settings.run_args.logrows;
|
||||
let commitment = circuit_settings.run_args.commitment.into();
|
||||
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let result = match commitment {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> = get_params(&mut reader)?;
|
||||
let strategy = IPASingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => Err(EZKLError::InternalError(format!(
|
||||
"Verification failed: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify aggregate proof with vk, proof, circuit settings and srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn verify_aggr(
|
||||
proof: Vec<u8>,
|
||||
vk: Vec<u8>,
|
||||
logrows: u64,
|
||||
srs: Vec<u8>,
|
||||
commitment: &str,
|
||||
) -> Result<bool, EZKLError> {
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
|
||||
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let commit = Commitments::from_str(commitment)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Invalid commitment: {}", e)))?;
|
||||
|
||||
let orig_n = 1 << logrows;
|
||||
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let result = match commit {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|
||||
|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)),
|
||||
)?;
|
||||
let strategy = IPASingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
result
|
||||
.map(|_| true)
|
||||
.map_err(|e| EZKLError::InternalError(format!("{}", e)))
|
||||
}
|
||||
|
||||
/// Prove in browser with compiled circuit, witness json, proving key, and srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn prove(
|
||||
witness: Vec<u8>,
|
||||
pk: Vec<u8>,
|
||||
compiled_circuit: Vec<u8>,
|
||||
srs: Vec<u8>,
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
#[cfg(feature = "det-prove")]
|
||||
log::set_max_level(log::LevelFilter::Debug);
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
log::set_max_level(log::LevelFilter::Info);
|
||||
|
||||
let mut circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
|
||||
|
||||
let data: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let mut reader = BufReader::new(&pk[..]);
|
||||
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit.settings().clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
|
||||
|
||||
circuit
|
||||
.load_graph_witness(&data)
|
||||
.map_err(InnerEZKLError::from)?;
|
||||
let public_inputs = circuit
|
||||
.prepare_public_inputs(&data)
|
||||
.map_err(InnerEZKLError::from)?;
|
||||
let proof_split_commits: Option<crate::pfsys::ProofSplitCommit> = data.into();
|
||||
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let commitment = circuit.settings().run_args.commitment.into();
|
||||
|
||||
let proof = match commitment {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|
||||
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
|
||||
)?;
|
||||
|
||||
create_proof_circuit::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit,
|
||||
vec![public_inputs],
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
proof_split_commits,
|
||||
None,
|
||||
)
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|
||||
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
|
||||
)?;
|
||||
|
||||
create_proof_circuit::<
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
_,
|
||||
ProverIPA<_>,
|
||||
VerifierIPA<_>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit,
|
||||
vec![public_inputs],
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
Commitments::IPA,
|
||||
TranscriptType::EVM,
|
||||
proof_split_commits,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
.map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(serde_json::to_vec(&proof).map_err(InnerEZKLError::from)?)
|
||||
}
|
||||
|
||||
/// Validate the witness json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn witness_validation(witness: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the compiled circuit
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn compiled_circuit_validation(compiled_circuit: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: GraphCircuit = bincode::deserialize(&compiled_circuit[..]).map_err(|e| {
|
||||
EZKLError::InternalError(format!("Failed to deserialize compiled circuit: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the input json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn input_validation(input: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: crate::graph::input::GraphData =
|
||||
serde_json::from_slice(&input[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the proof json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn proof_validation(proof: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: crate::pfsys::Snark<Fr, G1Affine> =
|
||||
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the verifying key given the settings json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let circuit_settings: GraphSettings =
|
||||
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the proving key given the settings json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let circuit_settings: GraphSettings =
|
||||
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
let mut reader = BufReader::new(&pk[..]);
|
||||
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the settings json
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn settings_validation(settings: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let _: GraphSettings = serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Validate the srs
|
||||
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
|
||||
pub fn srs_validation(srs: Vec<u8>) -> Result<bool, EZKLError> {
|
||||
let mut reader = BufReader::new(&srs[..]);
|
||||
let _: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(|e| {
|
||||
EZKLError::InternalError(format!("Failed to deserialize params: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
|
||||
fn get_params<
|
||||
Scheme: for<'a> halo2_proofs::poly::commitment::Params<'a, halo2curves::bn256::G1Affine>,
|
||||
>(
|
||||
mut reader: &mut BufReader<&[u8]>,
|
||||
) -> Result<Scheme, EZKLError> {
|
||||
halo2_proofs::poly::commitment::Params::<G1Affine>::read(&mut reader)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)))
|
||||
}
|
||||
|
||||
/// Creates a [ProvingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
|
||||
pub fn create_vk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
circuit: &C,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
compress_selectors: bool,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
|
||||
|
||||
// Initialize the verifying key
|
||||
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
|
||||
Ok(vk)
|
||||
}
|
||||
/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
|
||||
pub fn create_pk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
vk: VerifyingKey<Scheme::Curve>,
|
||||
circuit: &C,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
|
||||
|
||||
// Initialize the proving key
|
||||
let pk = keygen_pk(params, vk, &empty_circuit)?;
|
||||
Ok(pk)
|
||||
}
|
||||
398
src/bindings/wasm.rs
Normal file
398
src/bindings/wasm.rs
Normal file
@@ -0,0 +1,398 @@
|
||||
use crate::{
|
||||
circuit::modules::polycommit::PolyCommitChip,
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
|
||||
graph::{quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings},
|
||||
};
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::{
|
||||
plonk::*,
|
||||
poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
};
|
||||
use halo2_solidity_verifier::Evm;
|
||||
use halo2curves::{
|
||||
bn256::{Bn256, Fr, G1Affine},
|
||||
ff::PrimeField,
|
||||
};
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
|
||||
use crate::bindings::universal::{
|
||||
compiled_circuit_validation, encode_verifier_calldata, gen_pk, gen_vk, gen_witness,
|
||||
input_validation, pk_validation, proof_validation, settings_validation, srs_validation,
|
||||
verify_aggr, vk_validation, witness_validation, EZKLError as ExternalEZKLError,
|
||||
};
|
||||
#[cfg(feature = "web")]
|
||||
pub use wasm_bindgen_rayon::init_thread_pool;
|
||||
|
||||
impl From<ExternalEZKLError> for JsError {
|
||||
fn from(e: ExternalEZKLError) -> Self {
|
||||
JsError::new(&format!("{}", e))
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
/// Initialize logger for wasm
|
||||
pub fn init_logger() {
|
||||
log::set_logger(&DEFAULT_LOGGER).unwrap();
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
/// Initialize panic hook for wasm
|
||||
pub fn init_panic_hook() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
/// Wrapper around the halo2 encode call data method
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn encodeVerifierCalldata(
|
||||
proof: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk_address: Option<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
encode_verifier_calldata(proof.0, vk_address).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Converts a hex string to a byte array
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(format!("{:?}", felt))
|
||||
}
|
||||
|
||||
/// Converts a felt to a little endian string
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let repr = serde_json::to_string(&felt).unwrap();
|
||||
let b: String = serde_json::from_str(&repr).unwrap();
|
||||
Ok(b)
|
||||
}
|
||||
|
||||
/// Converts a hex string to a byte array
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToInt(
|
||||
array: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&felt_to_integer_rep(felt))
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Converts felts to a floating point element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToFloat(
|
||||
array: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
scale: crate::Scale,
|
||||
) -> Result<f64, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
Ok(int_rep as f64 / multiplier)
|
||||
}
|
||||
|
||||
/// Converts a floating point number to a hex string representing a fixed point field element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn floatToFelt(
|
||||
mut input: f64,
|
||||
scale: crate::Scale,
|
||||
input_type: &str,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
crate::circuit::InputType::roundtrip(
|
||||
&crate::circuit::InputType::from_str(input_type)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?,
|
||||
&mut input,
|
||||
);
|
||||
let int_rep =
|
||||
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|
||||
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
|
||||
)?))
|
||||
}
|
||||
|
||||
/// Generate a kzg commitment.
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn kzgCommit(
|
||||
message: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let message: Vec<Fr> = serde_json::from_slice(&message[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(¶ms_ser[..]);
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(&vk[..]);
|
||||
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
|
||||
message,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
¶ms,
|
||||
);
|
||||
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn bufferToVecOfFelt(
|
||||
buffer: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
// Convert the buffer to a slice
|
||||
let buffer: &[u8] = &buffer;
|
||||
|
||||
// Divide the buffer into chunks of 64 bytes
|
||||
let chunks = buffer.chunks_exact(16);
|
||||
|
||||
// Get the remainder
|
||||
let remainder = chunks.remainder();
|
||||
|
||||
// Add 0s to the remainder to make it 64 bytes
|
||||
let mut remainder = remainder.to_vec();
|
||||
|
||||
// Collect chunks into a Vec<[u8; 16]>.
|
||||
let chunks: Result<Vec<[u8; 16]>, JsError> = chunks
|
||||
.map(|slice| {
|
||||
let array: [u8; 16] = slice
|
||||
.try_into()
|
||||
.map_err(|_| JsError::new("failed to slice input chunks"))?;
|
||||
Ok(array)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut chunks = chunks?;
|
||||
|
||||
if remainder.len() != 0 {
|
||||
remainder.resize(16, 0);
|
||||
// Convert the Vec<u8> to [u8; 16]
|
||||
let remainder_array: [u8; 16] = remainder
|
||||
.try_into()
|
||||
.map_err(|_| JsError::new("failed to slice remainder"))?;
|
||||
// append the remainder to the chunks
|
||||
chunks.push(remainder_array);
|
||||
}
|
||||
|
||||
// Convert each chunk to a field element
|
||||
let field_elements: Vec<Fr> = chunks
|
||||
.iter()
|
||||
.map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x)))
|
||||
.collect();
|
||||
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&field_elements)
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize field elements: {}", e)))?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate a poseidon hash in browser. Input message
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn poseidonHash(
|
||||
message: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
super::universal::poseidon_hash(message.0)
|
||||
.map_err(JsError::from)
|
||||
.map(|x| wasm_bindgen::Clamped(x.clone()))
|
||||
}
|
||||
|
||||
/// Generate a witness file from input.json, compiled model and a settings.json file.
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn genWitness(
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
input: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
gen_witness(compiled_circuit.0, input.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Generate verifying key in browser
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn genVk(
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
compress_selectors: bool,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
gen_vk(compiled_circuit.0, params_ser.0, compress_selectors).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Generate proving key in browser
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn genPk(
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
gen_pk(vk.0, compiled_circuit.0, params_ser.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Verify proof in browser using wasm
|
||||
#[wasm_bindgen]
|
||||
pub fn verify(
|
||||
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
super::universal::verify(proof_js.0, vk.0, settings.0, srs.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Verify proof in browser evm using wasm
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn verifyEVM(
|
||||
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
bytecode_verifier: Vec<u8>,
|
||||
bytecode_vka: Option<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
let mut evm = Evm::unlimited();
|
||||
let decoded_verifier = utf8_bytes_to_hex_decoded(&bytecode_verifier)?;
|
||||
let (verifier_address, _) = evm.create(decoded_verifier);
|
||||
// if bytecode_vk is Some, then create the vk contract
|
||||
let vk_address = if let Some(bytecode_vka) = bytecode_vka {
|
||||
let decoded_vka = utf8_bytes_to_hex_decoded(&bytecode_vka)?;
|
||||
let (address, _) = evm.create(decoded_vka);
|
||||
Some(address.as_slice().to_vec())
|
||||
// check if bytecode_verifier is none and if so then generate the
|
||||
// reusable verifier
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let calldata = encode_verifier_calldata(proof_js.0, vk_address).map_err(JsError::from);
|
||||
let output = evm.call(verifier_address, calldata?).1;
|
||||
let true_word = [vec![0; 31], vec![1]].concat();
|
||||
Ok(output == true_word)
|
||||
}
|
||||
|
||||
/// Verify aggregate proof in browser using wasm
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn verifyAggr(
|
||||
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
logrows: u64,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
commitment: &str,
|
||||
) -> Result<bool, JsError> {
|
||||
verify_aggr(proof_js.0, vk.0, logrows, srs.0, commitment).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// Prove in browser using wasm
|
||||
#[wasm_bindgen]
|
||||
pub fn prove(
|
||||
witness: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
pk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
super::universal::prove(witness.0, pk.0, compiled_circuit.0, srs.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
// VALIDATION FUNCTIONS
|
||||
|
||||
/// Witness file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn witnessValidation(witness: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
witness_validation(witness.0).map_err(JsError::from)
|
||||
}
|
||||
/// Compiled circuit validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn compiledCircuitValidation(
|
||||
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
compiled_circuit_validation(compiled_circuit.0).map_err(JsError::from)
|
||||
}
|
||||
/// Input file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn inputValidation(input: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
input_validation(input.0).map_err(JsError::from)
|
||||
}
|
||||
/// Proof file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn proofValidation(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
proof_validation(proof.0).map_err(JsError::from)
|
||||
}
|
||||
/// Vk file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn vkValidation(
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
vk_validation(vk.0, settings.0).map_err(JsError::from)
|
||||
}
|
||||
/// Pk file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn pkValidation(
|
||||
pk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
pk_validation(pk.0, settings.0).map_err(JsError::from)
|
||||
}
|
||||
/// Settings file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn settingsValidation(settings: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
settings_validation(settings.0).map_err(JsError::from)
|
||||
}
|
||||
/// Srs file validation
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn srsValidation(srs: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
|
||||
srs_validation(srs.0).map_err(JsError::from)
|
||||
}
|
||||
|
||||
/// HELPER FUNCTIONS
|
||||
pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
|
||||
let mut n: u128 = 0;
|
||||
for &b in arr.iter().rev() {
|
||||
n <<= 8;
|
||||
n |= b as u128;
|
||||
}
|
||||
n
|
||||
}
|
||||
///
|
||||
pub fn utf8_bytes_to_hex_decoded(input: &[u8]) -> Result<Vec<u8>, JsError> {
|
||||
let string = std::str::from_utf8(input)?.trim();
|
||||
let hex_string = if string.starts_with("0x") {
|
||||
&string[2..]
|
||||
} else {
|
||||
string
|
||||
};
|
||||
hex::decode(hex_string).map_err(JsError::from)
|
||||
}
|
||||
@@ -7,14 +7,18 @@ use halo2_proofs::{
|
||||
};
|
||||
use log::debug;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*, IntoPyObject};
|
||||
use pyo3::{
|
||||
conversion::FromPyObject,
|
||||
exceptions::PyValueError,
|
||||
IntoPyObject,
|
||||
prelude::*,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use crate::{
|
||||
circuit::{
|
||||
chip::einsum::analysis::EinsumAnalysis,
|
||||
ops::base::BaseOp,
|
||||
table::{Range, RangeCheck, Table},
|
||||
},
|
||||
@@ -25,9 +29,6 @@ use std::{collections::BTreeMap, marker::PhantomData};
|
||||
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
|
||||
///
|
||||
pub mod einsum;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
|
||||
#[derive(
|
||||
@@ -270,8 +271,6 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub range_checks: RangeChecks<F>,
|
||||
/// [Selector]s for the shuffles
|
||||
pub shuffles: Shuffles,
|
||||
/// Einsum-specific configuration
|
||||
pub einsums: Option<einsum::Einsums<F>>,
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
@@ -286,22 +285,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
|
||||
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
|
||||
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
|
||||
einsums: Some(einsum::Einsums::<F>::dummy(col_size, num_inner_cols)),
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
shared_table_inputs: vec![],
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, no tables, and no Freivalds' argument.
|
||||
pub fn dummy_without_freivalds(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
Self {
|
||||
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
|
||||
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
|
||||
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
|
||||
einsums: None,
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
@@ -436,7 +419,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
},
|
||||
static_lookups: StaticLookups::default(),
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
einsums: None,
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
shared_table_inputs: vec![],
|
||||
@@ -711,27 +693,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates einsums
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_einsums(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
analysis: &EinsumAnalysis,
|
||||
num_inner_cols: usize,
|
||||
logrows: usize,
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
self.einsums = Some(einsum::Einsums::configure_universal(
|
||||
cs,
|
||||
analysis,
|
||||
num_inner_cols,
|
||||
logrows,
|
||||
));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_shuffles(
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::circuit::{
|
||||
einsum::reduction_planner::{self, Reduction},
|
||||
CircuitError,
|
||||
};
|
||||
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EinsumAnalysis {
|
||||
/// max size of input tensors
|
||||
pub max_input_size: usize,
|
||||
/// max size of output tensors
|
||||
pub max_output_size: usize,
|
||||
/// max number of input tensors
|
||||
pub max_num_inputs: usize,
|
||||
/// max number of output axes
|
||||
pub max_num_output_axes: usize,
|
||||
/// the sum of the lengths of dot product to compute all the reductions
|
||||
pub reduction_length: usize,
|
||||
}
|
||||
|
||||
/// The strategy to use for einsum
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EinsumStrategy {
|
||||
/// Use only base ops
|
||||
BaseOps,
|
||||
/// Use Freivalds' argument
|
||||
Freivalds,
|
||||
}
|
||||
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SingleEquationAnalysis {
|
||||
///
|
||||
pub equation: String,
|
||||
///
|
||||
pub num_inputs: usize,
|
||||
///
|
||||
pub max_input_size: usize,
|
||||
///
|
||||
pub output_size: usize,
|
||||
///
|
||||
pub num_output_axes: usize,
|
||||
///
|
||||
pub output_indices: Vec<char>,
|
||||
/// the length of dot product to compute all the reductions
|
||||
pub reduction_length: usize,
|
||||
/// the strategy to use for einsum
|
||||
pub strategy: EinsumStrategy,
|
||||
}
|
||||
|
||||
///
|
||||
pub fn analyze_einsum_usage(
|
||||
equations: &HashMap<(usize, String), HashMap<char, usize>>,
|
||||
) -> Result<EinsumAnalysis, CircuitError> {
|
||||
let mut max_num_inputs = 0;
|
||||
let mut max_input_size = 0;
|
||||
let mut max_output_size = 0;
|
||||
let mut max_num_output_axes = 0;
|
||||
let mut reduction_length = 0;
|
||||
|
||||
for ((_, equation), input_axes_to_dim) in equations.iter() {
|
||||
let analysis = analyze_single_equation(equation, input_axes_to_dim)?;
|
||||
max_input_size = max_input_size.max(analysis.max_input_size);
|
||||
max_output_size = max_output_size.max(analysis.output_size);
|
||||
max_num_inputs = max_num_inputs.max(analysis.num_inputs);
|
||||
max_num_output_axes = max_num_output_axes.max(analysis.num_output_axes);
|
||||
reduction_length += analysis.reduction_length;
|
||||
}
|
||||
|
||||
Ok(EinsumAnalysis {
|
||||
max_input_size,
|
||||
max_output_size,
|
||||
max_num_inputs,
|
||||
max_num_output_axes,
|
||||
reduction_length,
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
pub fn analyze_single_equation(
|
||||
equation: &str,
|
||||
input_axes_to_dim: &HashMap<char, usize>,
|
||||
) -> Result<SingleEquationAnalysis, CircuitError> {
|
||||
// Sanitise equation to remove trivial axes
|
||||
let equation = {
|
||||
let (inputs_str, output_str) = equation.split_once("->").unwrap();
|
||||
let input_equations: Vec<&str> = inputs_str.split(',').collect();
|
||||
|
||||
let inputs: Vec<String> = input_equations
|
||||
.iter()
|
||||
.map(|input| {
|
||||
input
|
||||
.chars()
|
||||
.filter(|char| *input_axes_to_dim.get(char).unwrap() > 1)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let output = output_str
|
||||
.chars()
|
||||
.filter(|c| {
|
||||
input_axes_to_dim.get(c).is_some() && *input_axes_to_dim.get(c).unwrap() > 1
|
||||
})
|
||||
.collect();
|
||||
|
||||
[inputs.join(","), output].join("->")
|
||||
};
|
||||
|
||||
let (inputs_eq, output_eq) = equation.split_once("->").unwrap();
|
||||
let input_equations: Vec<&str> = inputs_eq.split(',').collect();
|
||||
|
||||
let max_input_size = input_equations
|
||||
.iter()
|
||||
.map(|eqn| {
|
||||
eqn.chars()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap())
|
||||
.product()
|
||||
})
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
let output_indices: Vec<char> = output_eq.chars().collect();
|
||||
let output_dims = output_indices
|
||||
.iter()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap());
|
||||
let output_size = output_dims.clone().product();
|
||||
|
||||
let output_reduction_length = {
|
||||
let mut output_dims = output_dims.rev().cloned().collect_vec();
|
||||
let mut total_length = 0;
|
||||
for _ in 0..output_dims.len() {
|
||||
let dot_product_len = output_dims.remove(0);
|
||||
let num_dot_products: usize = output_dims.iter().product();
|
||||
total_length += dot_product_len * num_dot_products;
|
||||
}
|
||||
total_length
|
||||
};
|
||||
|
||||
let input_reductions_length = {
|
||||
let input_reductions = reduction_planner::input_reductions(&equation)?;
|
||||
input_reductions
|
||||
.into_iter()
|
||||
.map(|reduction| {
|
||||
let (_, output_expr) = reduction.expression().split_once("->").unwrap();
|
||||
let num_inputs = reduction.input_indices().len();
|
||||
let dot_product_len = match reduction {
|
||||
Reduction::RLC { axis, .. } => *input_axes_to_dim.get(&axis).unwrap(),
|
||||
Reduction::Contraction { axis, .. } => *axis
|
||||
.and_then(|axis| input_axes_to_dim.get(&axis))
|
||||
.unwrap_or(&1),
|
||||
};
|
||||
let num_dot_products: usize = output_expr
|
||||
.chars()
|
||||
.map(|c| input_axes_to_dim.get(&c).unwrap())
|
||||
.product();
|
||||
// since `multi_dot` does pairwise mult between input pairs and final summation
|
||||
if num_inputs <= 2 {
|
||||
num_dot_products * dot_product_len
|
||||
} else {
|
||||
num_dot_products * (dot_product_len * num_inputs)
|
||||
}
|
||||
})
|
||||
.sum::<usize>()
|
||||
};
|
||||
|
||||
let dispatch_to_einsum_with_base_ops = {
|
||||
let mut seen = HashSet::new();
|
||||
let mut common_indices_to_inputs = vec![];
|
||||
for input in input_equations.iter() {
|
||||
for c in input.chars() {
|
||||
if !seen.contains(&c) {
|
||||
seen.insert(c);
|
||||
} else {
|
||||
common_indices_to_inputs.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
let non_common_indices = input_axes_to_dim
|
||||
.keys()
|
||||
.filter(|&x| {
|
||||
!common_indices_to_inputs.contains(x)
|
||||
&& input_axes_to_dim.get(x).cloned().unwrap() > 1
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
!(output_indices.len() > 0
|
||||
&& common_indices_to_inputs.len() > 0
|
||||
&& non_common_indices.len() > 1)
|
||||
};
|
||||
|
||||
let strategy = if dispatch_to_einsum_with_base_ops {
|
||||
EinsumStrategy::BaseOps
|
||||
} else {
|
||||
EinsumStrategy::Freivalds
|
||||
};
|
||||
|
||||
Ok(SingleEquationAnalysis {
|
||||
output_size,
|
||||
max_input_size,
|
||||
equation: equation.to_string(),
|
||||
num_inputs: input_equations.len(),
|
||||
num_output_axes: output_indices.len(),
|
||||
output_indices,
|
||||
reduction_length: output_reduction_length + input_reductions_length,
|
||||
strategy,
|
||||
})
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
use std::{collections::HashMap, marker::PhantomData};
|
||||
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
tensor::{Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
/// Circuit parameter for a single einsum equation
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct SingleEinsumParams<F: PrimeField + TensorType + PartialOrd> {
|
||||
///
|
||||
pub equation: String,
|
||||
/// Map from input axes to dimensions
|
||||
pub input_axes_to_dims: HashMap<char, usize>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> SingleEinsumParams<F> {
|
||||
///
|
||||
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut input_axes_to_dims = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if input_axes_to_dims[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
equation: equation.to_owned(),
|
||||
input_axes_to_dims,
|
||||
_marker: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,359 +0,0 @@
|
||||
use halo2curves::ff::PrimeField;
|
||||
use log::{error, trace};
|
||||
|
||||
use crate::{
|
||||
circuit::{base::BaseOp, einsum::BaseOpInfo, region::RegionCtx, CheckMode, CircuitError},
|
||||
tensor::{
|
||||
get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
TensorError, TensorType, ValTensor, ValType,
|
||||
},
|
||||
};
|
||||
|
||||
use super::ContractionConfig;
|
||||
|
||||
/// Pairwise (elementwise) op layout
|
||||
pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 2],
|
||||
op: BaseOp,
|
||||
phases: &[usize; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let (mut lhs, mut rhs) = if phases[0] <= phases[1] {
|
||||
(values[0].clone(), values[1].clone())
|
||||
} else {
|
||||
(values[1].clone(), values[0].clone())
|
||||
};
|
||||
|
||||
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
|
||||
|
||||
lhs.expand(&broadcasted_shape)?;
|
||||
rhs.expand(&broadcasted_shape)?;
|
||||
|
||||
if lhs.len() != rhs.len() {
|
||||
return Err(CircuitError::DimMismatch(format!(
|
||||
"pairwise {} layout",
|
||||
op.as_str()
|
||||
)));
|
||||
}
|
||||
|
||||
region.flush_einsum()?;
|
||||
|
||||
let input_vars = config.get_input_vars(phases.as_slice().into());
|
||||
let output_var = config.get_output_var(phases.as_slice().into());
|
||||
|
||||
let inputs = [lhs, rhs]
|
||||
.iter()
|
||||
.zip(input_vars)
|
||||
.map(|(val, var)| {
|
||||
let res = region.assign_einsum(var, val)?;
|
||||
Ok(res.get_inner()?)
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
|
||||
// Now we can assign the dot product
|
||||
// time the calc
|
||||
let op_result = match op {
|
||||
BaseOp::Add => add(&inputs),
|
||||
BaseOp::Sub => sub(&inputs),
|
||||
BaseOp::Mult => mult(&inputs),
|
||||
_ => return Err(CircuitError::UnsupportedOp),
|
||||
}
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
|
||||
let assigned_len = op_result.len();
|
||||
let mut output = region.assign_einsum(output_var, &op_result.into())?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = output_var.cartesian_coord(region.einsum_col_coord() + i);
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: op.clone(),
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
let selector = config.selectors.get(&(op_info, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
output.reshape(&broadcasted_shape)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 1],
|
||||
phase: usize,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() == 1 {
|
||||
return Ok(values[0].clone());
|
||||
}
|
||||
assert!(phase == 0 || phase == 1);
|
||||
|
||||
region.flush_einsum()?;
|
||||
let mut input = values[0].clone();
|
||||
|
||||
let block_width = config.block_width();
|
||||
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let var = config.get_input_vars([phase].as_slice().into())[0];
|
||||
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
|
||||
// Now we can assign the dot product
|
||||
let accumulated_sum = accumulated::sum(&input, block_width)?;
|
||||
|
||||
let output_var = config.get_output_var([phase].as_slice().into());
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
output_var,
|
||||
&accumulated_sum.into(),
|
||||
check_mode,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
for i in 0..output_assigned_len {
|
||||
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// skip over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
continue;
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::SumInit,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::Sum,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
|
||||
region.enable(selector, z)?;
|
||||
}
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
// last element is the result
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 1],
|
||||
phase: usize,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phase == 0 || phase == 1);
|
||||
region.flush_einsum()?;
|
||||
let block_width = config.block_width();
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
|
||||
let var = config.get_input_vars([phase].as_slice().into())[0];
|
||||
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
|
||||
// Now we can assign the dot product
|
||||
let accumulated_prod = accumulated::prod(&input, block_width)?;
|
||||
|
||||
let output_var = config.get_output_var([phase].as_slice().into());
|
||||
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
|
||||
output_var,
|
||||
&accumulated_prod.into(),
|
||||
check_mode,
|
||||
)?;
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output_assigned_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) =
|
||||
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// skip over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::CumProdInit,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::CumProd,
|
||||
input_phases: [phase].as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
// last element is the result
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>; 2],
|
||||
phases: &[usize; 2],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() != values[1].len() {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
}
|
||||
|
||||
region.flush_einsum()?;
|
||||
// time this entire function run
|
||||
let global_start = instant::Instant::now();
|
||||
|
||||
let mut values = if phases[0] <= phases[1] {
|
||||
[values[0].clone(), values[1].clone()]
|
||||
} else {
|
||||
[values[1].clone(), values[0].clone()]
|
||||
};
|
||||
let vars = config.get_input_vars(phases.as_slice().into());
|
||||
|
||||
let mut inputs = vec![];
|
||||
let block_width = config.block_width();
|
||||
|
||||
let mut assigned_len = 0;
|
||||
for (val, var) in values.iter_mut().zip(vars) {
|
||||
val.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let inp = {
|
||||
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?;
|
||||
assigned_len = len;
|
||||
res.get_inner()?
|
||||
};
|
||||
inputs.push(inp);
|
||||
}
|
||||
|
||||
// Now we can assign the dot product
|
||||
// time this step
|
||||
let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?;
|
||||
let output_var = config.get_output_var(phases.as_slice().into());
|
||||
let (output, output_assigned_len) = region
|
||||
.assign_einsum_with_duplication_constrained(output_var, &accumulated_dot.into(), check_mode)
|
||||
.expect("failed to assign einsum with duplication constrained");
|
||||
|
||||
// enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output_assigned_len)
|
||||
.map(|i| {
|
||||
let (x, _, z) =
|
||||
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
// hop over duplicates at start of column
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::DotInit,
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
} else {
|
||||
let op_info = BaseOpInfo {
|
||||
op_kind: BaseOp::Dot,
|
||||
input_phases: phases.as_slice().into(),
|
||||
};
|
||||
config.selectors.get(&(op_info, x, 0))
|
||||
};
|
||||
region.enable(selector, z)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("dot layout took: {:?}, row {}", elapsed, region.row());
|
||||
trace!("----------------------------");
|
||||
Ok(last_elem)
|
||||
}
|
||||
|
||||
/// Dot product of more than two tensors
|
||||
pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[&ValTensor<F>],
|
||||
phases: &[usize],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert!(phases.iter().all(|phase| *phase == 0 || *phase == 1));
|
||||
if !values.iter().all(|value| value.len() == values[0].len()) {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
}
|
||||
// time this entire function run
|
||||
let global_start = instant::Instant::now();
|
||||
|
||||
let values: Vec<ValTensor<F>> = values.iter().copied().cloned().collect();
|
||||
// do pairwise dot product between intermediate tensor and the next tensor
|
||||
let (intermediate, output_phase) = values
|
||||
.into_iter()
|
||||
.zip(phases.iter().cloned())
|
||||
.reduce(|(intermediate, intermediate_phase), (input, phase)| {
|
||||
(
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[&intermediate, &input],
|
||||
BaseOp::Mult,
|
||||
&[intermediate_phase, phase],
|
||||
)
|
||||
.unwrap(),
|
||||
std::cmp::max(intermediate_phase, phase),
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let accumulated_dot = sum(config, region, &[&intermediate], output_phase, check_mode)?;
|
||||
let last_elem = accumulated_dot.last()?;
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("multi_dot layout took: {:?}, row {}", elapsed, region.row());
|
||||
trace!("----------------------------");
|
||||
Ok(last_elem)
|
||||
}
|
||||
@@ -1,867 +0,0 @@
|
||||
use crate::circuit::base::BaseOp;
|
||||
use crate::circuit::chip::einsum::analysis::{analyze_single_equation, EinsumAnalysis};
|
||||
use crate::circuit::einsum::layouts::{pairwise, sum};
|
||||
use crate::circuit::einsum::reduction_planner::Reduction;
|
||||
use crate::circuit::layouts::einsum_with_base_ops;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::{BaseConfig, CheckMode, CircuitError};
|
||||
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
Challenge, ConstraintSystem, Constraints, Expression, FirstPhase, Selector,
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use layouts::{dot, multi_dot, prod};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
///
|
||||
pub mod analysis;
|
||||
///
|
||||
pub mod circuit_params;
|
||||
mod layouts;
|
||||
mod reduction_planner;
|
||||
|
||||
/// The maximum number of challenges
|
||||
pub const NUM_MAX_EINSUM_CHALLENGES: usize = 10;
|
||||
|
||||
/// A struct representing reductions for the einsums
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// custom gate to constrain tensor contractions
|
||||
contraction_gate: ContractionConfig<F>,
|
||||
/// custom gate to constrain random linear combinations used by Freivalds' argument
|
||||
rlc_gates: Vec<RLCConfig<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
|
||||
///
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
let dummy_contraction_gate = ContractionConfig {
|
||||
inputs: [
|
||||
[dummy_var.clone(), dummy_var.clone()],
|
||||
[dummy_var.clone(), dummy_var.clone()],
|
||||
],
|
||||
outputs: [dummy_var.clone(), dummy_var.clone()],
|
||||
selectors: BTreeMap::default(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
Self {
|
||||
contraction_gate: dummy_contraction_gate,
|
||||
rlc_gates: (0..NUM_MAX_EINSUM_CHALLENGES)
|
||||
.map(|_| RLCConfig::dummy(&dummy_var))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure the columns based on universal Einsum analysis
|
||||
pub fn configure_universal(
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
analysis: &EinsumAnalysis,
|
||||
num_inner_cols: usize,
|
||||
logrows: usize,
|
||||
) -> Self {
|
||||
let capacity = analysis.reduction_length;
|
||||
let inputs: [VarTensor; 4] = [
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
];
|
||||
let outputs = [
|
||||
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
|
||||
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
|
||||
];
|
||||
let contraction_gate = ContractionConfig::new(
|
||||
meta,
|
||||
&[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]],
|
||||
&[&outputs[0], &outputs[1]],
|
||||
);
|
||||
|
||||
let mut rlc_gates = vec![];
|
||||
for _ in 0..analysis.max_num_output_axes {
|
||||
let rlc_gate =
|
||||
RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &outputs[1]);
|
||||
rlc_gates.push(rlc_gate);
|
||||
}
|
||||
|
||||
Self {
|
||||
contraction_gate,
|
||||
rlc_gates,
|
||||
}
|
||||
}
|
||||
|
||||
/// In dummy layout phase, calling this function will return error
|
||||
pub fn challenges(&self) -> Result<Vec<Challenge>, CircuitError> {
|
||||
self.rlc_gates
|
||||
.iter()
|
||||
.map(|gate| gate.challenge.ok_or(CircuitError::ChallengeNotSet))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
|
||||
///
|
||||
pub fn assign_einsum(
|
||||
&self,
|
||||
base_config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
input_tensors: &[&ValTensor<F>],
|
||||
output_tensor: &ValTensor<F>,
|
||||
equation: &str,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<(), CircuitError> {
|
||||
region.set_num_einsum_inner_cols(self.contraction_gate.block_width());
|
||||
|
||||
let (input_exprs, _) = equation.split_once("->").unwrap();
|
||||
let input_exprs = input_exprs.split(",").collect_vec();
|
||||
assert_eq!(input_exprs.len(), input_tensors.len());
|
||||
|
||||
let mut input_tensors = input_tensors.iter().copied().cloned().collect_vec();
|
||||
let mut output_tensor = output_tensor.clone();
|
||||
|
||||
let mut input_axes_to_dim: HashMap<char, usize> = HashMap::new();
|
||||
input_exprs
|
||||
.iter()
|
||||
.zip(input_tensors.iter())
|
||||
.for_each(|(indices, tensor)| {
|
||||
indices.chars().zip(tensor.dims()).for_each(|(index, dim)| {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) =
|
||||
input_axes_to_dim.entry(index)
|
||||
{
|
||||
e.insert(*dim);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let equation_analysis = analyze_single_equation(&equation, &input_axes_to_dim)?;
|
||||
let equation = equation_analysis.equation;
|
||||
|
||||
// Remove trivial axes from tensors
|
||||
input_tensors
|
||||
.iter_mut()
|
||||
.map(|tensor| tensor.remove_trivial_axes())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
output_tensor.remove_trivial_axes()?;
|
||||
|
||||
if matches!(
|
||||
equation_analysis.strategy,
|
||||
analysis::EinsumStrategy::BaseOps
|
||||
) {
|
||||
let _ = einsum_with_base_ops(
|
||||
base_config,
|
||||
region,
|
||||
&input_tensors.iter().collect_vec(),
|
||||
&equation,
|
||||
)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let output_shape = equation_analysis
|
||||
.output_indices
|
||||
.iter()
|
||||
.map(|c| input_axes_to_dim.get(c).copied().unwrap())
|
||||
.collect_vec();
|
||||
let squashed_output =
|
||||
self.assign_output(region, &output_tensor, output_shape, check_mode)?;
|
||||
|
||||
// reorder the reduction of input tensors and reduce
|
||||
let reordered_input_reductions = reduction_planner::input_reductions(&equation).unwrap();
|
||||
let mut tensors = input_tensors;
|
||||
let mut reduced_input_phase = 0;
|
||||
|
||||
for reduction in reordered_input_reductions.iter() {
|
||||
let (input_expr, output_expr) = reduction.expression().split_once("->").unwrap();
|
||||
let input_exprs = input_expr.split(",").collect_vec();
|
||||
|
||||
let remaining_axes = output_expr.chars().collect_vec();
|
||||
let mut remaining_axes_indices = remaining_axes
|
||||
.iter()
|
||||
.map(|c| 0..input_axes_to_dim[c])
|
||||
.multi_cartesian_product()
|
||||
.collect_vec();
|
||||
|
||||
// Dummy value to ensure the for loop runs at least once
|
||||
if remaining_axes.is_empty() {
|
||||
remaining_axes_indices.push(vec![]);
|
||||
}
|
||||
|
||||
let input_tensors = reduction
|
||||
.input_indices()
|
||||
.iter()
|
||||
.map(|idx| tensors[*idx].clone())
|
||||
.collect_vec();
|
||||
|
||||
let mut flattened_input_tensors: Vec<Vec<ValTensor<F>>> =
|
||||
vec![vec![]; input_tensors.len()];
|
||||
for remaining_axes_indices in remaining_axes_indices {
|
||||
// corresponds to 1 running sum of input tensors
|
||||
for (i, (input_tensor, input_expr)) in
|
||||
input_tensors.iter().zip(input_exprs.iter()).enumerate()
|
||||
{
|
||||
let mut sliced_dim = vec![];
|
||||
input_expr.chars().for_each(|axis| {
|
||||
if let Some(pos) = remaining_axes.iter().position(|c| *c == axis) {
|
||||
sliced_dim
|
||||
.push(remaining_axes_indices[pos]..remaining_axes_indices[pos] + 1);
|
||||
} else {
|
||||
// common axis
|
||||
sliced_dim.push(0..input_axes_to_dim[&axis]);
|
||||
}
|
||||
});
|
||||
let mut sliced_input_tensor = input_tensor.get_slice(&sliced_dim)?;
|
||||
sliced_input_tensor.flatten();
|
||||
flattened_input_tensors[i].push(sliced_input_tensor);
|
||||
}
|
||||
}
|
||||
let flattened_input_tensors = flattened_input_tensors
|
||||
.into_iter()
|
||||
.map(|tensors| {
|
||||
ValTensor::from(
|
||||
tensors
|
||||
.into_iter()
|
||||
.flat_map(|t| t.get_inner_tensor().unwrap().clone().into_iter())
|
||||
.collect_vec(),
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let output_dims = output_expr
|
||||
.chars()
|
||||
.map(|c| input_axes_to_dim[&c])
|
||||
.collect_vec();
|
||||
|
||||
let contracted_output = match reduction {
|
||||
Reduction::RLC {
|
||||
axis,
|
||||
input_phase,
|
||||
challenge_index,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(flattened_input_tensors.len(), 1);
|
||||
let rlc_len = input_axes_to_dim[axis];
|
||||
let mut result = self.rlc_gates[*challenge_index].assign_rlc(
|
||||
region,
|
||||
&flattened_input_tensors[0],
|
||||
region.challenges()[*challenge_index],
|
||||
rlc_len,
|
||||
*input_phase,
|
||||
check_mode,
|
||||
)?;
|
||||
result.reshape(&output_dims)?;
|
||||
result
|
||||
}
|
||||
Reduction::Contraction {
|
||||
axis, input_phases, ..
|
||||
} => match axis {
|
||||
Some(axis) => {
|
||||
let dot_product_len = input_axes_to_dim[axis];
|
||||
assign_input_contraction(
|
||||
&self.contraction_gate,
|
||||
region,
|
||||
flattened_input_tensors,
|
||||
dot_product_len,
|
||||
&output_dims,
|
||||
input_phases,
|
||||
check_mode,
|
||||
)?
|
||||
}
|
||||
None => {
|
||||
let mut result = assign_pairwise_mult(
|
||||
&self.contraction_gate,
|
||||
region,
|
||||
flattened_input_tensors,
|
||||
input_phases,
|
||||
)?;
|
||||
result.reshape(&output_dims)?;
|
||||
result
|
||||
}
|
||||
},
|
||||
};
|
||||
tensors.push(contracted_output);
|
||||
reduced_input_phase = reduction.output_phase();
|
||||
}
|
||||
tensors.retain(|tensor| tensor.is_singleton());
|
||||
|
||||
let scalars: ValTensor<F> = tensors
|
||||
.into_iter()
|
||||
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
|
||||
.collect_vec()
|
||||
.into();
|
||||
let squashed_input = prod(
|
||||
&self.contraction_gate,
|
||||
region,
|
||||
&[&scalars],
|
||||
reduced_input_phase,
|
||||
check_mode,
|
||||
)?;
|
||||
|
||||
region.constrain_equal(&squashed_input, &squashed_output)
|
||||
}
|
||||
|
||||
fn assign_output(
|
||||
&self,
|
||||
region: &mut RegionCtx<F>,
|
||||
output: &ValTensor<F>,
|
||||
output_shape: Vec<usize>,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut intermediate_values = output.clone();
|
||||
|
||||
let challenges = region
|
||||
.challenges()
|
||||
.iter()
|
||||
.take(output_shape.len())
|
||||
.copied()
|
||||
.collect_vec();
|
||||
|
||||
// Loop over the output axes
|
||||
for (idx, (rlc_config, challenge)) in self
|
||||
.rlc_gates
|
||||
.iter()
|
||||
.take(output_shape.len())
|
||||
.zip(challenges.iter())
|
||||
.rev()
|
||||
.enumerate()
|
||||
{
|
||||
let rlc_len = output_shape[output_shape.len() - idx - 1];
|
||||
intermediate_values.flatten();
|
||||
let phase = if idx > 0 { 1 } else { 0 };
|
||||
intermediate_values = rlc_config.assign_rlc(
|
||||
region,
|
||||
&intermediate_values,
|
||||
*challenge,
|
||||
rlc_len,
|
||||
phase,
|
||||
check_mode,
|
||||
)?;
|
||||
}
|
||||
|
||||
let phase = if challenges.len() > 0 { 1 } else { 0 };
|
||||
let output_var = self
|
||||
.contraction_gate
|
||||
.get_output_var([phase].as_slice().into());
|
||||
let res = region.assign_einsum(output_var, &intermediate_values)?;
|
||||
region.increment_einsum_col_coord(1);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_tensors: Vec<ValTensor<F>>,
|
||||
input_phases: &[usize],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert_eq!(flattened_tensors.len(), input_phases.len());
|
||||
let (result, _) = flattened_tensors
|
||||
.into_iter()
|
||||
.zip(input_phases.iter().cloned())
|
||||
.reduce(|(acc, acc_phase), (input, phase)| {
|
||||
(
|
||||
pairwise(
|
||||
config,
|
||||
region,
|
||||
&[&acc, &input],
|
||||
BaseOp::Mult,
|
||||
&[acc_phase, phase],
|
||||
)
|
||||
.unwrap(),
|
||||
std::cmp::max(acc_phase, phase),
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &ContractionConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_tensors: Vec<ValTensor<F>>,
|
||||
dot_product_len: usize,
|
||||
output_shape: &[usize],
|
||||
input_phases: &[usize],
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
assert_eq!(flattened_tensors.len(), input_phases.len());
|
||||
let num_dot_products = output_shape.iter().product();
|
||||
let mut dot_product_results = vec![];
|
||||
for chunk_idx in 0..num_dot_products {
|
||||
let start = chunk_idx * dot_product_len;
|
||||
let tensors: Vec<_> = flattened_tensors
|
||||
.iter()
|
||||
.map(|tensor| tensor.get_slice(&[start..(start + dot_product_len)]))
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
let result = if tensors.len() == 1 {
|
||||
sum(config, region, &[&tensors[0]], input_phases[0], check_mode)?
|
||||
} else if tensors.len() == 2 {
|
||||
dot(
|
||||
config,
|
||||
region,
|
||||
&[&tensors[0], &tensors[1]],
|
||||
&[input_phases[0], input_phases[1]],
|
||||
check_mode,
|
||||
)?
|
||||
} else {
|
||||
multi_dot(
|
||||
config,
|
||||
region,
|
||||
tensors.iter().collect_vec().as_slice(),
|
||||
input_phases,
|
||||
check_mode,
|
||||
)?
|
||||
};
|
||||
dot_product_results.push(result.get_inner_tensor()?.get_scalar());
|
||||
}
|
||||
let mut tensor = ValTensor::from(dot_product_results);
|
||||
tensor.reshape(output_shape)?;
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)]
|
||||
enum InputPhases {
|
||||
FirstPhase,
|
||||
SecondPhase,
|
||||
BothFirstPhase, // [0, 0]
|
||||
Mixed, // [0, 1] or [1, 0]
|
||||
BothSecondPhase, // [1, 1]
|
||||
}
|
||||
|
||||
impl From<&[usize]> for InputPhases {
|
||||
fn from(phases: &[usize]) -> Self {
|
||||
match phases {
|
||||
[0] => Self::FirstPhase,
|
||||
[1] => Self::SecondPhase,
|
||||
[0, 0] => Self::BothFirstPhase,
|
||||
[0, 1] | [1, 0] => Self::Mixed,
|
||||
[1, 1] => Self::BothSecondPhase,
|
||||
_ => panic!("Invalid phase combination"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
|
||||
struct BaseOpInfo {
|
||||
pub op_kind: BaseOp,
|
||||
pub input_phases: InputPhases,
|
||||
}
|
||||
|
||||
/// `ContractionConfig` is the custom gate to constrain tensor contractions
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct ContractionConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
// [[first phase, first phase], [second phase, second phase]]
|
||||
inputs: [[VarTensor; 2]; 2],
|
||||
// [first phase, second phase]
|
||||
outputs: [VarTensor; 2],
|
||||
// (BaseOpInfo, block index, inner column index) -> selector
|
||||
selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
|
||||
fn get_input_vars(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
|
||||
match input_phases {
|
||||
InputPhases::FirstPhase => vec![&self.inputs[0][0]],
|
||||
InputPhases::SecondPhase => vec![&self.inputs[1][0]],
|
||||
InputPhases::BothFirstPhase => vec![&self.inputs[0][0], &self.inputs[0][1]],
|
||||
InputPhases::Mixed => vec![&self.inputs[0][0], &self.inputs[1][0]],
|
||||
InputPhases::BothSecondPhase => vec![&self.inputs[1][0], &self.inputs[1][1]],
|
||||
}
|
||||
}
|
||||
|
||||
fn get_output_var(&self, input_phases: InputPhases) -> &VarTensor {
|
||||
match input_phases {
|
||||
InputPhases::FirstPhase => &self.outputs[0],
|
||||
InputPhases::SecondPhase => &self.outputs[1],
|
||||
InputPhases::BothFirstPhase => &self.outputs[0],
|
||||
InputPhases::Mixed => &self.outputs[1],
|
||||
InputPhases::BothSecondPhase => &self.outputs[1],
|
||||
}
|
||||
}
|
||||
|
||||
fn block_width(&self) -> usize {
|
||||
self.outputs[0].num_inner_cols()
|
||||
}
|
||||
|
||||
fn new(
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
inputs: &[&[&VarTensor; 2]; 2],
|
||||
outputs: &[&VarTensor; 2],
|
||||
) -> Self {
|
||||
let mut selectors = BTreeMap::new();
|
||||
let num_blocks = outputs[0].num_blocks();
|
||||
let block_width = outputs[0].num_inner_cols();
|
||||
for input_phases in [
|
||||
InputPhases::BothFirstPhase,
|
||||
InputPhases::Mixed,
|
||||
InputPhases::BothSecondPhase,
|
||||
] {
|
||||
for i in 0..num_blocks {
|
||||
for j in 0..block_width {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Mult,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
j,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
for i in 0..num_blocks {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::DotInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Dot,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for input_phases in [InputPhases::FirstPhase, InputPhases::SecondPhase] {
|
||||
for i in 0..num_blocks {
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::SumInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::Sum,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::CumProdInit,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
selectors.insert(
|
||||
(
|
||||
BaseOpInfo {
|
||||
op_kind: BaseOp::CumProd,
|
||||
input_phases,
|
||||
},
|
||||
i,
|
||||
0,
|
||||
),
|
||||
meta.selector(),
|
||||
);
|
||||
}
|
||||
}
|
||||
for ((base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
|
||||
let inputs = match base_op.input_phases {
|
||||
InputPhases::FirstPhase => vec![inputs[0][0]],
|
||||
InputPhases::SecondPhase => vec![inputs[1][0]],
|
||||
InputPhases::BothFirstPhase => vec![inputs[0][0], inputs[0][1]],
|
||||
InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]],
|
||||
InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]],
|
||||
};
|
||||
let output = match base_op.input_phases {
|
||||
InputPhases::FirstPhase => outputs[0],
|
||||
InputPhases::SecondPhase => outputs[1],
|
||||
InputPhases::BothFirstPhase => outputs[0],
|
||||
InputPhases::Mixed => outputs[1],
|
||||
InputPhases::BothSecondPhase => outputs[1],
|
||||
};
|
||||
assert_eq!(inputs.len(), base_op.op_kind.num_inputs());
|
||||
match base_op.op_kind {
|
||||
BaseOp::Mult => {
|
||||
meta.create_gate(base_op.op_kind.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
|
||||
let zero = Expression::<F>::Constant(F::ZERO);
|
||||
let mut qis = vec![zero; 2];
|
||||
for (q_i, input) in qis.iter_mut().zip(inputs) {
|
||||
*q_i = input
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("contraction config: input query failed")[0]
|
||||
.clone()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
.expect("contraction config: output query failed");
|
||||
|
||||
let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone()));
|
||||
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
meta.create_gate(base_op.op_kind.as_str(), |meta| {
|
||||
let selector = meta.query_selector(*selector);
|
||||
let mut qis = vec![vec![]; 2];
|
||||
for (q_i, input) in qis.iter_mut().zip(inputs) {
|
||||
*q_i = input
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("contraction config: input query failed")
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
// Get output expressions for each input channel
|
||||
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
|
||||
.expect("contraction config: output query failed");
|
||||
|
||||
let res = base_op.op_kind.accum_f(
|
||||
expected_output[0].clone(),
|
||||
qis[1].clone(),
|
||||
qis[0].clone(),
|
||||
);
|
||||
let constraints =
|
||||
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res];
|
||||
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let first_phase_inputs: [VarTensor; 2] = inputs[0]
|
||||
.iter()
|
||||
.copied()
|
||||
.cloned()
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let second_phase_inputs: [VarTensor; 2] = inputs[1]
|
||||
.iter()
|
||||
.copied()
|
||||
.cloned()
|
||||
.collect_vec()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
inputs: [first_phase_inputs, second_phase_inputs],
|
||||
outputs: [outputs[0].clone(), outputs[1].clone()],
|
||||
selectors,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `RLCConfig` is the custom gate used for random linear combination with the specific challenge
|
||||
#[derive(Clone, Debug)]
|
||||
struct RLCConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// The challenge used for the random linear combination
|
||||
/// Challenge is `None` in the dummy configuration
|
||||
pub challenge: Option<Challenge>,
|
||||
/// [first phase, second phase]
|
||||
pub inputs: [VarTensor; 2],
|
||||
pub output: VarTensor,
|
||||
/// (phase of input, block index) -> (init selector, acc selector)
|
||||
pub selectors: BTreeMap<(usize, usize), (Selector, Selector)>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
|
||||
fn dummy(dummy_var: &VarTensor) -> Self {
|
||||
let challenge = None;
|
||||
let inputs = [dummy_var.clone(), dummy_var.clone()];
|
||||
let output = dummy_var.clone();
|
||||
let selectors = BTreeMap::new();
|
||||
Self {
|
||||
challenge,
|
||||
inputs,
|
||||
output,
|
||||
selectors,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn new(meta: &mut ConstraintSystem<F>, inputs: &[VarTensor; 2], output: &VarTensor) -> Self {
|
||||
let challenge = meta.challenge_usable_after(FirstPhase);
|
||||
|
||||
let mut selectors = BTreeMap::new();
|
||||
for (phase, input) in inputs.iter().enumerate() {
|
||||
for block_idx in 0..input.num_blocks() {
|
||||
let selector = (meta.selector(), meta.selector());
|
||||
selectors.insert((phase, block_idx), selector);
|
||||
}
|
||||
}
|
||||
let block_width = output.num_inner_cols();
|
||||
let powers_of_challenge = (0..block_width)
|
||||
.scan(Expression::Constant(F::ONE), |r_power, _| {
|
||||
*r_power = r_power.clone() * challenge.expr();
|
||||
Some(r_power.clone())
|
||||
})
|
||||
.collect_vec();
|
||||
for ((phase, block_idx), (init_selector, acc_selector)) in selectors.iter() {
|
||||
meta.create_gate("init", |meta| {
|
||||
let selector = meta.query_selector(*init_selector);
|
||||
let input_exprs = inputs[*phase]
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("rlc config: input query failed")
|
||||
.into_iter()
|
||||
.collect();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, 0, 1)
|
||||
.expect("rlc config: output query failed");
|
||||
|
||||
let res = BaseOp::Dot.accum_f(
|
||||
Expression::Constant(F::ZERO),
|
||||
powers_of_challenge.iter().cloned().rev().collect_vec(),
|
||||
input_exprs,
|
||||
);
|
||||
vec![expected_output[0].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
meta.create_gate("acc", |meta| {
|
||||
let selector = meta.query_selector(*acc_selector);
|
||||
let input_exprs = inputs[*phase]
|
||||
.query_whole_block(meta, *block_idx, 0, 1)
|
||||
.expect("rlc config: input query failed")
|
||||
.into_iter()
|
||||
.collect();
|
||||
let constraints = {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, 0, -1, 2)
|
||||
.expect("rlc config: output query failed");
|
||||
|
||||
let res = BaseOp::Dot.accum_f(
|
||||
expected_output[0].clone() * powers_of_challenge.last().cloned().unwrap(),
|
||||
powers_of_challenge.iter().cloned().rev().collect_vec(),
|
||||
input_exprs,
|
||||
);
|
||||
vec![expected_output[1].clone() - res]
|
||||
};
|
||||
Constraints::with_selector(selector, constraints)
|
||||
});
|
||||
}
|
||||
Self {
|
||||
inputs: inputs.clone(),
|
||||
output: output.clone(),
|
||||
selectors,
|
||||
challenge: Some(challenge),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_rlc(
|
||||
&self,
|
||||
region: &mut RegionCtx<F>,
|
||||
flattened_input: &ValTensor<F>,
|
||||
challenge: Value<F>,
|
||||
rlc_len: usize,
|
||||
phase: usize,
|
||||
check_mode: &CheckMode,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
region.flush_einsum()?;
|
||||
let block_width = self.output.num_inner_cols();
|
||||
let powers_of_challenge = (0..block_width)
|
||||
.scan(Value::known(F::ONE), |challenge_power, _| {
|
||||
*challenge_power = challenge_power.clone() * challenge;
|
||||
Some(challenge_power.clone())
|
||||
})
|
||||
.collect_vec();
|
||||
let mut rlc_results: Vec<ValType<F>> = vec![];
|
||||
for tensor in flattened_input.get_inner_tensor()?.chunks_exact(rlc_len) {
|
||||
let running_sums = tensor
|
||||
.iter()
|
||||
.chunks(block_width)
|
||||
.into_iter()
|
||||
.scan(Value::known(F::ZERO), |state, val| {
|
||||
let curr_sum: Value<F> = val
|
||||
.into_iter()
|
||||
.zip(powers_of_challenge.iter().rev())
|
||||
.map(|(v, c_power)| {
|
||||
c_power.and_then(|c_power| {
|
||||
v.get_felt_eval()
|
||||
.and_then(|v| Some(Value::known(c_power * v)))
|
||||
.unwrap_or(Value::unknown())
|
||||
})
|
||||
})
|
||||
.reduce(|acc, v| acc + v)
|
||||
.unwrap();
|
||||
*state = *state * powers_of_challenge.last().unwrap() + curr_sum;
|
||||
Some(*state)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let assigned_len = {
|
||||
let mut input: ValTensor<F> = tensor.iter().collect_vec().into();
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
let (_, len) = region
|
||||
.assign_einsum_with_duplication_unconstrained(&self.inputs[phase], &input)?;
|
||||
len
|
||||
};
|
||||
let (assigned_output, assigned_output_len) = {
|
||||
let running_sums = running_sums.into_iter().map(ValType::from).collect_vec();
|
||||
region.assign_einsum_with_duplication_constrained(
|
||||
&self.output,
|
||||
&running_sums.into(),
|
||||
check_mode,
|
||||
)?
|
||||
};
|
||||
|
||||
(0..assigned_output_len)
|
||||
.map(|i| {
|
||||
let (block_idx, _, z) = self
|
||||
.output
|
||||
.cartesian_coord(region.einsum_col_coord() + i * block_width);
|
||||
if z == 0 && i > 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let selector = if i == 0 {
|
||||
self.selectors
|
||||
.get(&(phase, block_idx))
|
||||
.map(|(init, _)| init)
|
||||
} else {
|
||||
self.selectors.get(&(phase, block_idx)).map(|(_, acc)| acc)
|
||||
};
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
rlc_results.push(assigned_output.last()?.get_inner_tensor()?.get_scalar());
|
||||
|
||||
region.increment_einsum_col_coord(assigned_len);
|
||||
}
|
||||
Ok(rlc_results.into())
|
||||
}
|
||||
}
|
||||
@@ -1,205 +0,0 @@
|
||||
use std::{collections::BTreeSet, ops::Index};
|
||||
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
tensor::{TensorType, ValTensor},
|
||||
};
|
||||
|
||||
/// inj,jk->ik [inj,jk]
|
||||
/// inj,i->nj => RLC [jk,nj]
|
||||
/// jk,k->j => RLC [nj,j]
|
||||
/// nj,j->n => Contraction [n]
|
||||
/// n-> => Contraction []
|
||||
///
|
||||
/// bn,anm,bm->ba [bn,anm,bm]
|
||||
/// bn,bm->bnm => Contraction [anm,bnm]
|
||||
/// bnm,b->nm => RLC [anm,nm]
|
||||
/// anm,a->nm => RLC [nm,nm]
|
||||
/// nm,nm->m => Contraction [m]
|
||||
/// m-> => Contraction []
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Reduction {
|
||||
/// Random linear combination with powers of challenge along the axis
|
||||
RLC {
|
||||
expression: String,
|
||||
axis: char,
|
||||
/// Uniquely identifying index of input tensor to be reduced
|
||||
input_index: TensorIndex,
|
||||
/// phase of input tensor
|
||||
input_phase: usize,
|
||||
/// phase of output tensor
|
||||
output_phase: usize,
|
||||
challenge_index: usize,
|
||||
},
|
||||
Contraction {
|
||||
expression: String,
|
||||
/// when axis is `None`, the contraction is pairwise multiplication
|
||||
axis: Option<char>,
|
||||
/// Uniquely identifying indices of input tensors to be contracted
|
||||
input_indices: Vec<TensorIndex>,
|
||||
/// phases of input tensors
|
||||
input_phases: Vec<usize>,
|
||||
/// phase of output tensor
|
||||
output_phase: usize,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TensorIndex(usize);
|
||||
|
||||
impl<T: PrimeField + TensorType + PartialOrd> Index<TensorIndex> for Vec<ValTensor<T>> {
|
||||
type Output = ValTensor<T>;
|
||||
|
||||
fn index(&self, index: TensorIndex) -> &Self::Output {
|
||||
&self[index.0]
|
||||
}
|
||||
}
|
||||
|
||||
impl Reduction {
|
||||
pub fn expression(&self) -> &str {
|
||||
match self {
|
||||
Reduction::Contraction { expression, .. } => expression,
|
||||
Reduction::RLC { expression, .. } => &expression,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn input_indices(&self) -> Vec<TensorIndex> {
|
||||
match self {
|
||||
Reduction::Contraction { input_indices, .. } => input_indices.clone(),
|
||||
Reduction::RLC { input_index, .. } => vec![*input_index],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn output_phase(&self) -> usize {
|
||||
match self {
|
||||
Reduction::Contraction { output_phase, .. } => *output_phase,
|
||||
Reduction::RLC { output_phase, .. } => *output_phase,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn input_reductions(expression: &str) -> Result<Vec<Reduction>, CircuitError> {
|
||||
let (input_exprs, output_expr) = expression.split_once("->").unwrap();
|
||||
let input_exprs: Vec<_> = input_exprs.split(",").map(|eq| eq.to_string()).collect();
|
||||
// (phase, expression)
|
||||
let input_exprs: Vec<(usize, String)> =
|
||||
input_exprs.into_iter().map(|expr| (0, expr)).collect_vec();
|
||||
|
||||
let mut input_tensor_counter = input_exprs.len();
|
||||
let mut input_exprs: Vec<((usize, String), TensorIndex)> = input_exprs
|
||||
.into_iter()
|
||||
.zip((0..input_tensor_counter).map(TensorIndex))
|
||||
.collect();
|
||||
let mut reductions: Vec<Reduction> = vec![];
|
||||
|
||||
// Reduce input_exprs along given axis
|
||||
let mut reduce = |input_exprs: Vec<((usize, String), TensorIndex)>,
|
||||
axis: char|
|
||||
-> (Reduction, Vec<((usize, String), TensorIndex)>) {
|
||||
let inputs = input_exprs
|
||||
.iter()
|
||||
.filter(|((_, eq), _)| eq.chars().contains(&axis))
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
let (inputs_axes, input_indices): (Vec<(usize, String)>, Vec<TensorIndex>) =
|
||||
inputs.iter().cloned().unzip();
|
||||
let (input_phases, inputs_axes): (Vec<usize>, Vec<String>) =
|
||||
inputs_axes.into_iter().unzip();
|
||||
|
||||
let is_output_axis = output_expr.chars().contains(&axis);
|
||||
let output: String = if is_output_axis == true && inputs.len() > 1 {
|
||||
let output: BTreeSet<char> =
|
||||
inputs_axes.iter().flat_map(|input| input.chars()).collect();
|
||||
output.iter().collect()
|
||||
} else {
|
||||
let output: BTreeSet<char> = inputs_axes
|
||||
.iter()
|
||||
.flat_map(|input| input.chars().filter(|&c| c != axis))
|
||||
.collect();
|
||||
output.iter().collect()
|
||||
};
|
||||
|
||||
let reduction = if is_output_axis == true && inputs.len() == 1 {
|
||||
let mut expression = inputs_axes.join(",");
|
||||
expression.push_str(format!(",{axis}").as_str());
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::RLC {
|
||||
expression,
|
||||
axis,
|
||||
input_index: input_indices[0],
|
||||
input_phase: input_phases[0],
|
||||
output_phase: 1,
|
||||
challenge_index: output_expr.chars().position(|c| c == axis).unwrap(),
|
||||
}
|
||||
} else if is_output_axis == true {
|
||||
let mut expression = inputs_axes.join(",");
|
||||
let output_phase = input_phases.iter().copied().max().unwrap();
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::Contraction {
|
||||
expression,
|
||||
axis: None,
|
||||
input_indices: input_indices,
|
||||
input_phases,
|
||||
output_phase,
|
||||
}
|
||||
} else {
|
||||
let mut expression = inputs_axes.join(",");
|
||||
let output_phase = input_phases.iter().copied().max().unwrap();
|
||||
expression.push_str("->");
|
||||
expression.push_str(&output);
|
||||
Reduction::Contraction {
|
||||
expression,
|
||||
axis: Some(axis),
|
||||
input_indices: input_indices,
|
||||
input_phases,
|
||||
output_phase,
|
||||
}
|
||||
};
|
||||
|
||||
// Mutate input_exprs
|
||||
let mut input_exprs = input_exprs.clone();
|
||||
input_exprs.retain(|((_, input_eq), _)| !inputs_axes.contains(input_eq));
|
||||
input_exprs.push((
|
||||
(reduction.output_phase(), output.clone()),
|
||||
TensorIndex(input_tensor_counter),
|
||||
));
|
||||
input_tensor_counter += 1;
|
||||
|
||||
(reduction, input_exprs)
|
||||
};
|
||||
|
||||
let mut output_axes = output_expr.chars().collect_vec();
|
||||
while let Some(axis) = output_axes.first().cloned() {
|
||||
let num_inputs = input_exprs
|
||||
.iter()
|
||||
.filter(|((_, eq), _)| eq.chars().contains(&axis))
|
||||
.count();
|
||||
if num_inputs == 0 {
|
||||
output_axes.remove(0);
|
||||
} else {
|
||||
let (reduction, new_input_exprs) = reduce(input_exprs, axis);
|
||||
reductions.push(reduction);
|
||||
input_exprs = new_input_exprs;
|
||||
}
|
||||
}
|
||||
|
||||
// These are not output axes and were not contracted with random vectors
|
||||
let remaining_axes: BTreeSet<_> = input_exprs
|
||||
.iter()
|
||||
.flat_map(|((_, eq), _)| eq.chars())
|
||||
.collect();
|
||||
|
||||
for axis in remaining_axes.iter() {
|
||||
let (reduction, new_input_exprs) = reduce(input_exprs, *axis);
|
||||
reductions.push(reduction);
|
||||
input_exprs = new_input_exprs;
|
||||
}
|
||||
|
||||
Ok(reductions)
|
||||
}
|
||||
@@ -46,9 +46,6 @@ pub enum CircuitError {
|
||||
/// Failed to get shuffle
|
||||
#[error("failed to get shuffle for op: {0}")]
|
||||
GetShuffleError(String),
|
||||
/// Failed to get einsum
|
||||
#[error("failed to get einsum for op: {0}")]
|
||||
GetEinsumError(String),
|
||||
/// Failed to get constants
|
||||
#[error("failed to get constants for op: {0}")]
|
||||
GetConstantsError(String),
|
||||
@@ -64,9 +61,6 @@ pub enum CircuitError {
|
||||
/// Missing product in einsum
|
||||
#[error("missing product in einsum")]
|
||||
MissingEinsumProduct,
|
||||
/// Missing config in einsum
|
||||
#[error("missing config in einsum")]
|
||||
MissingEinsumConfig,
|
||||
/// Mismatched lookup length
|
||||
#[error("mismatched lookup lengths: {0} and {1}")]
|
||||
MismatchedLookupLength(usize, usize),
|
||||
@@ -115,7 +109,4 @@ pub enum CircuitError {
|
||||
/// A decomposition base overflowed
|
||||
#[error("decomposition base overflowed")]
|
||||
DecompositionBaseOverflow,
|
||||
/// Challenge not set
|
||||
#[error("challenge not set")]
|
||||
ChallengeNotSet,
|
||||
}
|
||||
|
||||
@@ -22,8 +22,9 @@ use crate::{
|
||||
tensor::{
|
||||
create_unit_tensor, get_broadcasted_shape,
|
||||
ops::{accumulated, add, mult, sub},
|
||||
DataFormat, KernelFormat, Tensor, TensorError, ValType,
|
||||
Tensor, TensorError, ValType,
|
||||
},
|
||||
tensor::{DataFormat, KernelFormat},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
@@ -822,73 +823,14 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// let result = einsum::<Fp>(&dummy_config, &mut dummy_region, &[&x, &k], "mk,n->ma").unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1773]), &[1, 1]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
inputs: &[&ValTensor<F>],
|
||||
equation: &str,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut eq = equation.split("->");
|
||||
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
|
||||
|
||||
// Check that the number of inputs matches the number of inputs in the equation
|
||||
if inputs.len() != inputs_eq.len() {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
|
||||
let mut indices_to_size = HashMap::new();
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
for j in 0..inputs_eq[i].len() {
|
||||
let c = inputs_eq[i]
|
||||
.chars()
|
||||
.nth(j)
|
||||
.ok_or(CircuitError::InvalidEinsum)?;
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) {
|
||||
e.insert(input.dims()[j]);
|
||||
} else if indices_to_size[&c] != input.dims()[j] {
|
||||
return Err(TensorError::DimMismatch("einsum".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track the einsum equation
|
||||
region.add_used_einsum_equation(equation.to_string(), &indices_to_size)?;
|
||||
|
||||
if config.einsums.is_none() {
|
||||
return einsum_with_base_ops(config, region, inputs, equation);
|
||||
}
|
||||
|
||||
let input_values = inputs
|
||||
.iter()
|
||||
.map(|t| t.get_inner())
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
let (output_tensor, _) =
|
||||
crate::tensor::ops::accumulated::einsum(equation, &input_values.iter().collect_vec())?;
|
||||
|
||||
config.einsums.as_ref().unwrap().assign_einsum(
|
||||
config,
|
||||
region,
|
||||
inputs,
|
||||
&output_tensor.clone().into(),
|
||||
equation,
|
||||
&config.check_mode,
|
||||
)?;
|
||||
|
||||
let output: ValTensor<F> = output_tensor.into();
|
||||
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// einsum with base ops
|
||||
pub fn einsum_with_base_ops<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
inputs: &[&ValTensor<F>],
|
||||
equation: &str,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut equation = equation.split("->");
|
||||
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
|
||||
@@ -1094,8 +1036,6 @@ pub fn einsum_with_base_ops<F: PrimeField + TensorType + PartialOrd + std::hash:
|
||||
|
||||
let output: ValTensor<F> = output.into();
|
||||
|
||||
region.increment_einsum_index(1);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
@@ -6229,9 +6169,9 @@ pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
|
||||
(0..num_first_dims)
|
||||
.flat_map(|_| {
|
||||
(0..n).rev().map(|x| {
|
||||
let base = (*base as IntegerRep).checked_pow(x as u32);
|
||||
let base = (*base).checked_pow(x as u32);
|
||||
if let Some(base) = base {
|
||||
Ok(ValType::Constant(integer_rep_to_felt(base)))
|
||||
Ok(ValType::Constant(integer_rep_to_felt(base as IntegerRep)))
|
||||
} else {
|
||||
Err(CircuitError::DecompositionBaseOverflow)
|
||||
}
|
||||
|
||||
@@ -364,15 +364,7 @@ impl<
|
||||
};
|
||||
Ok(Some(if self.decomp {
|
||||
log::debug!("constraining constant to be decomp");
|
||||
super::layouts::decompose(
|
||||
config,
|
||||
region,
|
||||
&[&value],
|
||||
®ion.base(),
|
||||
®ion.legs(),
|
||||
false,
|
||||
)?
|
||||
.1
|
||||
super::layouts::decompose(config, region, &[&value], ®ion.base(), ®ion.legs(), false)?.1
|
||||
} else {
|
||||
log::debug!("constraining constant to be identity");
|
||||
super::layouts::identity(config, region, &[&value])?
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use crate::{
|
||||
circuit::{einsum::NUM_MAX_EINSUM_CHALLENGES, table::Range},
|
||||
circuit::table::Range,
|
||||
fieldutils::IntegerRep,
|
||||
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use colored::Colorize;
|
||||
use halo2_proofs::{
|
||||
circuit::{Region, Value},
|
||||
circuit::Region,
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -85,45 +85,6 @@ impl ShuffleIndex {
|
||||
}
|
||||
}
|
||||
|
||||
/// Einsum index
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct EinsumIndex {
|
||||
index: usize,
|
||||
col_coord: usize,
|
||||
// (einsum equation, input axes to dimensions map)
|
||||
equations: Vec<(String, HashMap<char, usize>)>,
|
||||
num_inner_cols: usize,
|
||||
}
|
||||
|
||||
impl EinsumIndex {
|
||||
/// Create a new einsum index
|
||||
pub fn new(index: usize, col_coord: usize, num_inner_cols: usize) -> EinsumIndex {
|
||||
EinsumIndex {
|
||||
index,
|
||||
col_coord,
|
||||
equations: Vec::new(),
|
||||
num_inner_cols,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the einsum index
|
||||
pub fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
/// Get the column coord
|
||||
pub fn col_coord(&self) -> usize {
|
||||
self.col_coord
|
||||
}
|
||||
|
||||
/// update with another einsum index
|
||||
pub fn update(&mut self, other: &EinsumIndex) {
|
||||
self.index += other.index;
|
||||
self.col_coord += other.col_coord;
|
||||
self.equations.extend(other.equations.clone());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Some settings for a region to differentiate it across the different phases of proof generation
|
||||
pub struct RegionSettings {
|
||||
@@ -215,11 +176,9 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
einsum_index: EinsumIndex,
|
||||
statistics: RegionStatistics,
|
||||
settings: RegionSettings,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
challenges: Vec<Value<F>>,
|
||||
max_dynamic_input_len: usize,
|
||||
}
|
||||
|
||||
@@ -291,16 +250,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.shuffle_index.col_coord += n;
|
||||
}
|
||||
|
||||
/// increment the einsum index
|
||||
pub fn increment_einsum_index(&mut self, n: usize) {
|
||||
self.einsum_index.index += n;
|
||||
}
|
||||
|
||||
/// increment the einsum column coordinate
|
||||
pub fn increment_einsum_col_coord(&mut self, n: usize) {
|
||||
self.einsum_index.col_coord += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn witness_gen(&self) -> bool {
|
||||
self.settings.witness_gen
|
||||
@@ -316,11 +265,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&self.statistics
|
||||
}
|
||||
|
||||
///
|
||||
pub fn challenges(&self) -> &[Value<F>] {
|
||||
&self.challenges
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new(
|
||||
region: Region<'a, F>,
|
||||
@@ -339,11 +283,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
linear_coord,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
einsum_index: EinsumIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(decomp_base, decomp_legs),
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges: vec![],
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -362,20 +304,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
new_self
|
||||
}
|
||||
|
||||
/// Create a new region context with challenges
|
||||
pub fn new_with_challenges(
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
decomp_base: usize,
|
||||
decomp_legs: usize,
|
||||
challenges: Vec<Value<F>>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
|
||||
new_self.challenges = challenges;
|
||||
new_self
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
@@ -392,11 +320,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
einsum_index: EinsumIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges: vec![Value::unknown(); NUM_MAX_EINSUM_CHALLENGES],
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -407,7 +333,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
settings: RegionSettings,
|
||||
challenges: Vec<Value<F>>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -417,11 +342,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
einsum_index: EinsumIndex::default(),
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
challenges,
|
||||
max_dynamic_input_len: 0,
|
||||
}
|
||||
}
|
||||
@@ -475,7 +398,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let einsum_index = Arc::new(Mutex::new(self.einsum_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output.par_enum_map(|idx, _| {
|
||||
@@ -490,7 +412,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.settings.clone(),
|
||||
self.challenges.clone(),
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -509,9 +430,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the einsum index
|
||||
let mut einsum_index = einsum_index.lock().unwrap();
|
||||
einsum_index.update(&local_reg.einsum_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
@@ -532,10 +450,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
|
||||
self.einsum_index = Arc::try_unwrap(einsum_index)
|
||||
.map_err(|e| CircuitError::GetEinsumError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetEinsumError(format!("{:?}", e)))?;
|
||||
self.assigned_constants = Arc::try_unwrap(constants)
|
||||
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
@@ -602,18 +516,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
|
||||
/// add used einsum equation
|
||||
pub fn add_used_einsum_equation(
|
||||
&mut self,
|
||||
equation: String,
|
||||
input_axes_to_dims: &HashMap<char, usize>,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.einsum_index
|
||||
.equations
|
||||
.push((equation, input_axes_to_dims.clone()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the offset
|
||||
pub fn row(&self) -> usize {
|
||||
self.row
|
||||
@@ -649,31 +551,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.shuffle_index.col_coord
|
||||
}
|
||||
|
||||
/// einsum index
|
||||
pub fn einsum_index(&self) -> usize {
|
||||
self.einsum_index.index
|
||||
}
|
||||
|
||||
/// einsum column coordinate
|
||||
pub fn einsum_col_coord(&self) -> usize {
|
||||
self.einsum_index.col_coord
|
||||
}
|
||||
|
||||
/// get used einsum equations
|
||||
pub fn used_einsum_equations(&self) -> Vec<(String, HashMap<char, usize>)> {
|
||||
self.einsum_index.equations.clone()
|
||||
}
|
||||
|
||||
/// set the number of inner columns used in einsum custom gate
|
||||
pub fn set_num_einsum_inner_cols(&mut self, num_inner_cols: usize) {
|
||||
self.einsum_index.num_inner_cols = num_inner_cols;
|
||||
}
|
||||
|
||||
/// number of inner columns used in einsum custom gate
|
||||
pub fn num_einsum_inner_cols(&self) -> usize {
|
||||
self.einsum_index.num_inner_cols
|
||||
}
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.statistics.used_lookups.clone()
|
||||
@@ -763,28 +640,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor in einsum area
|
||||
pub fn assign_einsum(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
Ok(var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)?)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
@@ -842,63 +697,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_einsum_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication_unconstrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
false,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_einsum_with_duplication_constrained(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &crate::circuit::CheckMode,
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len) = var.assign_with_duplication_constrained(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
check_mode,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.einsum_col_coord(),
|
||||
values,
|
||||
true,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable a selector
|
||||
pub fn enable(&mut self, selector: Option<&Selector>, offset: usize) -> Result<(), Error> {
|
||||
match &self.region {
|
||||
@@ -965,19 +763,4 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// flush row to the next row in einsum area
|
||||
pub fn flush_einsum(&mut self) -> Result<(), CircuitError> {
|
||||
// increment by the difference between the current linear coord and the next row
|
||||
let num_einsum_inner_cols = self.num_einsum_inner_cols();
|
||||
let remainder = self.einsum_col_coord() % num_einsum_inner_cols;
|
||||
if remainder != 0 {
|
||||
let diff = num_einsum_inner_cols - remainder;
|
||||
self.increment_einsum_col_coord(diff);
|
||||
}
|
||||
if self.einsum_col_coord() % num_einsum_inner_cols != 0 {
|
||||
return Err(CircuitError::FlushError);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,6 +359,8 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
&pk,
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
@@ -478,6 +480,8 @@ mod matmul_col_ultra_overflow {
|
||||
&pk,
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
@@ -1294,6 +1298,8 @@ mod conv_col_ultra_overflow {
|
||||
&pk,
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
@@ -1461,6 +1467,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
// use safe mode to verify that the proof is correct
|
||||
None,
|
||||
None,
|
||||
@@ -2635,6 +2643,8 @@ mod lookup_ultra_overflow {
|
||||
&pk,
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
188
src/commands.rs
188
src/commands.rs
@@ -9,9 +9,10 @@ use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use tosubcommand::{ToFlags, ToSubcommand};
|
||||
|
||||
use crate::RunArgs;
|
||||
use crate::{pfsys::ProofType, Commitments, RunArgs};
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::pfsys::TranscriptType;
|
||||
|
||||
/// The default path to the .json data file
|
||||
pub const DEFAULT_DATA: &str = "input.json";
|
||||
@@ -27,16 +28,26 @@ pub const DEFAULT_SETTINGS: &str = "settings.json";
|
||||
pub const DEFAULT_PK: &str = "pk.key";
|
||||
/// The default path to the verification key file
|
||||
pub const DEFAULT_VK: &str = "vk.key";
|
||||
/// The default path to the proving key file for aggregated proofs
|
||||
pub const DEFAULT_PK_AGGREGATED: &str = "pk_aggr.key";
|
||||
/// The default path to the verification key file for aggregated proofs
|
||||
pub const DEFAULT_VK_AGGREGATED: &str = "vk_aggr.key";
|
||||
/// The default path to the proof file
|
||||
pub const DEFAULT_PROOF: &str = "proof.json";
|
||||
/// The default path to the proof file for aggregated proofs
|
||||
pub const DEFAULT_PROOF_AGGREGATED: &str = "proof_aggr.json";
|
||||
/// Default for whether to split proofs
|
||||
pub const DEFAULT_SPLIT: &str = "false";
|
||||
/// Default verifier abi
|
||||
pub const DEFAULT_VERIFIER_ABI: &str = "verifier_abi.json";
|
||||
/// Default verifier abi for aggregated proofs
|
||||
pub const DEFAULT_VERIFIER_AGGREGATED_ABI: &str = "verifier_aggr_abi.json";
|
||||
/// Default solidity code
|
||||
pub const DEFAULT_SOL_CODE: &str = "evm_deploy.sol";
|
||||
/// Default calldata path
|
||||
pub const DEFAULT_CALLDATA: &str = "calldata.bytes";
|
||||
/// Default solidity code for aggregated proofs
|
||||
pub const DEFAULT_SOL_CODE_AGGREGATED: &str = "evm_deploy_aggr.sol";
|
||||
/// Default contract address
|
||||
pub const DEFAULT_CONTRACT_ADDRESS: &str = "contract.address";
|
||||
/// Default contract address for vk
|
||||
@@ -45,6 +56,8 @@ pub const DEFAULT_CONTRACT_ADDRESS_VK: &str = "contract_vk.address";
|
||||
pub const DEFAULT_CHECKMODE: &str = "safe";
|
||||
/// Default calibration target
|
||||
pub const DEFAULT_CALIBRATION_TARGET: &str = "resources";
|
||||
/// Default logrows for aggregated proofs
|
||||
pub const DEFAULT_AGGREGATED_LOGROWS: &str = "23";
|
||||
/// Default optimizer runs
|
||||
pub const DEFAULT_OPTIMIZER_RUNS: &str = "1";
|
||||
/// Default fuzz runs
|
||||
@@ -78,6 +91,35 @@ pub const DEFAULT_DECIMALS: &str = "18";
|
||||
/// Default path for the vka digest file
|
||||
pub const DEFAULT_VKA_DIGEST: &str = "vka.digest";
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
impl<'py> IntoPyObject<'py> for TranscriptType {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
TranscriptType::Poseidon => "poseidon",
|
||||
TranscriptType::EVM => "evm",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains TranscriptType from PyObject (Required for TranscriptType to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for TranscriptType {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let trystr = String::extract_bound(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
match strval.to_lowercase().as_str() {
|
||||
"poseidon" => Ok(TranscriptType::Poseidon),
|
||||
"evm" => Ok(TranscriptType::EVM),
|
||||
_ => Err(PyValueError::new_err("Invalid value for TranscriptType")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// Determines what the calibration pass should optimize for
|
||||
pub enum CalibrationTarget {
|
||||
@@ -145,6 +187,7 @@ pub enum ContractType {
|
||||
/// Deploys a verifier contrat tailored to the circuit and not reusable
|
||||
Verifier {
|
||||
/// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits)
|
||||
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
|
||||
reusable: bool,
|
||||
},
|
||||
}
|
||||
@@ -226,7 +269,9 @@ impl<'py> IntoPyObject<'py> for CalibrationTarget {
|
||||
|
||||
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
CalibrationTarget::Resources { col_overflow: true } => "resources/col-overflow",
|
||||
CalibrationTarget::Resources { col_overflow: true } => {
|
||||
"resources/col-overflow"
|
||||
}
|
||||
CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
} => "resources",
|
||||
@@ -492,6 +537,9 @@ pub enum Commands {
|
||||
/// number of logrows to use for srs
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
logrows: usize,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
|
||||
/// Gets an SRS from a circuit settings file.
|
||||
@@ -506,6 +554,9 @@ pub enum Commands {
|
||||
/// Number of logrows to use for srs. Overrides settings_path if specified.
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// Commitment used
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Loads model and input and runs mock prover (for testing)
|
||||
Mock {
|
||||
@@ -517,6 +568,82 @@ pub enum Commands {
|
||||
model: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Mock aggregate proofs
|
||||
MockAggregate {
|
||||
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// whether the accumulated are segments of a larger proof
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
split_proofs: Option<bool>,
|
||||
},
|
||||
|
||||
/// Setup aggregation circuit and generate pk and vk
|
||||
SetupAggregate {
|
||||
/// The path to samples of snarks that will be aggregated over (generated using the prove command with the --proof-type=for-aggr flag)
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
/// The path to save the desired verification key file to
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to save the proving key to
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// whether the accumulated are segments of a larger proof
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
split_proofs: Option<bool>,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
|
||||
disable_selector_compression: Option<bool>,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Aggregates proofs
|
||||
Aggregate {
|
||||
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
/// The path to load the desired proving key file (generated using the setup-aggregate command)
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
long,
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = TranscriptType::default(),
|
||||
value_enum,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
transcript: TranscriptType,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// run sanity checks during calculations (safe or unsafe)
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
|
||||
check_mode: Option<CheckMode>,
|
||||
/// whether the accumulated proofs are segments of a larger circuit
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
split_proofs: Option<bool>,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
|
||||
CompileCircuit {
|
||||
/// The path to the .onnx model file
|
||||
@@ -577,6 +704,15 @@ pub enum Commands {
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
long,
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = ProofType::Single,
|
||||
value_enum,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
proof_type: ProofType,
|
||||
/// run sanity checks during calculations (safe or unsafe)
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
|
||||
check_mode: Option<CheckMode>,
|
||||
@@ -644,6 +780,32 @@ pub enum Commands {
|
||||
decimals: Option<usize>,
|
||||
},
|
||||
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
// aggregated circuit settings paths, used to calculate the number of instances in the aggregate proof
|
||||
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// Whether to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier.
|
||||
#[cfg_attr(all(feature = "reusable-verifier", not(target_arch = "wasm32")), arg(short = 'R', long, action = clap::ArgAction::SetTrue))]
|
||||
reusable: Option<bool>,
|
||||
},
|
||||
/// Verifies a proof, returning accept or reject
|
||||
Verify {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -662,7 +824,27 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
},
|
||||
|
||||
/// Verifies an aggregate proof, returning accept or reject
|
||||
VerifyAggr {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to the verification key file (generated using the setup-aggregate command)
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// reduced srs
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// commitment
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
|
||||
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
|
||||
DeployEvm {
|
||||
|
||||
11
src/eth.rs
11
src/eth.rs
@@ -1,3 +1,4 @@
|
||||
use crate::pfsys::evm::EvmVerificationError;
|
||||
use crate::pfsys::{encode_calldata, Snark};
|
||||
use alloy::contract::CallBuilder;
|
||||
use alloy::core::primitives::Address as H160;
|
||||
@@ -56,6 +57,8 @@ pub enum EthError {
|
||||
Wallet(#[from] WalletError),
|
||||
#[error("failed to parse url {0}")]
|
||||
UrlParse(String),
|
||||
#[error("evm verification error: {0}")]
|
||||
EvmVerification(#[from] EvmVerificationError),
|
||||
#[error("Private key must be in hex format, 64 chars, without 0x prefix")]
|
||||
PrivateKeyFormat,
|
||||
#[error("failed to parse hex: {0}")]
|
||||
@@ -97,8 +100,6 @@ pub enum EthError {
|
||||
VkaData(String),
|
||||
#[error("rescaled‑instance mismatch: {0}")]
|
||||
RescaleCheckError(#[from] RescaleCheckError),
|
||||
#[error("evm verification error: {0}")]
|
||||
EvmVerificationError(String),
|
||||
}
|
||||
|
||||
pub type EthersClient = Arc<
|
||||
@@ -197,7 +198,7 @@ pub async fn register_vka_via_rv(
|
||||
let result = client.call(&tx).await;
|
||||
|
||||
if let Err(e) = result {
|
||||
return Err(EthError::EvmVerificationError(e.to_string()).into());
|
||||
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
|
||||
}
|
||||
let result = result?;
|
||||
debug!("result: {:#?}", result.to_vec());
|
||||
@@ -269,7 +270,7 @@ pub async fn verify_proof_via_solidity(
|
||||
let result = client.call(&tx).await;
|
||||
|
||||
if let Err(e) = result {
|
||||
return Err(EthError::EvmVerificationError(e.to_string()).into());
|
||||
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
|
||||
}
|
||||
let result = result?;
|
||||
debug!("result: {:#?}", result.to_vec());
|
||||
@@ -305,7 +306,7 @@ pub async fn verify_proof_via_solidity(
|
||||
.ok_or(EthError::NoContractOutput)?
|
||||
== &1u8;
|
||||
if !result {
|
||||
return Err(EthError::EvmVerificationError("Invalid proof".into()));
|
||||
return Err(EvmVerificationError::InvalidProof.into());
|
||||
}
|
||||
|
||||
let gas = client.estimate_gas(&tx).await?;
|
||||
|
||||
1011
src/execute.rs
1011
src/execute.rs
File diff suppressed because it is too large
Load Diff
@@ -427,6 +427,8 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use crate::pfsys::field_to_string;
|
||||
|
||||
@@ -441,9 +443,13 @@ impl<'py> IntoPyObject<'py> for FileSourceInner {
|
||||
FileSourceInner::Field(data) => {
|
||||
let s = field_to_string(&data);
|
||||
Ok(pyo3::types::PyString::new(py, &s).into_any())
|
||||
}
|
||||
FileSourceInner::Bool(data) => Ok(pyo3::types::PyBool::new(py, data).as_any().clone()),
|
||||
FileSourceInner::Float(data) => Ok(pyo3::types::PyFloat::new(py, data).into_any()),
|
||||
},
|
||||
FileSourceInner::Bool(data) => {
|
||||
Ok(pyo3::types::PyBool::new(py, data).as_any().clone())
|
||||
},
|
||||
FileSourceInner::Float(data) => {
|
||||
Ok(pyo3::types::PyFloat::new(py, data).into_any())
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
167
src/graph/mod.rs
167
src/graph/mod.rs
@@ -64,7 +64,6 @@ use pyo3::types::PyDictMethods;
|
||||
use pyo3::IntoPyObject;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Deref;
|
||||
pub use utilities::*;
|
||||
pub use vars::*;
|
||||
@@ -439,15 +438,6 @@ pub struct ShuffleParams {
|
||||
pub total_shuffle_col_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||
/// Parameters for einsum operations
|
||||
pub struct EinsumParams {
|
||||
/// einsum equations
|
||||
pub equations: Vec<(String, HashMap<char, usize>)>,
|
||||
/// total einsum column size
|
||||
pub total_einsum_col_size: usize,
|
||||
}
|
||||
|
||||
/// model parameters
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct GraphSettings {
|
||||
@@ -463,8 +453,6 @@ pub struct GraphSettings {
|
||||
pub dynamic_lookup_params: DynamicLookupParams,
|
||||
/// shuffle parameters, flattened for backwards compatibility
|
||||
pub shuffle_params: ShuffleParams,
|
||||
/// einsum parameters
|
||||
pub einsum_params: EinsumParams,
|
||||
/// the shape of public inputs to the model (in order of appearance)
|
||||
pub model_instance_shapes: Vec<Vec<usize>>,
|
||||
/// model output scales
|
||||
@@ -499,7 +487,7 @@ impl Serialize for GraphSettings {
|
||||
if serializer.is_human_readable() {
|
||||
// JSON format - use flattened fields for backwards compatibility
|
||||
use serde::ser::SerializeStruct;
|
||||
let mut state = serializer.serialize_struct("GraphSettings", 22)?;
|
||||
let mut state = serializer.serialize_struct("GraphSettings", 21)?;
|
||||
state.serialize_field("run_args", &self.run_args)?;
|
||||
state.serialize_field("num_rows", &self.num_rows)?;
|
||||
state.serialize_field("total_assignments", &self.total_assignments)?;
|
||||
@@ -526,9 +514,6 @@ impl Serialize for GraphSettings {
|
||||
&self.shuffle_params.total_shuffle_col_size,
|
||||
)?;
|
||||
|
||||
// Serialize EinsumParams
|
||||
state.serialize_field("einsum_params", &self.einsum_params)?;
|
||||
|
||||
state.serialize_field("model_instance_shapes", &self.model_instance_shapes)?;
|
||||
state.serialize_field("model_output_scales", &self.model_output_scales)?;
|
||||
state.serialize_field("model_input_scales", &self.model_input_scales)?;
|
||||
@@ -545,14 +530,13 @@ impl Serialize for GraphSettings {
|
||||
} else {
|
||||
// Binary format (bincode) - use nested struct format
|
||||
use serde::ser::SerializeTuple;
|
||||
let mut state = serializer.serialize_tuple(19)?;
|
||||
let mut state = serializer.serialize_tuple(18)?;
|
||||
state.serialize_element(&self.run_args)?;
|
||||
state.serialize_element(&self.num_rows)?;
|
||||
state.serialize_element(&self.total_assignments)?;
|
||||
state.serialize_element(&self.total_const_size)?;
|
||||
state.serialize_element(&self.dynamic_lookup_params)?;
|
||||
state.serialize_element(&self.shuffle_params)?;
|
||||
state.serialize_element(&self.einsum_params)?;
|
||||
state.serialize_element(&self.model_instance_shapes)?;
|
||||
state.serialize_element(&self.model_output_scales)?;
|
||||
state.serialize_element(&self.model_input_scales)?;
|
||||
@@ -592,8 +576,6 @@ impl<'de> Deserialize<'de> for GraphSettings {
|
||||
// Flattened ShuffleParams fields
|
||||
NumShuffles,
|
||||
TotalShuffleColSize,
|
||||
// EinsumParams field
|
||||
EinsumParams,
|
||||
ModelInstanceShapes,
|
||||
ModelOutputScales,
|
||||
ModelInputScales,
|
||||
@@ -633,7 +615,6 @@ impl<'de> Deserialize<'de> for GraphSettings {
|
||||
let mut num_dynamic_lookups = None;
|
||||
let mut num_shuffles = None;
|
||||
let mut total_shuffle_col_size = None;
|
||||
let mut einsum_params = None;
|
||||
let mut model_instance_shapes = None;
|
||||
let mut model_output_scales = None;
|
||||
let mut model_input_scales = None;
|
||||
@@ -703,12 +684,6 @@ impl<'de> Deserialize<'de> for GraphSettings {
|
||||
}
|
||||
total_shuffle_col_size = Some(map.next_value()?);
|
||||
}
|
||||
Field::EinsumParams => {
|
||||
if einsum_params.is_some() {
|
||||
return Err(de::Error::duplicate_field("einsum_params"));
|
||||
}
|
||||
einsum_params = Some(map.next_value()?);
|
||||
}
|
||||
Field::ModelInstanceShapes => {
|
||||
if model_instance_shapes.is_some() {
|
||||
return Err(de::Error::duplicate_field("model_instance_shapes"));
|
||||
@@ -847,7 +822,6 @@ impl<'de> Deserialize<'de> for GraphSettings {
|
||||
total_const_size,
|
||||
dynamic_lookup_params,
|
||||
shuffle_params,
|
||||
einsum_params: einsum_params.unwrap_or_default(),
|
||||
model_instance_shapes,
|
||||
model_output_scales,
|
||||
model_input_scales,
|
||||
@@ -870,63 +844,24 @@ impl<'de> Deserialize<'de> for GraphSettings {
|
||||
use serde::de::Error;
|
||||
|
||||
// For bincode compatibility, deserialize in the same order as tuple serialization
|
||||
let run_args = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(0, &self))?;
|
||||
let num_rows = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(1, &self))?;
|
||||
let total_assignments = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(2, &self))?;
|
||||
let total_const_size = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(3, &self))?;
|
||||
let dynamic_lookup_params = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(4, &self))?;
|
||||
let shuffle_params = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(5, &self))?;
|
||||
let einsum_params = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(6, &self))?;
|
||||
let model_instance_shapes = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(7, &self))?;
|
||||
let model_output_scales = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(8, &self))?;
|
||||
let model_input_scales = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(9, &self))?;
|
||||
let module_sizes = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(10, &self))?;
|
||||
let required_lookups = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(11, &self))?;
|
||||
let required_range_checks = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(12, &self))?;
|
||||
let check_mode = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(13, &self))?;
|
||||
let version = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(14, &self))?;
|
||||
let num_blinding_factors = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(15, &self))?;
|
||||
let timestamp = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(16, &self))?;
|
||||
let input_types = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(17, &self))?;
|
||||
let output_types = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| Error::invalid_length(18, &self))?;
|
||||
let run_args = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
|
||||
let num_rows = seq.next_element()?.ok_or_else(|| Error::invalid_length(1, &self))?;
|
||||
let total_assignments = seq.next_element()?.ok_or_else(|| Error::invalid_length(2, &self))?;
|
||||
let total_const_size = seq.next_element()?.ok_or_else(|| Error::invalid_length(3, &self))?;
|
||||
let dynamic_lookup_params = seq.next_element()?.ok_or_else(|| Error::invalid_length(4, &self))?;
|
||||
let shuffle_params = seq.next_element()?.ok_or_else(|| Error::invalid_length(5, &self))?;
|
||||
let model_instance_shapes = seq.next_element()?.ok_or_else(|| Error::invalid_length(6, &self))?;
|
||||
let model_output_scales = seq.next_element()?.ok_or_else(|| Error::invalid_length(7, &self))?;
|
||||
let model_input_scales = seq.next_element()?.ok_or_else(|| Error::invalid_length(8, &self))?;
|
||||
let module_sizes = seq.next_element()?.ok_or_else(|| Error::invalid_length(9, &self))?;
|
||||
let required_lookups = seq.next_element()?.ok_or_else(|| Error::invalid_length(10, &self))?;
|
||||
let required_range_checks = seq.next_element()?.ok_or_else(|| Error::invalid_length(11, &self))?;
|
||||
let check_mode = seq.next_element()?.ok_or_else(|| Error::invalid_length(12, &self))?;
|
||||
let version = seq.next_element()?.ok_or_else(|| Error::invalid_length(13, &self))?;
|
||||
let num_blinding_factors = seq.next_element()?.ok_or_else(|| Error::invalid_length(14, &self))?;
|
||||
let timestamp = seq.next_element()?.ok_or_else(|| Error::invalid_length(15, &self))?;
|
||||
let input_types = seq.next_element()?.ok_or_else(|| Error::invalid_length(16, &self))?;
|
||||
let output_types = seq.next_element()?.ok_or_else(|| Error::invalid_length(17, &self))?;
|
||||
|
||||
Ok(GraphSettings {
|
||||
run_args,
|
||||
@@ -935,7 +870,6 @@ impl<'de> Deserialize<'de> for GraphSettings {
|
||||
total_const_size,
|
||||
dynamic_lookup_params,
|
||||
shuffle_params,
|
||||
einsum_params,
|
||||
model_instance_shapes,
|
||||
model_output_scales,
|
||||
model_input_scales,
|
||||
@@ -950,41 +884,25 @@ impl<'de> Deserialize<'de> for GraphSettings {
|
||||
output_types,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Universal deserializer that works with both JSON (map) and bincode (tuple)
|
||||
if deserializer.is_human_readable() {
|
||||
// JSON format - use struct/map deserialization with flattened fields
|
||||
const FIELDS: &'static [&'static str] = &[
|
||||
"run_args",
|
||||
"num_rows",
|
||||
"total_assignments",
|
||||
"total_const_size",
|
||||
"total_dynamic_col_size",
|
||||
"max_dynamic_input_len",
|
||||
"num_dynamic_lookups",
|
||||
"num_shuffles",
|
||||
"total_shuffle_col_size",
|
||||
"einsum_params",
|
||||
"model_instance_shapes",
|
||||
"model_output_scales",
|
||||
"model_input_scales",
|
||||
"module_sizes",
|
||||
"required_lookups",
|
||||
"required_range_checks",
|
||||
"check_mode",
|
||||
"version",
|
||||
"num_blinding_factors",
|
||||
"timestamp",
|
||||
"input_types",
|
||||
"output_types",
|
||||
"dynamic_lookup_params",
|
||||
"shuffle_params",
|
||||
"run_args", "num_rows", "total_assignments", "total_const_size",
|
||||
"total_dynamic_col_size", "max_dynamic_input_len", "num_dynamic_lookups",
|
||||
"num_shuffles", "total_shuffle_col_size", "model_instance_shapes",
|
||||
"model_output_scales", "model_input_scales", "module_sizes",
|
||||
"required_lookups", "required_range_checks", "check_mode", "version",
|
||||
"num_blinding_factors", "timestamp", "input_types", "output_types",
|
||||
"dynamic_lookup_params", "shuffle_params",
|
||||
];
|
||||
deserializer.deserialize_struct("GraphSettings", FIELDS, GraphSettingsVisitor)
|
||||
} else {
|
||||
// Binary format (bincode) - use tuple deserialization
|
||||
deserializer.deserialize_tuple(19, GraphSettingsVisitor)
|
||||
deserializer.deserialize_tuple(18, GraphSettingsVisitor)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1069,13 +987,6 @@ impl GraphSettings {
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// Calculates the logrows for einsum computation area in which there is no column overflow
|
||||
pub fn einsum_logrows(&self) -> u32 {
|
||||
(self.einsum_params.total_einsum_col_size as f64 / self.run_args.num_inner_cols as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the total number of instances
|
||||
pub fn total_instances(&self) -> Vec<usize> {
|
||||
let mut instances: Vec<usize> = self.module_sizes.num_instances();
|
||||
@@ -1631,19 +1542,13 @@ impl GraphCircuit {
|
||||
let model_constraint_logrows = self.settings().model_constraint_logrows_with_blinding();
|
||||
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
|
||||
let constants_logrows = self.settings().constants_logrows();
|
||||
let einsum_logrows = self.settings().einsum_logrows();
|
||||
max_logrows = std::cmp::min(
|
||||
max_logrows,
|
||||
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
|
||||
*[
|
||||
model_constraint_logrows,
|
||||
min_bits,
|
||||
constants_logrows,
|
||||
einsum_logrows,
|
||||
]
|
||||
.iter()
|
||||
.max()
|
||||
.unwrap(),
|
||||
*[model_constraint_logrows, min_bits, constants_logrows]
|
||||
.iter()
|
||||
.max()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// we now have a min and max logrows
|
||||
@@ -2224,7 +2129,6 @@ pub mod tests {
|
||||
num_shuffles: 3,
|
||||
total_shuffle_col_size: 256,
|
||||
},
|
||||
einsum_params: EinsumParams::default(),
|
||||
model_instance_shapes: vec![vec![1, 2, 3]],
|
||||
model_output_scales: vec![],
|
||||
model_input_scales: vec![],
|
||||
@@ -2258,6 +2162,7 @@ pub mod tests {
|
||||
let deserialized: GraphSettings = serde_json::from_str(&json_str).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
|
||||
|
||||
// now do JSON bytes
|
||||
let json_bytes = serde_json::to_vec(&original).unwrap();
|
||||
let deserialized_from_bytes: GraphSettings = serde_json::from_slice(&json_bytes).unwrap();
|
||||
@@ -2299,8 +2204,7 @@ pub mod tests {
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false,
|
||||
"ignore_range_check_inputs_outputs": false,
|
||||
"disable_freivalds": false
|
||||
"ignore_range_check_inputs_outputs": false
|
||||
},
|
||||
"num_rows": 236,
|
||||
"total_assignments": 472,
|
||||
@@ -2349,5 +2253,6 @@ pub mod tests {
|
||||
}"#;
|
||||
|
||||
let _backwards_compatible: GraphSettings = serde_json::from_str(old_format_json).unwrap();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::vars::*;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::einsum::analysis::analyze_einsum_usage;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
@@ -38,6 +37,7 @@ use log::{debug, info, trace};
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::BTreeMap;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fs;
|
||||
@@ -106,8 +106,6 @@ pub struct DummyPassRes {
|
||||
pub dynamic_lookup_params: DynamicLookupParams,
|
||||
/// shuffle parameters
|
||||
pub shuffle_params: ShuffleParams,
|
||||
/// einsum parameters
|
||||
pub einsum_params: crate::graph::EinsumParams,
|
||||
/// num shuffles
|
||||
pub num_shuffles: usize,
|
||||
/// shuffle
|
||||
@@ -594,7 +592,6 @@ impl Model {
|
||||
output_types: Some(self.get_output_types()),
|
||||
dynamic_lookup_params: res.dynamic_lookup_params,
|
||||
shuffle_params: res.shuffle_params,
|
||||
einsum_params: res.einsum_params,
|
||||
total_const_size: res.total_const_size,
|
||||
check_mode,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
@@ -1050,7 +1047,6 @@ impl Model {
|
||||
|
||||
let lookup_range = settings.run_args.lookup_range;
|
||||
let logrows = settings.run_args.logrows as usize;
|
||||
let num_inner_cols = settings.run_args.num_inner_cols;
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
@@ -1099,24 +1095,6 @@ impl Model {
|
||||
)?;
|
||||
}
|
||||
|
||||
// Configures the circuit to use Freivalds' argument
|
||||
// In the dummy phase, Freivalds' is configured as a default (unless `disable-freivalds` is not enabled),
|
||||
// but if einsum coordinate is 0, it means that all the einsum layouts are dispatched to use only base operations.
|
||||
if settings.einsum_params.total_einsum_col_size > 0 {
|
||||
debug!("configuring einsums...");
|
||||
let used_einsums: HashMap<(usize, String), HashMap<char, usize>> = settings
|
||||
.einsum_params
|
||||
.equations
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, (equation, indices_to_dims))| {
|
||||
((idx, equation.clone()), indices_to_dims.clone())
|
||||
})
|
||||
.collect();
|
||||
let analysis = analyze_einsum_usage(&used_einsums)?;
|
||||
base_gate.configure_einsums(meta, &analysis, num_inner_cols, logrows)?;
|
||||
}
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
@@ -1169,30 +1147,17 @@ impl Model {
|
||||
|
||||
let original_constants = constants.clone();
|
||||
|
||||
let challenges = {
|
||||
if let Some(einsum_config) = &config.base.einsums {
|
||||
einsum_config
|
||||
.challenges()?
|
||||
.iter()
|
||||
.map(|c| layouter.get_challenge(*c))
|
||||
.collect_vec()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
};
|
||||
|
||||
let outputs = layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut thread_safe_region = RegionCtx::new_with_challenges(
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(
|
||||
region,
|
||||
0,
|
||||
run_args.num_inner_cols,
|
||||
run_args.decomp_base,
|
||||
run_args.decomp_legs,
|
||||
challenges.clone(),
|
||||
original_constants.clone(),
|
||||
);
|
||||
thread_safe_region.update_constants(original_constants.clone());
|
||||
// we need to do this as this loop is called multiple times
|
||||
vars.set_instance_idx(instance_idx);
|
||||
|
||||
@@ -1494,16 +1459,8 @@ impl Model {
|
||||
results.insert(*input_idx, vec![inputs[i].clone()]);
|
||||
}
|
||||
|
||||
let mut dummy_config = {
|
||||
if run_args.disable_freivalds {
|
||||
PolyConfig::dummy_without_freivalds(
|
||||
run_args.logrows as usize,
|
||||
run_args.num_inner_cols,
|
||||
)
|
||||
} else {
|
||||
PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols)
|
||||
}
|
||||
};
|
||||
let mut dummy_config =
|
||||
PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols);
|
||||
let mut model_config = ModelConfig {
|
||||
base: dummy_config.clone(),
|
||||
vars: ModelVars::new_dummy(),
|
||||
@@ -1572,10 +1529,6 @@ impl Model {
|
||||
num_shuffles: region.shuffle_index(),
|
||||
total_shuffle_col_size: region.shuffle_col_coord(),
|
||||
},
|
||||
einsum_params: crate::graph::EinsumParams {
|
||||
equations: region.used_einsum_equations(),
|
||||
total_einsum_col_size: region.einsum_col_coord(),
|
||||
},
|
||||
total_const_size: region.total_constants(),
|
||||
lookup_ops: region.used_lookups(),
|
||||
range_checks: region.used_range_checks(),
|
||||
|
||||
@@ -8,7 +8,9 @@ use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use log::debug;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{exceptions::PyValueError, FromPyObject, IntoPyObject, PyResult, Python};
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, FromPyObject, IntoPyObject, PyResult, Python,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
@@ -129,9 +131,7 @@ impl<'py> IntoPyObject<'py> for Visibility {
|
||||
.map(|o| o.to_string())
|
||||
.collect_vec()
|
||||
.join(",");
|
||||
Ok(format!("hashed/private/{}", outlets)
|
||||
.into_pyobject(py)?
|
||||
.into_any())
|
||||
Ok(format!("hashed/private/{}", outlets).into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
101
src/lib.rs
101
src/lib.rs
@@ -42,6 +42,8 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum EZKLError {
|
||||
#[error("[aggregation] {0}")]
|
||||
AggregationError(#[from] pfsys::evm::aggregation_kzg::AggregationError),
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
@@ -98,11 +100,18 @@ impl From<String> for EZKLError {
|
||||
EZKLError::UncategorizedError(s)
|
||||
}
|
||||
}
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::Args;
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, G1Affine};
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
@@ -121,6 +130,7 @@ pub fn version() -> &'static str {
|
||||
|
||||
/// Bindings management
|
||||
#[cfg(any(
|
||||
feature = "universal-bindings",
|
||||
all(target_arch = "wasm32", target_os = "unknown"),
|
||||
feature = "python-bindings"
|
||||
))]
|
||||
@@ -161,6 +171,8 @@ pub mod pfsys;
|
||||
pub mod srs_sha;
|
||||
/// An implementation of multi-dimensional tensors.
|
||||
pub mod tensor;
|
||||
#[cfg(feature = "ios-bindings")]
|
||||
uniffi::setup_scaffolding!();
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -186,6 +198,78 @@ const EZKL_KEY_FORMAT: &str = "raw-bytes";
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
const EZKL_BUF_CAPACITY: &usize = &8000;
|
||||
|
||||
#[derive(
|
||||
Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Default, Copy,
|
||||
)]
|
||||
/// Commitment scheme
|
||||
pub enum Commitments {
|
||||
#[default]
|
||||
/// KZG
|
||||
KZG,
|
||||
/// IPA
|
||||
IPA,
|
||||
}
|
||||
|
||||
impl From<Option<Commitments>> for Commitments {
|
||||
fn from(value: Option<Commitments>) -> Self {
|
||||
value.unwrap_or(Commitments::KZG)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Commitments {
|
||||
type Err = String;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"kzg" => Ok(Commitments::KZG),
|
||||
"ipa" => Ok(Commitments::IPA),
|
||||
_ => Err("Invalid value for Commitments".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<KZGCommitmentScheme<Bn256>> for Commitments {
|
||||
fn from(_value: KZGCommitmentScheme<Bn256>) -> Self {
|
||||
Commitments::KZG
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IPACommitmentScheme<G1Affine>> for Commitments {
|
||||
fn from(_value: IPACommitmentScheme<G1Affine>) -> Self {
|
||||
Commitments::IPA
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Commitments {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Commitments::KZG => write!(f, "kzg"),
|
||||
Commitments::IPA => write!(f, "ipa"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for Commitments {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Commitments {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"kzg" => Commitments::KZG,
|
||||
"ipa" => Commitments::IPA,
|
||||
_ => {
|
||||
log::error!("Invalid value for Commitments");
|
||||
log::warn!("defaulting to KZG");
|
||||
Commitments::KZG
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
///
|
||||
/// RunArgs contains all configuration parameters needed to control the proving process,
|
||||
@@ -252,6 +336,10 @@ pub struct RunArgs {
|
||||
/// Controls level of constraint verification
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
|
||||
pub check_mode: CheckMode,
|
||||
/// Commitment scheme for circuit proving
|
||||
/// Affects proof size and verification time
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
|
||||
pub commitment: Option<Commitments>,
|
||||
/// Base for number decomposition
|
||||
/// Must be a power of 2
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
|
||||
@@ -275,13 +363,6 @@ pub struct RunArgs {
|
||||
/// Optional override for epsilon value
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
|
||||
pub epsilon: Option<f64>,
|
||||
/// Forcefully disable using Freivalds' argument in einsum operations
|
||||
/// Freivalds' argument can make verifier bigger, so this option is useful when
|
||||
/// the verifier size is a concern
|
||||
/// Without this option the circuit layouter will always try to use Freivalds' argument
|
||||
/// when it is good to do so
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
|
||||
pub disable_freivalds: bool,
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
@@ -307,16 +388,16 @@ impl Default for RunArgs {
|
||||
logrows: 17,
|
||||
num_inner_cols: 2,
|
||||
variables: vec![("batch_size".to_string(), 1)],
|
||||
input_visibility: Visibility::Public,
|
||||
input_visibility: Visibility::Private,
|
||||
output_visibility: Visibility::Public,
|
||||
param_visibility: Visibility::Fixed,
|
||||
param_visibility: Visibility::Private,
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
commitment: None,
|
||||
decomp_base: 16384,
|
||||
decomp_legs: 2,
|
||||
ignore_range_check_inputs_outputs: false,
|
||||
epsilon: None,
|
||||
disable_freivalds: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
442
src/pfsys/evm/aggregation_kzg.rs
Normal file
442
src/pfsys/evm/aggregation_kzg.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::CircuitSize;
|
||||
use crate::pfsys::{Snark, SnarkWitness};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use colored_json::ToColoredJson;
|
||||
use halo2_proofs::circuit::AssignedCell;
|
||||
use halo2_proofs::plonk::{self};
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem},
|
||||
};
|
||||
use halo2_wrong_ecc::{
|
||||
integer::rns::Rns,
|
||||
maingate::{
|
||||
MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, RangeInstructions,
|
||||
RegionCtx,
|
||||
},
|
||||
EccConfig,
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use log::debug;
|
||||
use log::trace;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::loader::EcPointLoader;
|
||||
use snark_verifier::{
|
||||
loader,
|
||||
pcs::{
|
||||
kzg::{
|
||||
Bdfg21, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding,
|
||||
LimbsEncodingInstructions,
|
||||
},
|
||||
AccumulationScheme, AccumulationSchemeProver,
|
||||
},
|
||||
system,
|
||||
util::arithmetic::fe_to_limbs,
|
||||
verifier::{self, SnarkVerifier},
|
||||
};
|
||||
use std::rc::Rc;
|
||||
use thiserror::Error;
|
||||
|
||||
const LIMBS: usize = 4;
|
||||
const BITS: usize = 68;
|
||||
type As = KzgAs<Bn256, Bdfg21>;
|
||||
/// Type for aggregator verification
|
||||
type PlonkSuccinctVerifier = verifier::plonk::PlonkSuccinctVerifier<As, LimbsEncoding<LIMBS, BITS>>;
|
||||
|
||||
const T: usize = 5;
|
||||
const RATE: usize = 4;
|
||||
const R_F: usize = 8;
|
||||
const R_P: usize = 60;
|
||||
|
||||
type Svk = KzgSuccinctVerifyingKey<G1Affine>;
|
||||
type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip<G1Affine, LIMBS, BITS>;
|
||||
/// The loader type used in the transcript definition
|
||||
type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>;
|
||||
/// Application snark transcript
|
||||
pub type PoseidonTranscript<L, S> =
|
||||
system::halo2::transcript::halo2::PoseidonTranscript<G1Affine, L, S, T, RATE, R_F, R_P>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
/// Errors related to proof aggregation
|
||||
pub enum AggregationError {
|
||||
/// A KZG proof could not be verified
|
||||
#[error("failed to verify KZG proof")]
|
||||
KZGProofVerification,
|
||||
/// proof read errors
|
||||
#[error("Failed to read proof")]
|
||||
ProofRead,
|
||||
/// proof verification errors
|
||||
#[error("Failed to verify proof")]
|
||||
ProofVerify,
|
||||
/// proof creation errors
|
||||
#[error("Failed to create proof")]
|
||||
ProofCreate,
|
||||
}
|
||||
|
||||
type AggregationResult<'a> = (
|
||||
// accumulator
|
||||
KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>,
|
||||
// the set of assigned cells
|
||||
Vec<Vec<AssignedCell<Fr, Fr>>>,
|
||||
);
|
||||
|
||||
type LoadedProof<'a> = verifier::plonk::PlonkProof<
|
||||
G1Affine,
|
||||
Rc<
|
||||
loader::halo2::Halo2Loader<
|
||||
'a,
|
||||
G1Affine,
|
||||
halo2_wrong_ecc::BaseFieldEccChip<G1Affine, 4, 68>,
|
||||
>,
|
||||
>,
|
||||
KzgAs<Bn256, Bdfg21>,
|
||||
>;
|
||||
|
||||
/// Aggregate one or more application snarks of the same shape into a KzgAccumulator
|
||||
pub fn aggregate<'a>(
|
||||
svk: &Svk,
|
||||
loader: &Rc<Halo2Loader<'a>>,
|
||||
snarks: &[SnarkWitness<Fr, G1Affine>],
|
||||
as_proof: Value<&'_ [u8]>,
|
||||
split_proofs: bool,
|
||||
) -> Result<AggregationResult<'a>, plonk::Error> {
|
||||
let assign_instances = |instances: &[Vec<Value<Fr>>]| {
|
||||
instances
|
||||
.iter()
|
||||
.map(|instances| {
|
||||
instances
|
||||
.iter()
|
||||
.map(|instance| loader.assign_scalar(*instance))
|
||||
.collect_vec()
|
||||
})
|
||||
.collect_vec()
|
||||
};
|
||||
|
||||
let mut accumulators = vec![];
|
||||
let mut snark_instances = vec![];
|
||||
let mut proofs: Vec<LoadedProof<'_>> = vec![];
|
||||
|
||||
for snark in snarks.iter() {
|
||||
let protocol = snark.protocol.as_ref().unwrap().loaded(loader);
|
||||
let instances = assign_instances(&snark.instances);
|
||||
|
||||
// get assigned cells
|
||||
snark_instances.extend(instances.iter().map(|instance| {
|
||||
instance
|
||||
.iter()
|
||||
.map(|v| v.clone().into_assigned())
|
||||
.collect_vec()
|
||||
}));
|
||||
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
|
||||
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
|
||||
if split_proofs {
|
||||
let previous_proof = proofs.last();
|
||||
let split_commit = match snark.clone().split {
|
||||
Some(split) => split,
|
||||
None => {
|
||||
log::error!("Failed to split KZG commit for sequential proofs");
|
||||
return Err(plonk::Error::Synthesis);
|
||||
}
|
||||
};
|
||||
if let Some(previous_proof) = previous_proof {
|
||||
// output of previous proof
|
||||
let output = &previous_proof.witnesses[split_commit.start..split_commit.end];
|
||||
// input of current proof
|
||||
let split_commit_len = split_commit.end - split_commit.start;
|
||||
let input = &proof.witnesses[..split_commit_len];
|
||||
// these points were already assigned previously when loading the transcript so this is safe
|
||||
// and equivalent to a copy constraint and an equality constraint
|
||||
for (output, input) in output.iter().zip(input.iter()) {
|
||||
loader
|
||||
.ec_point_assert_eq("assert commits match", output, input)
|
||||
.map_err(|e| {
|
||||
log::error!(
|
||||
"Failed to match KZG commits for sequential proofs: {:?}",
|
||||
e
|
||||
);
|
||||
plonk::Error::Synthesis
|
||||
})?;
|
||||
}
|
||||
}
|
||||
proofs.push(proof.clone());
|
||||
}
|
||||
|
||||
let mut accum = PlonkSuccinctVerifier::verify(svk, &protocol, &instances, &proof)
|
||||
.map_err(|_| plonk::Error::Synthesis)?;
|
||||
accumulators.append(&mut accum);
|
||||
}
|
||||
let accumulator = {
|
||||
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, as_proof);
|
||||
let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap();
|
||||
As::verify(&Default::default(), &accumulators, &proof).map_err(|_| plonk::Error::Synthesis)
|
||||
}?;
|
||||
Ok((accumulator, snark_instances))
|
||||
}
|
||||
|
||||
/// The Halo2 Config for the aggregation circuit
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AggregationConfig {
|
||||
main_gate_config: MainGateConfig,
|
||||
range_config: RangeConfig,
|
||||
}
|
||||
|
||||
impl AggregationConfig {
|
||||
/// Configure the aggregation circuit
|
||||
pub fn configure<F: PrimeField>(
|
||||
meta: &mut ConstraintSystem<F>,
|
||||
composition_bits: Vec<usize>,
|
||||
overflow_bits: Vec<usize>,
|
||||
) -> Self {
|
||||
let main_gate_config = MainGate::<F>::configure(meta);
|
||||
let range_config =
|
||||
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let circuit_size = CircuitSize::from_cs(meta, 23);
|
||||
|
||||
// not wasm
|
||||
|
||||
debug!(
|
||||
"circuit size: \n {}",
|
||||
circuit_size
|
||||
.as_json()
|
||||
.unwrap()
|
||||
.to_colored_json_auto()
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
AggregationConfig {
|
||||
main_gate_config,
|
||||
range_config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a MainGate from the aggregation approach
|
||||
pub fn main_gate(&self) -> MainGate<Fr> {
|
||||
MainGate::new(self.main_gate_config.clone())
|
||||
}
|
||||
|
||||
/// Create a range chip to decompose and range check inputs
|
||||
pub fn range_chip(&self) -> RangeChip<Fr> {
|
||||
RangeChip::new(self.range_config.clone())
|
||||
}
|
||||
|
||||
/// Create an ecc chip for ec ops
|
||||
pub fn ecc_chip(&self) -> BaseFieldEccChip {
|
||||
BaseFieldEccChip::new(EccConfig::new(
|
||||
self.range_config.clone(),
|
||||
self.main_gate_config.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AggregationCircuit {
|
||||
svk: Svk,
|
||||
snarks: Vec<SnarkWitness<Fr, G1Affine>>,
|
||||
instances: Vec<Fr>,
|
||||
as_proof: Value<Vec<u8>>,
|
||||
split_proof: bool,
|
||||
}
|
||||
|
||||
impl AggregationCircuit {
|
||||
/// Create a new Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof.
|
||||
pub fn new(
|
||||
svk: &KzgSuccinctVerifyingKey<G1Affine>,
|
||||
snarks: impl IntoIterator<Item = Snark<Fr, G1Affine>>,
|
||||
split_proof: bool,
|
||||
) -> Result<Self, AggregationError> {
|
||||
let snarks = snarks.into_iter().collect_vec();
|
||||
|
||||
let mut accumulators = vec![];
|
||||
|
||||
for snark in snarks.iter() {
|
||||
trace!("Aggregating with snark instances {:?}", snark.instances);
|
||||
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(snark.proof.as_slice());
|
||||
let proof = PlonkSuccinctVerifier::read_proof(
|
||||
svk,
|
||||
snark.protocol.as_ref().unwrap(),
|
||||
&snark.instances,
|
||||
&mut transcript,
|
||||
)
|
||||
.map_err(|e| {
|
||||
log::error!("{:?}", e);
|
||||
AggregationError::ProofRead
|
||||
})?;
|
||||
let mut accum = PlonkSuccinctVerifier::verify(
|
||||
svk,
|
||||
snark.protocol.as_ref().unwrap(),
|
||||
&snark.instances,
|
||||
&proof,
|
||||
)
|
||||
.map_err(|_| AggregationError::ProofVerify)?;
|
||||
accumulators.append(&mut accum);
|
||||
}
|
||||
|
||||
trace!("Accumulator");
|
||||
let (accumulator, as_proof) = {
|
||||
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(Vec::new());
|
||||
let accumulator =
|
||||
As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng)
|
||||
.map_err(|_| AggregationError::ProofCreate)?;
|
||||
(accumulator, transcript.finalize())
|
||||
};
|
||||
|
||||
trace!("KzgAccumulator");
|
||||
let KzgAccumulator { lhs, rhs } = accumulator;
|
||||
let instances = [lhs.x, lhs.y, rhs.x, rhs.y]
|
||||
.map(fe_to_limbs::<_, _, LIMBS, BITS>)
|
||||
.concat();
|
||||
|
||||
Ok(Self {
|
||||
svk: *svk,
|
||||
snarks: snarks.into_iter().map_into().collect(),
|
||||
instances,
|
||||
as_proof: Value::known(as_proof),
|
||||
split_proof,
|
||||
})
|
||||
}
|
||||
|
||||
/// Number of limbs used for decomposition
|
||||
pub fn num_limbs() -> usize {
|
||||
LIMBS
|
||||
}
|
||||
/// Number of bits used for decomposition
|
||||
pub fn num_bits() -> usize {
|
||||
BITS
|
||||
}
|
||||
|
||||
/// Accumulator indices used in generating verifier.
|
||||
pub fn accumulator_indices() -> Vec<(usize, usize)> {
|
||||
(0..4 * LIMBS).map(|idx| (0, idx)).collect()
|
||||
}
|
||||
|
||||
/// Number of instance variables for the aggregation circuit, used in generating verifier.
|
||||
pub fn num_instance(orginal_circuit_instances: usize) -> Vec<usize> {
|
||||
let accumulation_instances = 4 * LIMBS;
|
||||
vec![accumulation_instances + orginal_circuit_instances]
|
||||
}
|
||||
|
||||
/// Instance variables for the aggregation circuit, fed to verifier.
|
||||
pub fn instances(&self) -> Vec<Fr> {
|
||||
// also get snark instances here
|
||||
let mut snark_instances: Vec<Vec<Vec<Value<Fr>>>> = self
|
||||
.snarks
|
||||
.iter()
|
||||
.map(|snark| snark.instances.clone())
|
||||
.collect_vec();
|
||||
|
||||
// reduce from Vec<Vec<Vec<Value<Fr>>>> to Vec<Vec<Value<Fr>>>
|
||||
let mut instances: Vec<Fr> = self.instances.clone();
|
||||
for snark_instance in snark_instances.iter_mut() {
|
||||
for instance in snark_instance.iter_mut() {
|
||||
let mut felt_evals = vec![];
|
||||
for value in instance.iter_mut() {
|
||||
value.map(|v| felt_evals.push(v));
|
||||
}
|
||||
instances.extend(felt_evals);
|
||||
}
|
||||
}
|
||||
|
||||
instances
|
||||
}
|
||||
|
||||
fn as_proof(&self) -> Value<&[u8]> {
|
||||
self.as_proof.as_ref().map(Vec::as_slice)
|
||||
}
|
||||
}
|
||||
|
||||
impl Circuit<Fr> for AggregationCircuit {
|
||||
type Config = AggregationConfig;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = ();
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
Self {
|
||||
svk: self.svk,
|
||||
snarks: self
|
||||
.snarks
|
||||
.iter()
|
||||
.map(SnarkWitness::without_witnesses)
|
||||
.collect(),
|
||||
instances: Vec::new(),
|
||||
as_proof: Value::unknown(),
|
||||
split_proof: self.split_proof,
|
||||
}
|
||||
}
|
||||
|
||||
fn configure(meta: &mut ConstraintSystem<Fr>) -> Self::Config {
|
||||
AggregationConfig::configure(
|
||||
meta,
|
||||
vec![BITS / LIMBS],
|
||||
Rns::<Fq, Fr, LIMBS, BITS>::construct().overflow_lengths(),
|
||||
)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<Fr>,
|
||||
) -> Result<(), plonk::Error> {
|
||||
let main_gate = config.main_gate();
|
||||
let range_chip = config.range_chip();
|
||||
|
||||
range_chip.load_table(&mut layouter)?;
|
||||
|
||||
let (accumulator_limbs, snark_instances) = layouter.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let ctx = RegionCtx::new(region, 0);
|
||||
|
||||
let ecc_chip = config.ecc_chip();
|
||||
let loader = Halo2Loader::new(ecc_chip, ctx);
|
||||
let (accumulator, snark_instances) = aggregate(
|
||||
&self.svk,
|
||||
&loader,
|
||||
&self.snarks,
|
||||
self.as_proof(),
|
||||
self.split_proof,
|
||||
)?;
|
||||
|
||||
let accumulator_limbs = [accumulator.lhs, accumulator.rhs]
|
||||
.iter()
|
||||
.map(|ec_point| {
|
||||
loader
|
||||
.ecc_chip()
|
||||
.assign_ec_point_to_limbs(&mut loader.ctx_mut(), ec_point.assigned())
|
||||
})
|
||||
.collect::<Result<Vec<_>, plonk::Error>>()?
|
||||
.into_iter()
|
||||
.flatten();
|
||||
|
||||
Ok((accumulator_limbs, snark_instances))
|
||||
},
|
||||
)?;
|
||||
|
||||
let mut instance_offset = 0;
|
||||
for limb in accumulator_limbs {
|
||||
main_gate.expose_public(layouter.namespace(|| ""), limb, instance_offset)?;
|
||||
instance_offset += 1;
|
||||
}
|
||||
|
||||
for instance in snark_instances.into_iter() {
|
||||
for elem in instance.into_iter() {
|
||||
main_gate.expose_public(layouter.namespace(|| ""), elem, instance_offset)?;
|
||||
instance_offset += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
24
src/pfsys/evm/mod.rs
Normal file
24
src/pfsys/evm/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Aggregate proof generation for EVM using KZG
|
||||
pub mod aggregation_kzg;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
/// Errors related to evm verification
|
||||
pub enum EvmVerificationError {
|
||||
/// If the Solidity verifier worked but returned false
|
||||
#[error("Solidity verifier found the proof invalid")]
|
||||
InvalidProof,
|
||||
/// If the Solidity verifier threw and error (e.g. OutOfGas)
|
||||
#[error("Execution of Solidity code failed: {0}")]
|
||||
SolidityExecution(String),
|
||||
/// EVM verify errors
|
||||
#[error("evm verification reverted: {0}")]
|
||||
Reverted(String),
|
||||
/// EVM verify errors
|
||||
#[error("evm deployment failed: {0}")]
|
||||
DeploymentFailed(String),
|
||||
/// Invalid Visibility
|
||||
#[error("Invalid visibility")]
|
||||
InvalidVisibility,
|
||||
}
|
||||
303
src/pfsys/mod.rs
303
src/pfsys/mod.rs
@@ -1,3 +1,6 @@
|
||||
/// EVM related proving and verification
|
||||
pub mod evm;
|
||||
|
||||
/// SRS generation, processing, verification and downloading
|
||||
pub mod srs;
|
||||
|
||||
@@ -10,11 +13,17 @@ use std::borrow::Borrow;
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::graph::GraphWitness;
|
||||
use crate::{EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::ValueEnum;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
|
||||
};
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
|
||||
@@ -28,16 +37,22 @@ use rand::rngs::OsRng;
|
||||
use rand::rngs::StdRng;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::verifier::plonk::PlonkProtocol;
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, BufWriter, Cursor, Write};
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error as thisError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDictMethods;
|
||||
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
|
||||
/// Converts a string to a `SerdeFormat`.
|
||||
/// # Panics
|
||||
/// Panics if the provided `s` is not a valid `SerdeFormat` (i.e. not one of "processed", "raw-bytes-unchecked", or "raw-bytes").
|
||||
@@ -125,6 +140,144 @@ where
|
||||
bytes
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum ProofType {
|
||||
#[default]
|
||||
Single,
|
||||
ForAggr,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProofType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
ProofType::Single => "single",
|
||||
ProofType::ForAggr => "for-aggr",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for ProofType {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ProofType> for TranscriptType {
|
||||
fn from(val: ProofType) -> Self {
|
||||
match val {
|
||||
ProofType::Single => TranscriptType::EVM,
|
||||
ProofType::ForAggr => TranscriptType::Poseidon,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ProofType> for StrategyType {
|
||||
fn from(val: ProofType) -> Self {
|
||||
match val {
|
||||
ProofType::Single => StrategyType::Single,
|
||||
ProofType::ForAggr => StrategyType::Accum,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl<'py> pyo3::IntoPyObject<'py> for ProofType {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
ProofType::Single => "Single",
|
||||
ProofType::ForAggr => "ForAggr",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python)
|
||||
impl<'source> pyo3::FromPyObject<'source> for ProofType {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> pyo3::PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
match strval.to_lowercase().as_str() {
|
||||
"single" => Ok(ProofType::Single),
|
||||
"for-aggr" => Ok(ProofType::ForAggr),
|
||||
_ => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"Invalid value for ProofType",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum StrategyType {
|
||||
Single,
|
||||
Accum,
|
||||
}
|
||||
impl std::fmt::Display for StrategyType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// When the `ezkl` feature is disabled or we're targeting `wasm32`, use basic string representation.
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
{
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
StrategyType::Single => "single",
|
||||
StrategyType::Accum => "accum",
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// When the `ezkl` feature is enabled and we're not targeting `wasm32`, use `to_possible_value`.
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
self.to_possible_value()
|
||||
.expect("no values are skipped")
|
||||
.get_name()
|
||||
.fmt(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts StrategyType into a PyObject (Required for StrategyType to be compatible with Python)
|
||||
impl<'py> pyo3::IntoPyObject<'py> for StrategyType {
|
||||
type Target = pyo3::PyAny;
|
||||
type Output = pyo3::Bound<'py, Self::Target>;
|
||||
type Error = pyo3::PyErr;
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
|
||||
let result = match self {
|
||||
StrategyType::Single => "single",
|
||||
StrategyType::Accum => "accum",
|
||||
};
|
||||
Ok(result.into_pyobject(py)?.into_any())
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python)
|
||||
impl<'source> pyo3::FromPyObject<'source> for StrategyType {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> pyo3::PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
match strval.to_lowercase().as_str() {
|
||||
"single" => Ok(StrategyType::Single),
|
||||
"accum" => Ok(StrategyType::Accum),
|
||||
_ => Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"Invalid value for StrategyType",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thisError, Debug)]
|
||||
/// Errors related to pfsys
|
||||
pub enum PfSysError {
|
||||
@@ -133,8 +286,34 @@ pub enum PfSysError {
|
||||
PackingExponent,
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use halo2curves::bn256::G1Affine;
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum TranscriptType {
|
||||
Poseidon,
|
||||
#[default]
|
||||
EVM,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TranscriptType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
TranscriptType::Poseidon => "poseidon",
|
||||
TranscriptType::EVM => "evm",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl ToFlags for TranscriptType {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
///
|
||||
@@ -193,7 +372,7 @@ pub struct PrettyElements {
|
||||
pub outputs: Vec<Vec<String>>,
|
||||
}
|
||||
|
||||
/// An application snark with proof and instance variables
|
||||
/// An application snark with proof and instance variables ready for aggregation (raw field element)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Snark<F: PrimeField + SerdeObject, C: CurveAffine>
|
||||
where
|
||||
@@ -208,12 +387,16 @@ where
|
||||
pub proof: Vec<u8>,
|
||||
/// hex encoded proof
|
||||
pub hex_proof: Option<String>,
|
||||
/// transcript type
|
||||
pub transcript_type: TranscriptType,
|
||||
/// the split proof
|
||||
pub split: Option<ProofSplitCommit>,
|
||||
/// the proof instances as rescaled floats
|
||||
pub pretty_public_inputs: Option<PrettyElements>,
|
||||
/// timestamp
|
||||
pub timestamp: Option<u128>,
|
||||
/// commitment
|
||||
pub commitment: Option<Commitments>,
|
||||
/// (optional) version of ezkl used to generate the proof
|
||||
version: Option<String>,
|
||||
}
|
||||
@@ -221,8 +404,7 @@ where
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{types::PyDict, IntoPyObject, Python};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl<'py, F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> IntoPyObject<'py>
|
||||
for Snark<F, C>
|
||||
impl<'py, F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> IntoPyObject<'py> for Snark<F, C>
|
||||
where
|
||||
C::Scalar: Serialize + DeserializeOwned,
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
@@ -241,6 +423,8 @@ where
|
||||
dict.set_item("instances", field_elems).unwrap();
|
||||
let hex_proof = hex::encode(&self.proof);
|
||||
dict.set_item("proof", format!("0x{}", hex_proof)).unwrap();
|
||||
dict.set_item("transcript_type", self.transcript_type.into_pyobject(py)?)
|
||||
.unwrap();
|
||||
Ok(dict.into_any())
|
||||
}
|
||||
}
|
||||
@@ -253,21 +437,24 @@ where
|
||||
C::Scalar: Serialize + DeserializeOwned,
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
{
|
||||
/// Create a new application snark from proof and instance variables
|
||||
/// Create a new application snark from proof and instance variables ready for aggregation
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
protocol: Option<PlonkProtocol<C>>,
|
||||
instances: Vec<Vec<F>>,
|
||||
proof: Vec<u8>,
|
||||
hex_proof: Option<String>,
|
||||
transcript_type: TranscriptType,
|
||||
split: Option<ProofSplitCommit>,
|
||||
pretty_public_inputs: Option<PrettyElements>,
|
||||
commitment: Option<Commitments>,
|
||||
) -> Self {
|
||||
Self {
|
||||
protocol,
|
||||
instances,
|
||||
proof,
|
||||
hex_proof,
|
||||
transcript_type,
|
||||
split,
|
||||
pretty_public_inputs,
|
||||
// unix timestamp
|
||||
@@ -277,6 +464,7 @@ where
|
||||
.unwrap()
|
||||
.as_millis(),
|
||||
),
|
||||
commitment,
|
||||
version: Some(crate::version().to_string()),
|
||||
}
|
||||
}
|
||||
@@ -372,6 +560,53 @@ impl From<GraphWitness> for Option<ProofSplitCommit> {
|
||||
}
|
||||
}
|
||||
|
||||
/// An application snark with proof and instance variables ready for aggregation (wrapped field element)
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SnarkWitness<F: PrimeField, C: CurveAffine> {
|
||||
protocol: Option<PlonkProtocol<C>>,
|
||||
instances: Vec<Vec<Value<F>>>,
|
||||
proof: Value<Vec<u8>>,
|
||||
split: Option<ProofSplitCommit>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField, C: CurveAffine> SnarkWitness<F, C> {
|
||||
fn without_witnesses(&self) -> Self {
|
||||
SnarkWitness {
|
||||
protocol: self.protocol.clone(),
|
||||
instances: self
|
||||
.instances
|
||||
.iter()
|
||||
.map(|instances| vec![Value::unknown(); instances.len()])
|
||||
.collect(),
|
||||
proof: Value::unknown(),
|
||||
split: self.split.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn proof(&self) -> Value<&[u8]> {
|
||||
self.proof.as_ref().map(Vec::as_slice)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + SerdeObject, C: CurveAffine> From<Snark<F, C>> for SnarkWitness<F, C>
|
||||
where
|
||||
C::Scalar: Serialize + DeserializeOwned,
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
{
|
||||
fn from(snark: Snark<F, C>) -> Self {
|
||||
Self {
|
||||
protocol: snark.protocol,
|
||||
instances: snark
|
||||
.instances
|
||||
.into_iter()
|
||||
.map(|instances| instances.into_iter().map(Value::known).collect())
|
||||
.collect(),
|
||||
proof: Value::known(snark.proof),
|
||||
split: snark.split,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a [VerifyingKey] and [ProvingKey] for a [crate::graph::GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`).
|
||||
pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
circuit: &C,
|
||||
@@ -417,6 +652,8 @@ pub fn create_proof_circuit<
|
||||
params: &'params Scheme::ParamsProver,
|
||||
pk: &ProvingKey<Scheme::Curve>,
|
||||
check_mode: CheckMode,
|
||||
commitment: Commitments,
|
||||
transcript_type: TranscriptType,
|
||||
split: Option<ProofSplitCommit>,
|
||||
protocol: Option<PlonkProtocol<Scheme::Curve>>,
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, PfsysError>
|
||||
@@ -464,7 +701,16 @@ where
|
||||
let proof = transcript.finalize();
|
||||
let hex_proof = format!("0x{}", hex::encode(&proof));
|
||||
|
||||
let checkable_pf = Snark::new(protocol, instances, proof, Some(hex_proof), split, None);
|
||||
let checkable_pf = Snark::new(
|
||||
protocol,
|
||||
instances,
|
||||
proof,
|
||||
Some(hex_proof),
|
||||
transcript_type,
|
||||
split,
|
||||
None,
|
||||
Some(commitment),
|
||||
);
|
||||
|
||||
// sanity check that the generated proof is valid
|
||||
if check_mode == CheckMode::SAFE {
|
||||
@@ -553,6 +799,44 @@ where
|
||||
Ok(proof_first_bytes)
|
||||
}
|
||||
|
||||
/// Swap the proof commitments to a new set in the proof for KZG
|
||||
pub fn swap_proof_commitments_polycommit(
|
||||
snark: &Snark<Fr, G1Affine>,
|
||||
commitments: &[G1Affine],
|
||||
) -> Result<Snark<Fr, G1Affine>, PfsysError> {
|
||||
let proof = match snark.commitment {
|
||||
Some(Commitments::KZG) => match snark.transcript_type {
|
||||
TranscriptType::EVM => swap_proof_commitments::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(snark, commitments)?,
|
||||
TranscriptType::Poseidon => swap_proof_commitments::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(snark, commitments)?,
|
||||
},
|
||||
Some(Commitments::IPA) => match snark.transcript_type {
|
||||
TranscriptType::EVM => swap_proof_commitments::<
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(snark, commitments)?,
|
||||
TranscriptType::Poseidon => swap_proof_commitments::<
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(snark, commitments)?,
|
||||
},
|
||||
None => {
|
||||
return Err(PfsysError::InvalidCommitmentScheme);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(proof)
|
||||
}
|
||||
|
||||
/// A wrapper around halo2's verify_proof
|
||||
pub fn verify_proof_circuit<
|
||||
'params,
|
||||
@@ -709,11 +993,13 @@ mod tests {
|
||||
let snark = Snark::<Fr, G1Affine> {
|
||||
proof: vec![1, 2, 3, 4, 5, 6, 7, 8],
|
||||
instances: vec![vec![Fr::from(1)], vec![Fr::from(2)]],
|
||||
transcript_type: TranscriptType::EVM,
|
||||
protocol: None,
|
||||
hex_proof: None,
|
||||
split: None,
|
||||
pretty_public_inputs: None,
|
||||
timestamp: None,
|
||||
commitment: None,
|
||||
version: None,
|
||||
};
|
||||
|
||||
@@ -726,5 +1012,6 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(snark.instances, snark2.instances);
|
||||
assert_eq!(snark.proof, snark2.proof);
|
||||
assert_eq!(snark.transcript_type, snark2.transcript_type);
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user