Compare commits

..

5 Commits

Author SHA1 Message Date
dante
5144991b21 fix: update decompose base type 2025-10-26 14:34:36 -04:00
dante
64acb1d9d6 fix: bump decomp base to integerrep (#1016) 2025-10-25 23:22:12 -04:00
dante
8cf28456b3 refactor: remove IPA (#1014) 2025-10-14 09:32:28 -04:00
dante
e70e13a9e3 refactor!: rm ios,js,aggregation (#1013)
BREAKING CHANGE: removes support for iOS, JS, WASM and removes aggregation circuit.
2025-10-10 09:56:41 -04:00
dante
365d92a5f2 feat: implement generalized Freivalds' algorithm for arbitrary einsum (#990)
---------

Co-authored-by: DoHoon Kim <59155248+DoHoonKim8@users.noreply.github.com>
Co-authored-by: therealyingtong <yingtong.lai@gmail.com>
Co-authored-by: DoHoonKim <dohoon1097819@gmail.com>
2025-10-08 07:34:32 -04:00
120 changed files with 3961 additions and 11379 deletions

View File

@@ -1,192 +0,0 @@
name: Build and Publish EZKL Engine npm package
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
defaults:
run:
working-directory: .
jobs:
publish-wasm-bindings:
permissions:
contents: read
packages: write
name: publish-wasm-bindings
env:
RELEASE_TAG: ${{ github.ref_name }}
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
cache: false
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
export PATH=$PATH:$PWD/binaryen-version_116/bin
wasm-opt --version
- name: Build wasm files for both web and nodejs compilation targets
run: |
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
- name: Create package.json in pkg folder
shell: bash
run: |
cat > pkg/package.json << EOF
{
"name": "@ezkljs/engine",
"version": "$RELEASE_TAG",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}
EOF
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:21,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
run: |
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
- name: Add serialize and deserialize methods to nodejs bundle
run: |
echo '
const JSONBig = require("json-bigint");
function deserialize(buffer) { // buffer is a Uint8ClampedArray | Uint8Array // return a JSON object
if (buffer instanceof Uint8ClampedArray) {
buffer = new Uint8Array(buffer.buffer);
}
const string = new TextDecoder().decode(buffer);
const jsonObject = JSONBig.parse(string);
return jsonObject;
}
function serialize(data) { // data is an object // return a Uint8ClampedArray
// Step 1: Stringify the Object with BigInt support
if (typeof data === "object") {
data = JSONBig.stringify(data);
}
// Step 2: Encode the JSON String
const uint8Array = new TextEncoder().encode(data);
// Step 3: Convert to Uint8ClampedArray
return new Uint8ClampedArray(uint8Array.buffer);
}
module.exports = {
deserialize,
serialize
};
' > pkg/nodejs/utils.js
- name: Add serialize and deserialize methods to web bundle
run: |
echo '
import { parse, stringify } from "json-bigint";
export function deserialize(buffer) { // buffer is a Uint8ClampedArray | Uint8Array // return a JSON object
if (buffer instanceof Uint8ClampedArray) {
buffer = new Uint8Array(buffer.buffer);
}
const string = new TextDecoder().decode(buffer);
const jsonObject = parse(string);
return jsonObject;
}
export function serialize(data) { // data is an object // return a Uint8ClampedArray
// Step 1: Stringify the Object with BigInt support
if (typeof data === "object") {
data = stringify(data);
}
// Step 2: Encode the JSON String
const uint8Array = new TextEncoder().encode(data);
// Step 3: Convert to Uint8ClampedArray
return new Uint8ClampedArray(uint8Array.buffer);
}
' > pkg/web/utils.js
- name: Expose serialize and deserialize imports in nodejs target
run: |
sed -i '53i// import serialize and deserialize from utils.js\nconst { serialize, deserialize } = require(`./utils.js`);\nmodule.exports.serialize = serialize;\nmodule.exports.deserialize = deserialize;' pkg/nodejs/ezkl.js
- name: Expose serialize and deserialize imports in web target
run: |
sed -i '51i\
// import serialize and deserialize from utils.js\
import { serialize, deserialize } from '\''./utils.js'\'';\
export { serialize, deserialize };' pkg/web/ezkl.js
- name: Add serialize and deserialize imports to nodejs ezkl.d.ts
run: |
sed -i '1i\
export declare function serialize(data: object | string): Uint8ClampedArray;\
export declare function deserialize(buffer: Uint8ClampedArray | Uint8Array): any;' pkg/nodejs/ezkl.d.ts
- name: Add serialize and deserialize imports to web ezkl.d.ts
run: |
sed -i '1i\
export declare function serialize(data: object | string): Uint8ClampedArray;\
export declare function deserialize(buffer: Uint8ClampedArray | Uint8Array): any;' pkg/web/ezkl.d.ts
- name: Create README.md in pkg folder
run: |
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
- name: Set up Node.js
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd pkg
npm install
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -20,18 +20,18 @@ env:
jobs:
fr-age-test:
needs: [build, library-tests, docs, python-tests, python-integration-tests]
needs: [build, library-tests, docs]
permissions:
contents: read
runs-on: large-self-hosted
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -52,12 +52,11 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -78,10 +77,10 @@ jobs:
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -102,10 +101,10 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -116,7 +115,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -126,7 +125,6 @@ jobs:
- name: Library tests
run: cargo nextest run --lib --verbose
ultra-overflow-tests-gpu:
permissions:
contents: read
@@ -137,21 +135,18 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
with:
wasmtime-version: "3.0.1"
- name: Setup GPU dependencies
run: sudo ./setup-gpu.sh --yes
- name: Install build dependencies
@@ -176,7 +171,7 @@ jobs:
ultra-overflow-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
@@ -184,12 +179,11 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -200,17 +194,10 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@bf814d7d8fc3c3a77dfe114bd9fb8a2c575f6ad6 #v2.0.0
with:
wasmtime-version: "3.0.1"
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: lookup overflow
run: cargo nextest run lookup_ultra_overflow --no-capture -- --include-ignored
- name: Matmul overflow
@@ -230,12 +217,11 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -246,66 +232,17 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- name: Model serialization different binary ID
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
wasm32-tests:
permissions:
contents: read
runs-on: ubuntu-22.04
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
# add `atomics` and `bulk-memory` to RUSTFLAGS to enable wasm-bindgen tests
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- name: install libc6
run: sudo apt-get install -y libc6
- name: Install cmake and build dependencies
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.13.1
version: "v0.13.1"
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
- name: Create webdriver.json to disable timeouts
run: |
echo '{"args": ["--headless", "--disable-gpu", "--disable-dev-shm-usage", "--no-sandbox"]}' > webdriver.json
- name: Run wasm verifier tests
run: |
ulimit -n 65536
WASM_BINDGEN_TEST_THREADS=1 \
WASM_BINDGEN_TEST_TIMEOUT=1800 \
CHROMEDRIVER_ARGS="--log-level=INFO" \
wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web -- --nocapture
mock-proving-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
@@ -313,16 +250,16 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -380,62 +317,42 @@ jobs:
prove-and-verify-evm-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.13.1
version: "v0.13.1"
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Use Node.js 22.17.1
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "22.17.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and package
run: |
pnpm install --frozen-lockfile
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 56b806a3ba7866a3b061093bebd0fa2ace97f1fc --locked anvil --force
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: KZG prove and verify tests (EVM)
run: cargo nextest run --verbose "tests_evm::kzg_evm_prove_and_verify_::" --test-threads 1
- name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --features reusable-verifier --test-threads 1
# - name: KZG prove and verify tests (EVM + reusable verifier + col-overflow)
# run: cargo nextest run --verbose tests_evm::kzg_evm_prove_and_verify_reusable_verifier --features reusable-verifier --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
run: cargo nextest run --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg inputs)
@@ -449,96 +366,36 @@ jobs:
- name: KZG prove and verify tests (EVM + hashed outputs)
run: cargo nextest run --verbose tests_evm::kzg_evm_hashed_output_prove_and_verify --test-threads 1
# prove-and-verify-tests-metal:
# permissions:
# contents: read
# runs-on: macos-13
# # needs: [build, library-tests, docs]
# steps:
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# with:
# persist-credentials: false
# - uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
# with:
# toolchain: nightly-2025-05-01
# override: true
# components: rustfmt, clippy
# - uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
# with:
# # Pin to version 0.13.1
# version: 'v0.13.1'
# - name: Add rust-src
# run: rustup component add rust-src --toolchain nightly-2025-05-01
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
# with:
# persist-credentials: false
# - name: Use pnpm 8
# uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
# with:
# version: 8
# - uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG prove and verify tests (public outputs)
# run: cargo nextest run --features macos-metal --verbose tests::kzg_prove_and_verify_::t --no-capture
prove-and-verify-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.13.1
version: "v0.13.1"
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Use pnpm 8
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Use Node.js 22.17.1
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "22.17.1"
cache: "pnpm"
- name: Install dependencies for js tests
run: |
pnpm install --frozen-lockfile
env:
CI: false
NODE_ENV: development
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
# - name: Build wasm package for nodejs target.
# run: |
# wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::t
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
- name: KZG prove and verify tests (hashed inputs + column overflow)
@@ -575,17 +432,17 @@ jobs:
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -600,7 +457,7 @@ jobs:
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_kzg_output --features gpu-accelerated --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w --features gpu-accelerated --test-threads 1
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::t --features gpu-accelerated --test-threads 1
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features gpu-accelerated --test-threads 1
- name: KZG prove and verify tests (public outputs)
@@ -612,122 +469,6 @@ jobs:
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_hashed --features gpu-accelerated --test-threads 1
prove-and-verify-mock-aggr-tests:
permissions:
contents: read
runs-on: self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- name: Mock aggr tests (KZG)
run: cargo nextest run --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
prove-and-verify-aggr-tests-gpu:
permissions:
contents: read
runs-on: gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
env:
ENABLE_ICICLE_GPU: true
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- name: Setup GPU dependencies
run: sudo ./setup-gpu.sh --yes
- name: Install build dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- name: KZG tests
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features gpu-accelerated --test-threads 1 -- --include-ignored
prove-and-verify-aggr-tests:
permissions:
contents: read
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- name: KZG tests
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
prove-and-verify-aggr-evm-tests:
permissions:
contents: read
runs-on: large-self-hosted
needs: [build, library-tests, docs, python-tests, python-integration-tests]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
# - name: Install solc
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 56b806a3ba7866a3b061093bebd0fa2ace97f1fc --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
examples:
permissions:
contents: read
@@ -739,12 +480,11 @@ jobs:
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -755,7 +495,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y cmake build-essential g++ gcc libclang-dev llvm-dev libstdc++-12-dev libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -765,21 +505,20 @@ jobs:
python-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.12"
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
@@ -800,26 +539,25 @@ jobs:
accuracy-measurement-tests:
permissions:
contents: read
runs-on: [ non-gpu, non-sgx ]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
runs-on: [non-gpu, non-sgx]
needs: [build, library-tests, docs]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.12"
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -845,20 +583,19 @@ jobs:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: "3.11"
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
@@ -899,100 +636,3 @@ jobs:
# run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_verifier_ --no-capture
- name: Reusable verifier tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_verifier_ --no-capture --test-threads 1
ios-integration-tests:
permissions:
contents: read
runs-on: macos-latest
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3 #v3.3.0
with:
crate: cargo-nextest
locked: true
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- name: Run ios tests
run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2025-05-01-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features
swift-package-tests:
permissions:
contents: read
runs-on: macos-latest
needs: [ios-integration-tests]
env:
EVM_VERIFIER_EZKL_TOKEN: ${{ secrets.EVM_VERIFIER_EZKL_TOKEN }}
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
with:
toolchain: nightly-2025-05-01
override: true
components: rustfmt, clippy
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- name: Build EzklCoreBindings
run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift- repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 16 Pro,OS=18.2' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 16 Pro,OS=18.2' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder

View File

@@ -1,134 +0,0 @@
name: Build and Publish EZKL iOS SPM package
on:
push:
tags:
# Only support SemVer versioning tags
- 'v[0-9]+.[0-9]+.[0-9]+'
- '[0-9]+.[0-9]+.[0-9]+'
jobs:
build-and-update:
permissions:
contents: read
packages: write
runs-on: macos-latest
env:
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
RELEASE_TAG: ${{ github.ref_name }}
steps:
- name: Checkout EZKL
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Extract TAG from github.ref_name
run: |
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
TAG="${RELEASE_TAG}"
echo "Original TAG: $TAG"
# Remove leading 'v' if present to match the Swift Package Manager version format.
NEW_TAG=${TAG#v}
echo "Stripped TAG: $NEW_TAG"
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
- name: Install Rust (nightly)
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://${{ env.EZKL_SWIFT_PACKAGE_REPO }}
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Check for changes
id: check_changes
run: |
cd ezkl-swift-package
if git diff --quiet Sources/EzklCoreBindings Tests/EzklAssets; then
echo "no_changes=true" >> $GITHUB_OUTPUT
else
echo "no_changes=false" >> $GITHUB_OUTPUT
fi
- name: Set up Xcode environment
if: steps.check_changes.outputs.no_changes == 'false'
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Setup Git
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git remote set-url origin https://zkonduit:${EZKL_SWIFT_PACKAGE_REPO_TOKEN}@${{ env.EZKL_SWIFT_PACKAGE_REPO }}
env:
EZKL_SWIFT_PACKAGE_REPO_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
- name: Commit and Push Changes
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
git add Sources/EzklCoreBindings Tests/EzklAssets
git commit -m "Automatically updated EzklCoreBindings for EZKL"
if ! git push origin; then
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
exit 1
fi
- name: Tag the latest commit
run: |
cd ezkl-swift-package
source $GITHUB_ENV
# Tag the latest commit on the current branch
if git rev-parse "$TAG" >/dev/null 2>&1; then
echo "Tag $TAG already exists locally. Skipping tag creation."
else
git tag "$TAG"
fi
if ! git push origin "$TAG"; then
echo "::error::Failed to push tag '$TAG' to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure EZKL_PORTER_TOKEN has correct permissions."
exit 1
fi

15
Cargo.lock generated
View File

@@ -1941,7 +1941,6 @@ version = "0.0.0"
dependencies = [
"alloy",
"bincode",
"camino",
"chrono",
"clap",
"clap_complete",
@@ -1996,9 +1995,7 @@ dependencies = [
"tosubcommand",
"tract-onnx",
"uniffi",
"uniffi_bindgen",
"unzip-n",
"uuid",
"wasm-bindgen",
"wasm-bindgen-console-logger",
"wasm-bindgen-rayon",
@@ -2431,7 +2428,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#1dd2090741f006fd031a07da7f3c9dfce5e0015e"
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#01c88842679b4308e43ae5ed91c4183e861669bd"
dependencies = [
"bincode",
"blake2b_simd",
@@ -6397,7 +6394,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31bff6daf87277a9014bcdefbc2842b0553392919d1096843c5aad899ca4588"
dependencies = [
"anyhow",
"uniffi_bindgen",
"uniffi_build",
"uniffi_core",
"uniffi_macros",
@@ -6580,15 +6576,6 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
dependencies = [
"getrandom 0.3.2",
]
[[package]]
name = "valuable"
version = "0.1.1"

View File

@@ -96,13 +96,6 @@ objc = { version = "0.2.4", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
jemallocator = { version = "0.5", optional = true }
mimalloc = { version = "0.1", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
uniffi_bindgen = { version = "=0.28.0", optional = true }
camino = { version = "^1.1", optional = true }
uuid = { version = "1.10.0", features = ["v4"], optional = true }
# GPU / device related things (optional - only enabled with gpu-accelerated feature)
icicle-runtime = { git = "https://github.com/ingonyama-zk/icicle", branch="emir/gate_eval_2", package="icicle-runtime", optional = true }
@@ -215,9 +208,7 @@ test = false
bench = false
required-features = ["ezkl"]
[[bin]]
name = "ios_gen_bindings"
required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
[[bin]]
name = "py_stub_gen"
@@ -236,16 +227,7 @@ default = [
]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
universal-bindings = [
"uniffi",
"mv-lookup",
"precompute-coset",
"parallel-poly-read",
"dep:halo2_solidity_verifier"
]
logging = ["dep:colored", "dep:env_logger", "dep:chrono"]
ios-bindings = ["universal-bindings"]
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
ezkl = [
"onnx",
"tabled/color",
@@ -300,7 +282,6 @@ halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]

View File

@@ -4,7 +4,6 @@ use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
@@ -153,8 +152,6 @@ fn runcnvrl(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -2,7 +2,6 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -120,8 +119,6 @@ fn rundot(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -1,53 +1,78 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
Throughput,
};
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::{create_keys, create_proof_circuit};
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, SimpleFloorPlanner, Value},
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
use std::collections::HashMap;
static mut LEN: usize = 4;
const K: usize = 16;
static mut K: usize = 15;
#[derive(Clone)]
struct MyCircuit {
inputs: [ValTensor<Fr>; 2],
_marker: PhantomData<Fr>,
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum_params: SingleEinsumParams<F>,
}
impl Circuit<Fr> for MyCircuit {
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
type FloorPlanner = V1;
type Params = SingleEinsumParams<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let len = unsafe { LEN };
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let mut config = Self::Config::default();
let a = VarTensor::new_advice(cs, K, 1, len * len);
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
unsafe {
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
}
let b = VarTensor::new_advice(cs, K, 1, len * len);
config
}
let output = VarTensor::new_advice(cs, K, 1, (len + 1) * len);
fn params(&self) -> Self::Params {
SingleEinsumParams::<Fr>::new(
&self.einsum_params.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE)
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
@@ -55,16 +80,33 @@ impl Circuit<Fr> for MyCircuit {
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
equation: self.einsum_params.equation.clone(),
}),
)
.unwrap();
@@ -77,41 +119,49 @@ impl Circuit<Fr> for MyCircuit {
fn runmatmul(c: &mut Criterion) {
let mut group = c.benchmark_group("accum_einsum_matmul");
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [4, 32].iter() {
unsafe {
LEN = len;
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear));
group.sampling_mode(criterion::SamplingMode::Flat);
group.sample_size(10);
let len = 128;
unsafe {
LEN = len;
}
for k in 16..17 {
let params = unsafe {
K = k;
gen_srs::<KZGCommitmentScheme<_>>(K as u32)
};
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[len, len]).unwrap();
let einsum_params = SingleEinsumParams::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
_marker: PhantomData,
einsum_params,
};
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, &params, false)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
MyCircuit<Fr>,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
@@ -124,8 +174,6 @@ fn runmatmul(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -5,7 +5,7 @@ use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -154,8 +154,6 @@ fn runmatmul(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -5,7 +5,7 @@ use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::table::Range;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -157,8 +157,6 @@ fn runmatmul(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -116,8 +116,6 @@ fn runsum(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -4,7 +4,7 @@ use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
@@ -131,8 +131,6 @@ fn runsumpool(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -118,8 +118,6 @@ fn runadd(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -3,7 +3,7 @@ use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -117,8 +117,6 @@ fn runpow(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -8,7 +8,7 @@ use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::circuit::Value;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -104,8 +104,6 @@ fn runposeidon(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -4,7 +4,7 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -130,8 +130,6 @@ fn runrelu(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -4,7 +4,7 @@ use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -124,8 +124,6 @@ fn runrelu(c: &mut Criterion) {
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);

View File

@@ -1,7 +1,3 @@
fn main() {
if cfg!(feature = "ios-bindings-test") {
println!("cargo::rustc-env=UNIFFI_CARGO_BUILD_EXTRA_ARGS=--features=ios-bindings --no-default-features");
}
println!("cargo::rerun-if-changed=build.rs");
}

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '22.2.4'
release = '0.0.0'
version = release

View File

@@ -0,0 +1,171 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
const K: usize = 13;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runmatmul() {
let len = 64;
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[len, len]).unwrap();
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runmatmul()
}

179
examples/batch_mat_mul.rs Normal file
View File

@@ -0,0 +1,179 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
static mut LEN: usize = 4;
const K: usize = 11;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let len = unsafe { LEN };
let a = VarTensor::new_advice(cs, K, 1, len);
let b = VarTensor::new_advice(cs, K, 1, len);
let output = VarTensor::new_advice(cs, K, 1, len);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runbatchmatmul() {
let batch_size = 5;
let len = 12;
let mut a = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[batch_size, len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[batch_size, len, len]).unwrap();
let einsum = Einsum::<Fr>::new("ijk,ikl->ijl", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runbatchmatmul()
}

View File

@@ -866,7 +866,7 @@
},
{
"cell_type": "code",
"execution_count": 98,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -879,6 +879,7 @@
"run_args.input_visibility = \"private\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.disable_freivalds = True\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"\n",
@@ -1142,4 +1143,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -253,8 +253,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
@@ -303,4 +301,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -546,7 +546,7 @@
"\n",
"proof_path = os.path.join('proof.json')\n",
"\n",
"proof = ezkl.prove(proof_type=\"single\", proof_path=proof_path)\n",
"proof = ezkl.prove(proof_path=proof_path)\n",
"\n",
"print(proof)\n",
"assert os.path.isfile(proof_path)"
@@ -736,4 +736,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -574,7 +574,7 @@
"\n",
"proof_path = os.path.join('proof.json')\n",
"\n",
"proof = ezkl.prove(proof_type=\"single\", proof_path=proof_path)\n",
"proof = ezkl.prove(proof_path=proof_path)\n",
"\n",
"print(proof)\n",
"assert os.path.isfile(proof_path)"
@@ -768,4 +768,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -54,7 +54,7 @@
" gip_run_args.param_scale = 19\n",
" gip_run_args.logrows = 8\n",
" run_args = ezkl.gen_settings(py_run_args=gip_run_args)\n",
" await ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n",
" await ezkl.get_srs()\n",
" ezkl.compile_circuit()\n",
" res = ezkl.gen_witness()\n",
" print(res)\n",
@@ -127,4 +127,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -105,7 +105,7 @@
"\n",
"class GCNConv(Module):\n",
" def __init__(self, in_channels, out_channels):\n",
" super(GCNConv, self).__init__() # \"Add\" aggregation.\n",
" super(GCNConv, self).__init__() \n",
" self.lin = torch.nn.Linear(in_channels, out_channels)\n",
"\n",
" self.reset_parameters()\n",
@@ -563,7 +563,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
@@ -625,4 +624,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -286,8 +286,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
@@ -341,4 +339,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -248,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
@@ -263,8 +263,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
@@ -313,4 +311,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -368,7 +368,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -236,7 +236,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -240,7 +240,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -358,7 +358,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -278,7 +278,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -232,7 +232,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -442,7 +442,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -227,7 +227,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -252,7 +252,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -422,7 +422,7 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -378,7 +378,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -301,7 +301,7 @@
"run_args.param_scale = 0\n",
"run_args.logrows = 18\n",
"\n",
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)\n"
"ezkl.get_srs(logrows=run_args.logrows, )\n"
]
},
{
@@ -399,7 +399,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"for-aggr\",\n",
" )\n",
"\n",
" print(res)\n",
@@ -438,28 +437,6 @@
" print(\"----- proving split \"+str(i))\n",
" prove_model(i)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also mock aggregate the split proofs into a single proof. This is useful if you want to verify the proof on chain at a lower cost. Here we mock aggregate the proofs to save time. You can use other notebooks to see how to aggregate in full ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# now mock aggregate the proofs\n",
"# proofs = []\n",
"# for i in range(3):\n",
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
"# proofs.append(proof_path)\n",
"\n",
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
]
}
],
"metadata": {
@@ -484,4 +461,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -303,7 +303,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",
@@ -543,7 +543,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -939,7 +939,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -234,7 +234,7 @@
"run_args.input_scale = 2\n",
"run_args.logrows = 15\n",
"\n",
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
"ezkl.get_srs(logrows=run_args.logrows, )"
]
},
{
@@ -330,7 +330,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"for-aggr\",\n",
" )\n",
"\n",
" print(res)\n",
@@ -426,28 +425,6 @@
"for i in range(2):\n",
" prove_model(i)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also mock aggregate the split proofs into a single proof. This is useful if you want to verify the proof on chain at a lower cost. Here we mock aggregate the proofs to save time. You can use other notebooks to see how to aggregate in full ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# now mock aggregate the proofs\n",
"proofs = []\n",
"for i in range(2):\n",
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
" proofs.append(proof_path)\n",
"\n",
"ezkl.mock_aggregate(proofs, logrows=22, split_proofs = True)"
]
}
],
"metadata": {
@@ -472,4 +449,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -260,7 +260,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -173,7 +173,7 @@
" assert os.path.isfile(settings_path)\n",
"\n",
" # GENERATE A PROOF\n",
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, \"single\")\n",
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path)\n",
" assert os.path.isfile(proof_path)\n",
"\n",
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",

View File

@@ -384,7 +384,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",
@@ -411,7 +411,7 @@
" pk_path,\n",
" proof_path_faulty,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",
@@ -438,7 +438,7 @@
" pk_path,\n",
" proof_path_truthy,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -1,407 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## EZKL Jupyter Notebook Demo (Aggregated Proofs) \n",
"\n",
"Demonstrates how to use EZKL with aggregated proofs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import torch\n",
"\n",
"\n",
"# Defines the model\n",
"# we got convs, we got relu, we got linear layers\n",
"# What else could one want ????\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
"\n",
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)\n",
" self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)\n",
"\n",
" self.relu = nn.ReLU()\n",
"\n",
" self.d1 = nn.Linear(48, 48)\n",
" self.d2 = nn.Linear(48, 10)\n",
"\n",
" def forward(self, x):\n",
" # 32x1x28x28 => 32x32x26x26\n",
" x = self.conv1(x)\n",
" x = self.relu(x)\n",
" x = self.conv2(x)\n",
" x = self.relu(x)\n",
"\n",
" # flatten => 32 x (32*26*26)\n",
" x = x.flatten(start_dim = 1)\n",
"\n",
" # 32 x (32*26*26) => 32x128\n",
" x = self.d1(x)\n",
" x = self.relu(x)\n",
"\n",
" # logits => 32x10\n",
" logits = self.d2(x)\n",
"\n",
" return logits\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"proof_path = os.path.join('test.pf')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')\n",
"aggregate_proof_path = os.path.join('aggr.pf')\n",
"aggregate_vk_path = os.path.join('aggr.vk')\n",
"aggregate_pk_path = os.path.join('aggr.pk')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"shape = [1, 28, 28]\n",
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
"x = 0.1*torch.rand(1,*shape, requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"for-aggr\", # IMPORTANT NOTE: To produce an aggregated EVM proof you will want to use poseidon for the smaller proofs\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0832b909",
"metadata": {},
"outputs": [],
"source": [
"# Generate a larger SRS. This is needed for the aggregated proof\n",
"\n",
"res = await ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5a64be6",
"metadata": {},
"outputs": [],
"source": [
"# Run mock aggregate to check whether the proof works\n",
"# Use mock to check for validity as it takes a shorter time to check compared to a full aggregated proof\n",
"\n",
"res = ezkl.mock_aggregate([proof_path], 21)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fee8acc6",
"metadata": {},
"outputs": [],
"source": [
"# Setup the vk and pk for aggregate\n",
"res = ezkl.setup_aggregate(\n",
" [proof_path],\n",
" aggregate_vk_path,\n",
" aggregate_pk_path,\n",
" 21\n",
")\n",
"\n",
"assert os.path.isfile(aggregate_vk_path)\n",
"assert os.path.isfile(aggregate_pk_path)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "171702d3",
"metadata": {},
"outputs": [],
"source": [
"# Run aggregate proof\n",
"res = ezkl.aggregate(\n",
" [proof_path],\n",
" aggregate_proof_path,\n",
" aggregate_pk_path,\n",
" \"evm\",\n",
" 21,\n",
" \"safe\"\n",
")\n",
"\n",
"assert os.path.isfile(aggregate_proof_path)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "671dfdd5",
"metadata": {},
"outputs": [],
"source": [
"# Check if the proof is valid\n",
"res = ezkl.verify_aggr(\n",
" aggregate_proof_path,\n",
" aggregate_vk_path,\n",
" 21,\n",
")\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "50eba2f4",
"metadata": {},
"outputs": [],
"source": [
"# Create a smart contract verifier for the aggregated proof\n",
"\n",
"sol_code_path = os.path.join(\"Verifier.sol\")\n",
"abi_path = os.path.join(\"Verifier_ABI.json\")\n",
"\n",
"res = await ezkl.create_evm_verifier_aggr(\n",
" [settings_path],\n",
" aggregate_vk_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" logrows=21)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -255,7 +255,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -253,7 +253,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -254,7 +254,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -233,7 +233,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -323,7 +323,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"assert os.path.isfile(proof_path)\n",
@@ -442,7 +442,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -271,7 +271,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -236,7 +236,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -707,7 +707,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -596,7 +596,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -580,7 +580,7 @@
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" ",
" )\n",
"\n",
"\n",

View File

@@ -759,7 +759,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -277,7 +277,7 @@
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" ",
" )\n",
"\n",
"print(res)\n",

View File

@@ -0,0 +1,79 @@
from torch import nn
import torch.nn.init as init
import torch
import json
N = 100
class Model(nn.Module):
def __init__(self, inplace=False):
super(Model, self).__init__()
self.aff1 = nn.Linear(N,N)
self.aff2 = nn.Linear(N,N)
self.aff3 = nn.Linear(N,N)
self.aff4 = nn.Linear(N,N)
self.aff5 = nn.Linear(N,N)
self.aff6 = nn.Linear(N,N)
self.aff7 = nn.Linear(N,N)
self.aff8 = nn.Linear(N,N)
self.aff9 = nn.Linear(N,N)
self.relu = nn.ReLU()
self._initialize_weights()
def forward(self, x):
# concat 10 x along dim 0
x = x.repeat(10, 1)
x = self.aff1(x)
x = self.relu(x)
x = self.aff2(x)
x = self.relu(x)
x = self.aff3(x)
x = self.relu(x)
x = self.aff4(x)
x = self.relu(x)
x = self.aff5(x)
x = self.relu(x)
x = self.aff6(x)
x = self.relu(x)
x = self.aff7(x)
x = self.relu(x)
x = self.aff8(x)
x = self.relu(x)
x = self.aff9(x)
return x
def _initialize_weights(self):
init.orthogonal_(self.aff1.weight)
model = Model()
# Flips the neural net into inference mode
model.eval()
model.to('cpu')
x = torch.randn(1, N)
# Export the model
torch.onnx.export(model, # model being run
# model input (or a tuple for multiple inputs)
x,
# where to save the model (can be a file or file-like object)
"network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
data_json = dict(input_data=[data_array])
print(data_json)
# Serialize data into file:
json.dump(data_json, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.33088353276252747, -0.8819183707237244, 1.245591163635254, -1.807046890258789, 1.9922369718551636, -0.3360576629638672, 0.4529011845588684, -0.3590165674686432, 0.08356846123933792, 0.5126393437385559, 0.44627535343170166, 1.4916497468948364, 0.49731069803237915, -0.9748706817626953, -0.4923185408115387, 1.3548223972320557, 0.2306872010231018, 1.125955581665039, -1.7063908576965332, 0.3777385354042053, -2.7988760471343994, -1.1846797466278076, 0.7473157048225403, 1.490412950515747, 0.017497723922133446, 2.113945245742798, -1.2141249179840088, -0.16120357811450958, 0.021127669140696526, 0.7207374572753906, -1.369688868522644, -0.7369781732559204, -0.630584180355072, -0.4520200788974762, 0.29123976826667786, 0.6334688067436218, -0.869332492351532, -1.258501648902893, 0.3012596666812897, -0.5507447123527527, 0.669975757598877, 0.15088629722595215, -0.1050339788198471, 0.5505334138870239, -0.1287376880645752, -1.4297826290130615, -0.01703289896249771, -1.2296998500823975, 0.5122153162956238, -0.16924428939819336, -0.415036678314209, -1.1979341506958008, 0.05831022188067436, -0.4411357045173645, 2.0713791847229004, 1.4611141681671143, -0.9357407093048096, -0.333297461271286, -0.676478385925293, 1.390028476715088, -0.05827632546424866, 1.535687804222107, 0.3060210347175598, -0.03171076253056526, -0.614985466003418, 1.2040390968322754, 0.31318482756614685, -1.2134959697723389, 0.13110508024692535, -1.4880926609039307, 1.7007993459701538, 1.5412729978561401, 0.09260450303554535, 0.7649128437042236, -0.5009126663208008, -0.5356241464614868, -0.069572813808918, -0.011717632412910461, 0.21314217150211334, -0.1985170543193817, -0.0223808903247118, 1.2128918170928955, 0.8334696888923645, 1.9029873609542847, -0.11491120606660843, -0.10303237289190292, -0.2467050403356552, 1.557223916053772, -1.1108328104019165, -0.9065343141555786, -0.2271333783864975, 0.6959827542304993, -0.48698121309280396, 0.5689510703086853, 1.115319013595581, -0.8907430768013, -0.24722427129745483, -0.7437837719917297, 0.6742106676101685, -1.7830933332443237]]}

Binary file not shown.

View File

@@ -1,23 +1,23 @@
## The worm
## The worm
This is an onnx file for a [WormVAE](https://github.com/TuragaLab/wormvae?tab=readme-ov-file) model, which is a VAE / latent-space representation of the C. elegans connectome.
The model "is a large-scale latent variable model with a very high-dimensional latent space
consisting of voltage dynamics of 300 neurons over 5 minutes of time at the simulation frequency
of 160 Hz. The generative model for these latent variables is described by stochastic differential
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
In effect this is a generative model for a worm's voltage dynamics, which can be used to generate new worm-like voltage dynamics given previous connectome state.
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
```bash
git lfs fetch --all
```
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
```bash
ezkl gen-settings --param-visibility=fixed
@@ -28,17 +28,7 @@ ezkl gen-witness
ezkl prove
```
You might also need to aggregate the proof to get it to fit on chain.
```bash
ezkl aggregate
```
You can then create a smart contract that verifies this aggregate proof
```bash
ezkl create-evm-verifier-aggr
```
This can then be deployed on the chain of your choice.

View File

@@ -0,0 +1,182 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
static mut LEN: usize = 4;
const K: usize = 11;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let len = unsafe { LEN };
let a = VarTensor::new_advice(cs, K, 1, len);
let b = VarTensor::new_advice(cs, K, 1, len);
let output = VarTensor::new_advice(cs, K, 1, len);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 2;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runmatmul() {
let i = 10;
let n = 10;
let j = 40;
let k = 10;
let mut a = Tensor::from((0..i * n * j).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[i, n, j]).unwrap();
// parameters
let mut b = Tensor::from((0..j * k).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[j, k]).unwrap();
let einsum = Einsum::<Fr>::new("inj,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runmatmul()
}

565
ezkl.pyi
View File

@@ -10,30 +10,26 @@ class PyG1:
r"""
pyclass containing the struct used for G1, this is mostly a helper class
"""
...
class PyG1Affine:
r"""
pyclass containing the struct used for G1
"""
...
class PyRunArgs:
r"""
Python class containing the struct used for run_args
Returns
-------
PyRunArgs
"""
...
class PyCommitments(Enum):
r"""
pyclass representing an enum, denoting the type of commitment
"""
KZG = auto()
IPA = auto()
...
class PyInputType(Enum):
Bool = auto()
@@ -47,57 +43,19 @@ class PyTestDataSource(Enum):
r"""
pyclass representing an enum
"""
File = auto()
OnChain = auto()
def aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,transcript:str,logrows:int,check_mode:str,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:PyCommitments) -> bool:
r"""
Creates an aggregated proof
Arguments
---------
aggregation_snarks: list[str]
List of paths to the various proofs
proof_path: str
Path to output the aggregated proof
vk_path: str
Path to the VK file
transcript:
Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
logrows:
Logrows used for aggregation circuit
check_mode: str
Run sanity checks during calculations. Accepts `safe` or `unsafe`
split-proofs: bool
Whether the accumulated proofs are segments of a larger circuit
srs_path: str
Path to the SRS used
commitment: str
Accepts "kzg" or "ipa"
Returns
-------
bool
"""
...
def buffer_to_felts(buffer:typing.Sequence[int]) -> list[str]:
def buffer_to_felts(buffer: typing.Sequence[int]) -> list[str]:
r"""
Converts a buffer to vector of field elements
Arguments
-------
buffer: list[int]
List of integers representing a buffer
Returns
-------
list[str]
@@ -105,173 +63,175 @@ def buffer_to_felts(buffer:typing.Sequence[int]) -> list[str]:
"""
...
def calibrate_settings(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,settings:str | os.PathLike | pathlib.Path,target:str,lookup_safety_margin:float,scales:typing.Optional[typing.Sequence[int]],scale_rebase_multiplier:typing.Sequence[int],max_logrows:typing.Optional[int]) -> typing.Any:
def calibrate_settings(
data: str | os.PathLike | pathlib.Path,
model: str | os.PathLike | pathlib.Path,
settings: str | os.PathLike | pathlib.Path,
target: str,
lookup_safety_margin: float,
scales: typing.Optional[typing.Sequence[int]],
scale_rebase_multiplier: typing.Sequence[int],
max_logrows: typing.Optional[int],
) -> typing.Any:
r"""
Calibrates the circuit settings
Arguments
---------
data: str
Path to the calibration data
model: str
Path to the onnx file
settings: str
Path to the settings file
lookup_safety_margin: int
the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
scales: list[int]
Optional scales to specifically try for calibration
scale_rebase_multiplier: list[int]
Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
max_logrows: int
Optional max logrows to use for calibration
Returns
-------
bool
"""
...
def compile_circuit(model:str | os.PathLike | pathlib.Path,compiled_circuit:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path) -> bool:
def compile_circuit(
model: str | os.PathLike | pathlib.Path,
compiled_circuit: str | os.PathLike | pathlib.Path,
settings_path: str | os.PathLike | pathlib.Path,
) -> bool:
r"""
Compiles the circuit for use in other steps
Arguments
---------
model: str
Path to the onnx model file
compiled_circuit: str
Path to output the compiled circuit
settings_path: str
Path to the settings files
Returns
-------
bool
"""
...
def create_evm_verifier(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
def create_evm_verifier(
vk_path: str | os.PathLike | pathlib.Path,
settings_path: str | os.PathLike | pathlib.Path,
sol_code_path: str | os.PathLike | pathlib.Path,
abi_path: str | os.PathLike | pathlib.Path,
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
reusable: bool,
) -> typing.Any:
r"""
Creates an EVM compatible verifier, you will need solc installed in your environment to run this
Arguments
---------
vk_path: str
The path to the verification key file
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifier
abi_path: str
The path to create the ABI for the solidity verifier
srs_path: str
The path to the SRS file
reusable: bool
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
Returns
-------
bool
"""
...
def create_evm_verifier_aggr(aggregation_settings:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,logrows:int,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
r"""
Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
Arguments
---------
aggregation_settings: str
path to the settings file
vk_path: str
The path to load the desired verification key file
sol_code_path: str
The path to the Solidity code
abi_path: str
The path to output the Solidity verifier ABI
logrows: int
Number of logrows used during aggregated setup
srs_path: str
The path to the SRS file
reusable: bool
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
Returns
-------
bool
"""
...
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vka_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
def create_evm_vka(
vk_path: str | os.PathLike | pathlib.Path,
settings_path: str | os.PathLike | pathlib.Path,
vka_path: str | os.PathLike | pathlib.Path,
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
) -> typing.Any:
r"""
Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
Arguments
---------
vk_path: str
The path to the verification key file
settings_path: str
The path to the settings file
vka_path: str
The path to the create the vka calldata.
abi_path: str
The path to create the ABI for the solidity verifier
srs_path: str
The path to the SRS file
Returns
-------
bool
"""
...
def deploy_evm(addr_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],contract_type:str,optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
def deploy_evm(
addr_path: str | os.PathLike | pathlib.Path,
sol_code_path: str | os.PathLike | pathlib.Path,
rpc_url: typing.Optional[str],
contract_type: str,
optimizer_runs: int,
private_key: typing.Optional[str],
) -> typing.Any:
r"""
deploys the solidity verifier
"""
...
def encode_evm_calldata(proof:str | os.PathLike | pathlib.Path,calldata:str | os.PathLike | pathlib.Path,addr_vk:typing.Optional[str]) -> list[int]:
def encode_evm_calldata(
proof: str | os.PathLike | pathlib.Path,
calldata: str | os.PathLike | pathlib.Path,
addr_vk: typing.Optional[str],
) -> list[int]:
r"""
Creates encoded evm calldata from a proof file
Arguments
---------
proof: str
Path to the proof file
calldata: str
Path to the calldata file to save
addr_vk: str
The address of the verification key contract (if the verifier key is to be rendered as a separate contract)
Returns
-------
vec[u8]
@@ -279,16 +239,16 @@ def encode_evm_calldata(proof:str | os.PathLike | pathlib.Path,calldata:str | os
"""
...
def felt_to_big_endian(felt:str) -> str:
def felt_to_big_endian(felt: str) -> str:
r"""
Converts a field element hex string to big endian
Arguments
-------
felt: str
The field element represented as a string
Returns
-------
str
@@ -296,54 +256,54 @@ def felt_to_big_endian(felt:str) -> str:
"""
...
def felt_to_float(felt:str,scale:int) -> float:
def felt_to_float(felt: str, scale: int) -> float:
r"""
Converts a field element hex string to a floating point number
Arguments
-------
felt: str
The field element represented as a string
scale: float
The scaling factor used to convert the field element into a floating point representation
Returns
-------
float
"""
...
def felt_to_int(felt:str) -> int:
def felt_to_int(felt: str) -> int:
r"""
Converts a field element hex string to an integer
Arguments
-------
felt: str
The field element represented as a string
Returns
-------
int
"""
...
def float_to_felt(input:float,scale:int,input_type:PyInputType) -> str:
def float_to_felt(input: float, scale: int, input_type: PyInputType) -> str:
r"""
Converts a floating point element to a field element hex string
Arguments
-------
input: float
The field element represented as a string
scale: float
The scaling factor used to quantize the float into a field element
input_type: PyInputType
The type of the input
Returns
-------
str
@@ -351,101 +311,97 @@ def float_to_felt(input:float,scale:int,input_type:PyInputType) -> str:
"""
...
def gen_settings(model:str | os.PathLike | pathlib.Path,output:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> bool:
def gen_settings(
model: str | os.PathLike | pathlib.Path,
output: str | os.PathLike | pathlib.Path,
py_run_args: typing.Optional[PyRunArgs],
) -> bool:
r"""
Generates the circuit settings
Arguments
---------
model: str
Path to the onnx file
output: str
Path to create the settings file
py_run_args: PyRunArgs
PyRunArgs object to initialize the settings
Returns
-------
bool
"""
...
def gen_srs(srs_path:str | os.PathLike | pathlib.Path,logrows:int) -> None:
def gen_srs(srs_path: str | os.PathLike | pathlib.Path, logrows: int) -> None:
r"""
Generates the Structured Reference String (SRS), use this only for testing purposes
Arguments
---------
srs_path: str
Path to the create the SRS file
logrows: int
The number of logrows for the SRS file
"""
...
def gen_vk_from_pk_aggr(path_to_pk:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Generates a vk from a pk for an aggregate circuit and saves it to a file
Arguments
-------
path_to_pk: str
Path to the proving key
vk_output_path: str
Path to create the vk file
Returns
-------
bool
"""
...
def gen_vk_from_pk_single(path_to_pk:str | os.PathLike | pathlib.Path,circuit_settings_path:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
def gen_vk_from_pk_single(
path_to_pk: str | os.PathLike | pathlib.Path,
circuit_settings_path: str | os.PathLike | pathlib.Path,
vk_output_path: str | os.PathLike | pathlib.Path,
) -> bool:
r"""
Generates a vk from a pk for a model circuit and saves it to a file
Arguments
-------
path_to_pk: str
Path to the proving key
circuit_settings_path: str
Path to the witness file
vk_output_path: str
Path to create the vk file
Returns
-------
bool
"""
...
def gen_witness(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,output:typing.Optional[str | os.PathLike | pathlib.Path],vk_path:typing.Optional[str | os.PathLike | pathlib.Path],srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
def gen_witness(
data: str | os.PathLike | pathlib.Path,
model: str | os.PathLike | pathlib.Path,
output: typing.Optional[str | os.PathLike | pathlib.Path],
vk_path: typing.Optional[str | os.PathLike | pathlib.Path],
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
) -> typing.Any:
r"""
Runs the forward pass operation to generate a witness
Arguments
---------
data: str
Path to the data file
model: str
Path to the compiled model file
output: str
Path to create the witness file
vk_path: str
Path to the verification key
srs_path: str
Path to the SRS file
Returns
-------
dict
@@ -453,126 +409,89 @@ def gen_witness(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike |
"""
...
def get_srs(settings_path:typing.Optional[str | os.PathLike | pathlib.Path],logrows:typing.Optional[int],srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:typing.Optional[PyCommitments]) -> typing.Any:
def get_srs(
settings_path: typing.Optional[str | os.PathLike | pathlib.Path],
logrows: typing.Optional[int],
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
) -> typing.Any:
r"""
Gets a public srs
Arguments
---------
settings_path: str
Path to the settings file
logrows: int
The number of logrows for the SRS file
srs_path: str
Path to the create the SRS file
commitment: str
Specify the commitment used ("kzg", "ipa")
Returns
-------
bool
"""
...
def ipa_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
r"""
Generate an ipa commitment.
Arguments
-------
message: list[str]
List of field elements represnted as strings
vk_path: str
Path to the verification key
settings_path: str
Path to the settings file
srs_path: str
Path to the Structure Reference String (SRS) file
Returns
-------
list[PyG1Affine]
"""
...
def kzg_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
def kzg_commit(
message: typing.Sequence[str],
vk_path: str | os.PathLike | pathlib.Path,
settings_path: str | os.PathLike | pathlib.Path,
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
) -> list[PyG1Affine]:
r"""
Generate a kzg commitment.
Arguments
-------
message: list[str]
List of field elements represnted as strings
vk_path: str
Path to the verification key
settings_path: str
Path to the settings file
srs_path: str
Path to the Structure Reference String (SRS) file
Returns
-------
list[PyG1Affine]
"""
...
def mock(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path) -> bool:
def mock(
witness: str | os.PathLike | pathlib.Path, model: str | os.PathLike | pathlib.Path
) -> bool:
r"""
Mocks the prover
Arguments
---------
witness: str
Path to the witness file
model: str
Path to the compiled model file
Returns
-------
bool
"""
...
def mock_aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],logrows:int,split_proofs:bool) -> bool:
r"""
Mocks the aggregate prover
Arguments
---------
aggregation_snarks: list[str]
List of paths to the relevant proof files
logrows: int
Number of logrows to use for the aggregation circuit
split_proofs: bool
Indicates whether the accumulated are segments of a larger proof
Returns
-------
bool
"""
...
def poseidon_hash(message:typing.Sequence[str]) -> list[str]:
def poseidon_hash(message: typing.Sequence[str]) -> list[str]:
r"""
Generate a poseidon hash.
Arguments
-------
message: list[str]
List of field elements represented as strings
Returns
-------
list[str]
@@ -580,126 +499,104 @@ def poseidon_hash(message:typing.Sequence[str]) -> list[str]:
"""
...
def prove(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,proof_path:typing.Optional[str | os.PathLike | pathlib.Path],proof_type:str,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
def prove(
witness: str | os.PathLike | pathlib.Path,
model: str | os.PathLike | pathlib.Path,
pk_path: str | os.PathLike | pathlib.Path,
proof_path: typing.Optional[str | os.PathLike | pathlib.Path],
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
) -> typing.Any:
r"""
Runs the prover on a set of inputs
Arguments
---------
witness: str
Path to the witness file
model: str
Path to the compiled model file
pk_path: str
Path to the proving key file
proof_path: str
Path to create the proof file
proof_type: str
Accepts `single`, `for-aggr`
srs_path: str
Path to the SRS file
Returns
-------
bool
"""
...
def setup(model:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],witness_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool) -> bool:
def setup(
model: str | os.PathLike | pathlib.Path,
vk_path: str | os.PathLike | pathlib.Path,
pk_path: str | os.PathLike | pathlib.Path,
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
witness_path: typing.Optional[str | os.PathLike | pathlib.Path],
disable_selector_compression: bool,
) -> bool:
r"""
Runs the setup process
Arguments
---------
model: str
Path to the compiled model file
vk_path: str
Path to create the verification key file
pk_path: str
Path to create the proving key file
srs_path: str
Path to the SRS file
witness_path: str
Path to the witness file
disable_selector_compression: bool
Whether to compress the selectors or not
Returns
-------
bool
"""
...
def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,logrows:int,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool,commitment:PyCommitments) -> bool:
r"""
Runs the setup process for an aggregate setup
Arguments
---------
sample_snarks: list[str]
List of paths to the various proofs
vk_path: str
Path to create the aggregated VK
pk_path: str
Path to create the aggregated PK
logrows: int
Number of logrows to use
split_proofs: bool
Whether the accumulated are segments of a larger proof
srs_path: str
Path to the SRS file
disable_selector_compression: bool
Whether to compress selectors
commitment: str
Accepts `kzg`, `ipa`
Returns
-------
bool
"""
...
def swap_proof_commitments(proof_path:str | os.PathLike | pathlib.Path,witness_path:str | os.PathLike | pathlib.Path) -> None:
def swap_proof_commitments(
proof_path: str | os.PathLike | pathlib.Path,
witness_path: str | os.PathLike | pathlib.Path,
) -> None:
r"""
Swap the commitments in a proof
Arguments
-------
proof_path: str
Path to the proof file
witness_path: str
Path to the witness file
"""
...
def table(model:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> str:
def table(
model: str | os.PathLike | pathlib.Path, py_run_args: typing.Optional[PyRunArgs]
) -> str:
r"""
Displays the table as a string in python
Arguments
---------
model: str
Path to the onnx file
Returns
---------
str
@@ -707,78 +604,59 @@ def table(model:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyR
"""
...
def verify(proof_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reduced_srs:bool) -> bool:
def verify(
proof_path: str | os.PathLike | pathlib.Path,
settings_path: str | os.PathLike | pathlib.Path,
vk_path: str | os.PathLike | pathlib.Path,
srs_path: typing.Optional[str | os.PathLike | pathlib.Path],
reduced_srs: bool,
) -> bool:
r"""
Verifies a given proof
Arguments
---------
proof_path: str
Path to create the proof file
settings_path: str
Path to the settings file
vk_path: str
Path to the verification key file
srs_path: str
Path to the SRS file
non_reduced_srs: bool
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
Returns
-------
bool
"""
...
def verify_aggr(proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,logrows:int,commitment:PyCommitments,reduced_srs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> bool:
r"""
Verifies and aggregate proof
Arguments
---------
proof_path: str
The path to the proof file
vk_path: str
The path to the verification key file
logrows: int
logrows used for aggregation circuit
commitment: str
Accepts "kzg" or "ipa"
reduced_srs: bool
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
srs_path: str
The path to the SRS file
Returns
-------
bool
"""
...
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],vka_path:typing.Optional[str]) -> typing.Any:
def verify_evm(
addr_verifier: str,
proof_path: str | os.PathLike | pathlib.Path,
rpc_url: typing.Optional[str],
vka_path: typing.Optional[str],
) -> typing.Any:
r"""
verifies an evm compatible proof, you will need solc installed in your environment to run this
Arguments
---------
addr_verifier: str
The verifier contract's address as a hex string
proof_path: str
The path to the proof file (generated using the prove command)
rpc_url: str
RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
vka_path: str
The path to the VKA calldata bytes file (generated using the create_evm_vka command)
Returns
@@ -786,4 +664,3 @@ def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc
bool
"""
...

View File

@@ -1,4 +0,0 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
};

View File

@@ -1,30 +0,0 @@
{
"name": "ezkljs-tests",
"version": "0.1.0",
"author": "Ethan Cemer",
"private": true,
"scripts": {
"test": "jest"
},
"devDependencies": {
"@ezkljs/engine": "^9.4.4",
"@ezkljs/verify": "^0.0.6",
"@jest/types": "^29.6.3",
"@types/file-saver": "^2.0.5",
"@types/jest": "^29.5.3",
"@types/json-bigint": "^1.0.1",
"@types/node": "20.4.5",
"buffer": "^6.0.3",
"env": "^0.0.2",
"fs": "0.0.1-security",
"jest": "^29.6.3",
"json-bigint": "^1.0.0",
"minimist": "^1.2.8",
"solc": "^0.8.21",
"ts-jest": "^29.1.1",
"ts-loader": "^9.4.4",
"ts-node": "^10.9.1",
"tsconfig-paths": "^4.2.0",
"typescript": "5.1.6"
}
}

3596
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,269 +0,0 @@
use camino::Utf8Path;
use std::fs;
use std::fs::remove_dir_all;
use std::path::{Path, PathBuf};
use std::process::Command;
use uniffi_bindgen::bindings::SwiftBindingGenerator;
use uniffi_bindgen::library_mode::generate_bindings;
use uuid::Uuid;
fn main() {
let library_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME is not set");
let mode = determine_build_mode();
build_bindings(&library_name, mode);
}
/// Determines the build mode based on the CONFIGURATION environment variable.
/// Defaults to "release" if not set or unrecognized.
/// "release" mode takes longer to build but produces optimized code, which has smaller size and is faster.
fn determine_build_mode() -> &'static str {
match std::env::var("CONFIGURATION").map(|s| s.to_lowercase()) {
Ok(ref config) if config == "debug" => "debug",
_ => "release",
}
}
/// Builds the Swift bindings and XCFramework for the specified library and build mode.
fn build_bindings(library_name: &str, mode: &str) {
// Get the root directory of this Cargo project
let manifest_dir = std::env::var_os("CARGO_MANIFEST_DIR")
.map(PathBuf::from)
.unwrap_or_else(|| std::env::current_dir().unwrap());
// Define the build directory inside the manifest directory
let build_dir = manifest_dir.join("build");
// Create a temporary directory to store the bindings and combined library
let tmp_dir = mktemp_local(&build_dir);
// Define directories for Swift bindings and output bindings
let swift_bindings_dir = tmp_dir.join("SwiftBindings");
let bindings_out = create_bindings_out_dir(&tmp_dir);
let framework_out = bindings_out.join("EzklCore.xcframework");
// Define target architectures for building
// We currently only support iOS devices and simulators running on ARM Macs
// This is due to limiting the library size to under 100MB for GitHub Commit Size Limit
// To support older Macs (Intel), follow the instructions in the comments below
#[allow(clippy::useless_vec)]
let target_archs = vec![
vec!["aarch64-apple-ios"], // iOS device
vec!["aarch64-apple-ios-sim"], // iOS simulator ARM Mac
// vec!["aarch64-apple-ios-sim", "x86_64-apple-ios"], // TODO - replace the above line with this line to allow running on older Macs (Intel)
];
// Build the library for each architecture and combine them
let out_lib_paths: Vec<PathBuf> = target_archs
.iter()
.map(|archs| build_combined_archs(library_name, archs, &build_dir, mode))
.collect();
// Generate the path to the built dynamic library (.dylib)
let out_dylib_path = build_dir.join(format!(
"{}/{}/lib{}.dylib",
target_archs[0][0], mode, library_name
));
// Generate Swift bindings using uniffi_bindgen
generate_ios_bindings(&out_dylib_path, &swift_bindings_dir)
.expect("Failed to generate iOS bindings");
// Move the generated Swift file to the bindings output directory
fs::rename(
swift_bindings_dir.join(format!("{}.swift", library_name)),
bindings_out.join("EzklCore.swift"),
)
.expect("Failed to copy swift bindings file");
// Rename the `ios_ezklFFI.modulemap` file to `module.modulemap`
fs::rename(
swift_bindings_dir.join(format!("{}FFI.modulemap", library_name)),
swift_bindings_dir.join("module.modulemap"),
)
.expect("Failed to rename modulemap file");
// Create the XCFramework from the combined libraries and Swift bindings
create_xcframework(&out_lib_paths, &swift_bindings_dir, &framework_out);
// Define the destination directory for the bindings
let bindings_dest = build_dir.join("EzklCoreBindings");
if bindings_dest.exists() {
fs::remove_dir_all(&bindings_dest).expect("Failed to remove existing bindings directory");
}
// Move the bindings output to the destination directory
fs::rename(&bindings_out, &bindings_dest).expect("Failed to move framework into place");
// Clean up temporary directories
cleanup_temp_dirs(&build_dir);
}
/// Creates the output directory for the bindings.
/// Returns the path to the bindings output directory.
fn create_bindings_out_dir(base_dir: &Path) -> PathBuf {
let bindings_out = base_dir.join("EzklCoreBindings");
fs::create_dir_all(&bindings_out).expect("Failed to create bindings output directory");
bindings_out
}
/// Builds the library for each architecture and combines them into a single library using lipo.
/// Returns the path to the combined library.
fn build_combined_archs(
library_name: &str,
archs: &[&str],
build_dir: &Path,
mode: &str,
) -> PathBuf {
// Build the library for each architecture
let out_lib_paths: Vec<PathBuf> = archs
.iter()
.map(|&arch| {
build_for_arch(arch, build_dir, mode);
build_dir
.join(arch)
.join(mode)
.join(format!("lib{}.a", library_name))
})
.collect();
// Create a unique temporary directory for the combined library
let lib_out = mktemp_local(build_dir).join(format!("lib{}.a", library_name));
// Combine the libraries using lipo
let mut lipo_cmd = Command::new("lipo");
lipo_cmd
.arg("-create")
.arg("-output")
.arg(lib_out.to_str().unwrap());
for lib_path in &out_lib_paths {
lipo_cmd.arg(lib_path.to_str().unwrap());
}
let status = lipo_cmd.status().expect("Failed to run lipo command");
if !status.success() {
panic!("lipo command failed with status: {}", status);
}
lib_out
}
/// Builds the library for a specific architecture.
fn build_for_arch(arch: &str, build_dir: &Path, mode: &str) {
// Ensure the target architecture is installed
install_arch(arch);
// Run cargo build for the specified architecture and mode
let mut build_cmd = Command::new("cargo");
build_cmd
.arg("build")
.arg("--no-default-features")
.arg("--features")
.arg("ios-bindings");
if mode == "release" {
build_cmd.arg("--release");
}
build_cmd
.arg("--lib")
.env("CARGO_BUILD_TARGET_DIR", build_dir)
.env("CARGO_BUILD_TARGET", arch);
let status = build_cmd.status().expect("Failed to run cargo build");
if !status.success() {
panic!("cargo build failed for architecture: {}", arch);
}
}
/// Installs the specified target architecture using rustup.
fn install_arch(arch: &str) {
let status = Command::new("rustup")
.arg("target")
.arg("add")
.arg(arch)
.status()
.expect("Failed to run rustup command");
if !status.success() {
panic!("Failed to install target architecture: {}", arch);
}
}
/// Generates Swift bindings for the iOS library using uniffi_bindgen.
fn generate_ios_bindings(dylib_path: &Path, binding_dir: &Path) -> Result<(), std::io::Error> {
// Remove existing binding directory if it exists
if binding_dir.exists() {
remove_dir_all(binding_dir)?;
}
// Generate the Swift bindings using uniffi_bindgen
generate_bindings(
Utf8Path::from_path(dylib_path).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid dylib path")
})?,
None,
&SwiftBindingGenerator,
None,
Utf8Path::from_path(binding_dir).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid Swift bindings directory",
)
})?,
true,
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
Ok(())
}
/// Creates an XCFramework from the combined libraries and Swift bindings.
fn create_xcframework(lib_paths: &[PathBuf], swift_bindings_dir: &Path, framework_out: &Path) {
let mut xcbuild_cmd = Command::new("xcodebuild");
xcbuild_cmd.arg("-create-xcframework");
// Add each library and its corresponding headers to the xcodebuild command
for lib_path in lib_paths {
println!("Including library: {:?}", lib_path);
xcbuild_cmd.arg("-library");
xcbuild_cmd.arg(lib_path.to_str().unwrap());
xcbuild_cmd.arg("-headers");
xcbuild_cmd.arg(swift_bindings_dir.to_str().unwrap());
}
xcbuild_cmd.arg("-output");
xcbuild_cmd.arg(framework_out.to_str().unwrap());
let status = xcbuild_cmd.status().expect("Failed to run xcodebuild");
if !status.success() {
panic!("xcodebuild failed with status: {}", status);
}
}
/// Creates a temporary directory inside the build path with a unique UUID.
/// This ensures unique build artifacts for concurrent builds.
fn mktemp_local(build_path: &Path) -> PathBuf {
let dir = tmp_local(build_path).join(Uuid::new_v4().to_string());
fs::create_dir(&dir).expect("Failed to create temporary directory");
dir
}
/// Gets the path to the local temporary directory inside the build path.
fn tmp_local(build_path: &Path) -> PathBuf {
let tmp_path = build_path.join("tmp");
if let Ok(metadata) = fs::metadata(&tmp_path) {
if !metadata.is_dir() {
panic!("Expected 'tmp' to be a directory");
}
} else {
fs::create_dir_all(&tmp_path).expect("Failed to create local temporary directory");
}
tmp_path
}
/// Cleans up temporary directories inside the build path.
fn cleanup_temp_dirs(build_dir: &Path) {
let tmp_dir = build_dir.join("tmp");
if tmp_dir.exists() {
fs::remove_dir_all(tmp_dir).expect("Failed to remove temporary directories");
}
}

View File

@@ -1,12 +1,3 @@
/// Python bindings
#[cfg(feature = "python-bindings")]
pub mod python;
/// Universal bindings for all platforms
#[cfg(any(
feature = "universal-bindings",
all(target_arch = "wasm32", target_os = "unknown")
))]
pub mod universal;
/// wasm prover and verifier
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub mod wasm;

View File

@@ -12,14 +12,10 @@ use crate::graph::TestDataSource;
use crate::graph::{
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
};
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
};
use crate::Commitments;
use crate::RunArgs;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
@@ -173,9 +169,6 @@ struct PyRunArgs {
#[pyo3(get, set)]
/// str: check mode, accepts `safe`, `unsafe`
pub check_mode: CheckMode,
#[pyo3(get, set)]
/// str: commitment type, accepts `kzg`, `ipa`
pub commitment: PyCommitments,
/// int: The base used for decomposition
#[pyo3(get, set)]
pub decomp_base: usize,
@@ -191,6 +184,9 @@ struct PyRunArgs {
/// float: epsilon used for arguments that use division
#[pyo3(get, set)]
pub epsilon: f64,
/// bool: Whether to disable using Freivalds' argument in einsum operations
#[pyo3(get, set)]
pub disable_freivalds: bool,
}
/// default instantiation of PyRunArgs
@@ -220,11 +216,11 @@ impl From<PyRunArgs> for RunArgs {
variables: py_run_args.variables,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: Some(py_run_args.commitment.into()),
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
epsilon: Some(py_run_args.epsilon),
disable_freivalds: py_run_args.disable_freivalds,
}
}
}
@@ -247,61 +243,11 @@ impl Into<PyRunArgs> for RunArgs {
variables: self.variables,
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
commitment: self.commitment.into(),
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
epsilon: eps,
}
}
}
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
/// pyclass representing an enum, denoting the type of commitment
pub enum PyCommitments {
/// KZG commitment
KZG,
/// IPA commitment
IPA,
}
impl From<Option<Commitments>> for PyCommitments {
fn from(commitment: Option<Commitments>) -> Self {
match commitment {
Some(Commitments::KZG) => PyCommitments::KZG,
Some(Commitments::IPA) => PyCommitments::IPA,
None => PyCommitments::KZG,
}
}
}
impl From<PyCommitments> for Commitments {
fn from(py_commitments: PyCommitments) -> Self {
match py_commitments {
PyCommitments::KZG => Commitments::KZG,
PyCommitments::IPA => Commitments::IPA,
}
}
}
impl Into<PyCommitments> for Commitments {
fn into(self) -> PyCommitments {
match self {
Commitments::KZG => PyCommitments::KZG,
Commitments::IPA => PyCommitments::IPA,
}
}
}
impl FromStr for PyCommitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"kzg" => Ok(PyCommitments::KZG),
"ipa" => Ok(PyCommitments::IPA),
_ => Err("Invalid value for Commitments".to_string()),
disable_freivalds: self.disable_freivalds,
}
}
}
@@ -618,8 +564,7 @@ fn kzg_commit(
let settings = GraphSettings::load(&settings_path)
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
let srs_path =
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
let srs_path = crate::execute::get_srs_path(settings.run_args.logrows, srs_path);
let srs = load_srs_prover::<KZGCommitmentScheme<Bn256>>(srs_path)
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
@@ -636,65 +581,6 @@ fn kzg_commit(
Ok(output.iter().map(|x| (*x).into()).collect::<Vec<_>>())
}
/// Generate an ipa commitment.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represented as strings
///
/// vk_path: str
/// Path to the verification key
///
/// settings_path: str
/// Path to the settings file
///
/// srs_path: str
/// Path to the Structure Reference String (SRS) file
///
/// Returns
/// -------
/// list[PyG1Affine]
///
#[pyfunction(signature = (
message,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
#[gen_stub_pyfunction]
fn ipa_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
settings_path: PathBuf,
srs_path: Option<PathBuf>,
) -> PyResult<Vec<PyG1Affine>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let settings = GraphSettings::load(&settings_path)
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
let srs_path =
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::IPA);
let srs = load_srs_prover::<IPACommitmentScheme<G1Affine>>(srs_path)
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
let vk = load_vk::<IPACommitmentScheme<G1Affine>, GraphCircuit>(vk_path, settings)
.map_err(|_| PyIOError::new_err("Failed to load vk"))?;
let output = PolyCommitChip::commit::<IPACommitmentScheme<G1Affine>>(
message,
(vk.cs().blinding_factors() + 1) as u32,
&srs,
);
Ok(output.iter().map(|x| (*x).into()).collect::<Vec<_>>())
}
/// Swap the commitments in a proof
///
/// Arguments
@@ -760,37 +646,6 @@ fn gen_vk_from_pk_single(
Ok(true)
}
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
///
/// Arguments
/// -------
/// path_to_pk: str
/// Path to the proving key
///
/// vk_output_path: str
/// Path to create the vk file
///
/// Returns
/// -------
/// bool
#[pyfunction(signature = (
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
#[gen_stub_pyfunction]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
let vk = pk.get_vk();
// now save
save_vk::<G1Affine>(&vk_output_path, vk)
.map_err(|_| PyIOError::new_err("Failed to save vk"))?;
Ok(true)
}
/// Displays the table as a string in python
///
/// Arguments
@@ -853,8 +708,6 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
/// srs_path: str
/// Path to the create the SRS file
///
/// commitment: str
/// Specify the commitment used ("kzg", "ipa")
///
/// Returns
/// -------
@@ -864,7 +717,6 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
settings_path=PathBuf::from(DEFAULT_SETTINGS),
logrows=None,
srs_path=None,
commitment=None,
))]
#[gen_stub_pyfunction]
fn get_srs(
@@ -872,15 +724,9 @@ fn get_srs(
settings_path: Option<PathBuf>,
logrows: Option<u32>,
srs_path: Option<PathBuf>,
commitment: Option<PyCommitments>,
) -> PyResult<Bound<'_, PyAny>> {
let commitment: Option<Commitments> = match commitment {
Some(c) => Some(c.into()),
None => None,
};
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::get_srs_cmd(srs_path, settings_path, logrows, commitment)
crate::execute::get_srs_cmd(srs_path, settings_path, logrows)
.await
.map_err(|e| {
let err_str = format!("Failed to get srs: {}", e);
@@ -1115,42 +961,6 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
Ok(true)
}
/// Mocks the aggregate prover
///
/// Arguments
/// ---------
/// aggregation_snarks: list[str]
/// List of paths to the relevant proof files
///
/// logrows: int
/// Number of logrows to use for the aggregation circuit
///
/// split_proofs: bool
/// Indicates whether the accumulated are segments of a larger proof
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
))]
#[gen_stub_pyfunction]
fn mock_aggregate(
aggregation_snarks: Vec<PathBuf>,
logrows: u32,
split_proofs: bool,
) -> PyResult<bool> {
crate::execute::mock_aggregate(aggregation_snarks, logrows, split_proofs).map_err(|e| {
let err_str = format!("Failed to run mock: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
/// Runs the setup process
///
/// Arguments
@@ -1226,8 +1036,6 @@ fn setup(
/// proof_path: str
/// Path to create the proof file
///
/// proof_type: str
/// Accepts `single`, `for-aggr`
///
/// srs_path: str
/// Path to the SRS file
@@ -1241,7 +1049,6 @@ fn setup(
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
pk_path=PathBuf::from(DEFAULT_PK),
proof_path=None,
proof_type=ProofType::default(),
srs_path=None,
))]
#[gen_stub_pyfunction]
@@ -1250,7 +1057,6 @@ fn prove(
model: PathBuf,
pk_path: PathBuf,
proof_path: Option<PathBuf>,
proof_type: ProofType,
srs_path: Option<PathBuf>,
) -> PyResult<PyObject> {
let snark = crate::execute::prove(
@@ -1259,7 +1065,6 @@ fn prove(
pk_path,
proof_path,
srs_path,
proof_type,
CheckMode::UNSAFE,
)
.map_err(|e| {
@@ -1318,77 +1123,6 @@ fn verify(
Ok(true)
}
/// Runs the setup process for an aggregate setup
///
/// Arguments
/// ---------
/// sample_snarks: list[str]
/// List of paths to the various proofs
///
/// vk_path: str
/// Path to create the aggregated VK
///
/// pk_path: str
/// Path to create the aggregated PK
///
/// logrows: int
/// Number of logrows to use
///
/// split_proofs: bool
/// Whether the accumulated are segments of a larger proof
///
/// srs_path: str
/// Path to the SRS file
///
/// disable_selector_compression: bool
/// Whether to compress selectors
///
/// commitment: str
/// Accepts `kzg`, `ipa`
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
sample_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
pk_path=PathBuf::from(DEFAULT_PK_AGGREGATED),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
srs_path = None,
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
vk_path: PathBuf,
pk_path: PathBuf,
logrows: u32,
split_proofs: bool,
srs_path: Option<PathBuf>,
disable_selector_compression: bool,
commitment: PyCommitments,
) -> Result<bool, PyErr> {
crate::execute::setup_aggregate(
sample_snarks,
vk_path,
pk_path,
srs_path,
logrows,
split_proofs,
disable_selector_compression,
commitment.into(),
)
.map_err(|e| {
let err_str = format!("Failed to setup aggregate: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
/// Compiles the circuit for use in other steps
///
/// Arguments
@@ -1418,144 +1152,7 @@ fn compile_circuit(
settings_path: PathBuf,
) -> Result<bool, PyErr> {
crate::execute::compile_circuit(model, compiled_circuit, settings_path).map_err(|e| {
let err_str = format!("Failed to setup aggregate: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
/// Creates an aggregated proof
///
/// Arguments
/// ---------
/// aggregation_snarks: list[str]
/// List of paths to the various proofs
///
/// proof_path: str
/// Path to output the aggregated proof
///
/// vk_path: str
/// Path to the VK file
///
/// transcript:
/// Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
///
/// logrows:
/// Logrows used for aggregation circuit
///
/// check_mode: str
/// Run sanity checks during calculations. Accepts `safe` or `unsafe`
///
/// split-proofs: bool
/// Whether the accumulated proofs are segments of a larger circuit
///
/// srs_path: str
/// Path to the SRS used
///
/// commitment: str
/// Accepts "kzg" or "ipa"
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
transcript=TranscriptType::default(),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
check_mode=CheckMode::UNSAFE,
split_proofs = false,
srs_path=None,
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn aggregate(
aggregation_snarks: Vec<PathBuf>,
proof_path: PathBuf,
vk_path: PathBuf,
transcript: TranscriptType,
logrows: u32,
check_mode: CheckMode,
split_proofs: bool,
srs_path: Option<PathBuf>,
commitment: PyCommitments,
) -> Result<bool, PyErr> {
// the K used for the aggregation circuit
crate::execute::aggregate(
proof_path,
aggregation_snarks,
vk_path,
srs_path,
transcript,
logrows,
check_mode,
split_proofs,
commitment.into(),
)
.map_err(|e| {
let err_str = format!("Failed to run aggregate: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
/// Verifies and aggregate proof
///
/// Arguments
/// ---------
/// proof_path: str
/// The path to the proof file
///
/// vk_path: str
/// The path to the verification key file
///
/// logrows: int
/// logrows used for aggregation circuit
///
/// commitment: str
/// Accepts "kzg" or "ipa"
///
/// reduced_srs: bool
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
///
/// srs_path: str
/// The path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
vk_path=PathBuf::from(DEFAULT_VK),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
srs_path=None,
))]
#[gen_stub_pyfunction]
fn verify_aggr(
proof_path: PathBuf,
vk_path: PathBuf,
logrows: u32,
commitment: PyCommitments,
reduced_srs: bool,
srs_path: Option<PathBuf>,
) -> Result<bool, PyErr> {
crate::execute::verify_aggr(
proof_path,
vk_path,
srs_path,
logrows,
reduced_srs,
commitment.into(),
)
.map_err(|e| {
let err_str = format!("Failed to run verify_aggr: {}", e);
let err_str = format!("Failed to compile circuit: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -1662,7 +1259,7 @@ fn create_evm_verifier(
#[cfg(feature = "reusable-verifier")]
/// Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
/// This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
/// This is useful for deploying verifier that were otherwise too big to fit on chain .
///
/// Arguments
/// ---------
@@ -1868,75 +1465,6 @@ fn verify_evm<'a>(
})
}
/// Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// aggregation_settings: str
/// path to the settings file
///
/// vk_path: str
/// The path to load the desired verification key file
///
/// sol_code_path: str
/// The path to the Solidity code
///
/// abi_path: str
/// The path to output the Solidity verifier ABI
///
/// logrows: int
/// Number of logrows used during aggregated setup
///
/// srs_path: str
/// The path to the SRS file
///
/// reusable: bool
/// Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_settings=vec![PathBuf::from(DEFAULT_PROOF)],
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
srs_path=None,
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn create_evm_verifier_aggr(
py: Python<'_>,
aggregation_settings: Vec<PathBuf>,
vk_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
logrows: u32,
srs_path: Option<PathBuf>,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_aggregate_verifier(
vk_path,
srs_path,
sol_code_path,
abi_path,
aggregation_settings,
logrows,
reusable,
)
.await
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier_aggr: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
// Define a function to gather stub information.
define_stub_info_gatherer!(stub_info);
@@ -1948,19 +1476,16 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add_class::<PyCommitments>()?;
m.add_class::<PyInputType>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
m.add_function(wrap_pyfunction!(kzg_commit, m)?)?;
m.add_function(wrap_pyfunction!(ipa_commit, m)?)?;
m.add_function(wrap_pyfunction!(swap_proof_commitments, m)?)?;
m.add_function(wrap_pyfunction!(poseidon_hash, m)?)?;
m.add_function(wrap_pyfunction!(float_to_felt, m)?)?;
m.add_function(wrap_pyfunction!(buffer_to_felts, m)?)?;
m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?;
m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?;
m.add_function(wrap_pyfunction!(table, m)?)?;
m.add_function(wrap_pyfunction!(mock, m)?)?;
@@ -1973,17 +1498,12 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;
m.add_function(wrap_pyfunction!(setup_aggregate, m)?)?;
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
#[cfg(feature = "reusable-verifier")]
m.add_function(wrap_pyfunction!(create_evm_vka, m)?)?;
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
#[cfg(feature = "reusable-verifier")]
m.add_function(wrap_pyfunction!(register_vka, m)?)?;
@@ -1999,24 +1519,6 @@ impl pyo3_stub_gen::PyStubType for CalibrationTarget {
}
}
impl pyo3_stub_gen::PyStubType for ProofType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for TranscriptType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for CheckMode {
fn type_output() -> TypeInfo {
TypeInfo {

View File

@@ -1,606 +0,0 @@
use halo2_proofs::{
plonk::*,
poly::{
commitment::{CommitmentScheme, ParamsProver},
ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
multiopen::{ProverIPA, VerifierIPA},
strategy::SingleStrategy as IPASingleStrategy,
},
kzg::{
commitment::{KZGCommitmentScheme, ParamsKZG},
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy as KZGSingleStrategy,
},
VerificationStrategy,
},
};
use std::fmt::Display;
use std::io::BufReader;
use std::str::FromStr;
use crate::{
circuit::region::RegionSettings,
graph::GraphSettings,
pfsys::{
create_proof_circuit, encode_calldata,
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
verify_proof_circuit, TranscriptType,
},
tensor::TensorType,
CheckMode, Commitments, EZKLError as InnerEZKLError,
};
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::graph::{GraphCircuit, GraphWitness};
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::{FromUniformBytes, PrimeField},
};
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
/// Wrapper around the Error Message
#[cfg_attr(feature = "ios-bindings", derive(uniffi::Error))]
#[derive(Debug)]
pub enum EZKLError {
/// Some Comment
InternalError(String),
}
impl Display for EZKLError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EZKLError::InternalError(e) => write!(f, "Internal error: {}", e),
}
}
}
impl From<InnerEZKLError> for EZKLError {
fn from(e: InnerEZKLError) -> Self {
EZKLError::InternalError(e.to_string())
}
}
/// Hash the input message with poseidon
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn poseidon_hash(message: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
let message: Vec<Fr> = serde_json::from_slice(&message[..]).map_err(InnerEZKLError::from)?;
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
.map_err(InnerEZKLError::from)?;
Ok(serde_json::to_vec(&output).map_err(InnerEZKLError::from)?)
}
/// Hash the input message with poseidon without converting to Fr
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn poseidon_hash_no_felt(message: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
let message: Vec<Fr> = message.iter().map(|x| Fr::from(*x as u64)).collect();
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
.map_err(InnerEZKLError::from)?;
Ok(serde_json::to_vec(&output).map_err(InnerEZKLError::from)?)
}
/// Encode verifier calldata from proof and ethereum vk_address
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn encode_verifier_calldata(
// TODO - shuold it be pub or pub or pub(super)?
proof: Vec<u8>,
vka: Option<Vec<u8>>,
) -> Result<Vec<u8>, EZKLError> {
let snark: crate::pfsys::Snark<Fr, G1Affine> =
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
let vka_buf: Option<Vec<[u8; 32]>> = if let Some(vka) = vka {
let array: Vec<[u8; 32]> =
serde_json::from_slice(&vka[..]).map_err(InnerEZKLError::from)?;
Some(array)
} else {
None
};
let vka: Option<&[[u8; 32]]> = vka_buf.as_deref();
let flattened_instances = snark.instances.into_iter().flatten();
let encoded = encode_calldata(vka, &snark.proof, &flattened_instances.collect::<Vec<_>>());
Ok(encoded)
}
/// Generate witness from compiled circuit and input json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn gen_witness(compiled_circuit: Vec<u8>, input: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
println!("[circuit]");
let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize compiled model: {}", e))
})?;
println!("[input]");
let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize input: {}", e)))?;
println!("[load graph input]");
let mut input = circuit
.load_graph_input(&input)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
println!("[load graph witness]");
let witness = circuit
.forward::<KZGCommitmentScheme<Bn256>>(
&mut input,
None,
None,
RegionSettings::all_true(
circuit.settings().run_args.decomp_base,
circuit.settings().run_args.decomp_legs,
),
)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
println!("[serialize witness]");
serde_json::to_vec(&witness)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize witness: {}", e)))
}
/// Generate verifying key from compiled circuit, and parameters srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn gen_vk(
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
compress_selectors: bool,
) -> Result<Vec<u8>, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let vk = create_vk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
let mut serialized_vk = Vec::new();
vk.write(
&mut serialized_vk,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
Ok(serialized_vk)
}
/// Generate proving key from vk, compiled circuit and parameters srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn gen_pk(vk: Vec<u8>, compiled_circuit: Vec<u8>, srs: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
let pk = create_pk_lean::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk, &circuit, &params)
.map_err(|e| EZKLError::InternalError(format!("Failed to create proving key: {}", e)))?;
let mut serialized_pk = Vec::new();
pk.write(&mut serialized_pk, halo2_proofs::SerdeFormat::RawBytes)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize proving key: {}", e)))?;
Ok(serialized_pk)
}
/// Verify proof with vk, proof json, circuit settings json and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn verify(
proof: Vec<u8>,
vk: Vec<u8>,
settings: Vec<u8>,
srs: Vec<u8>,
) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize settings: {}", e)))?;
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings.clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
let orig_n = 1 << circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
let mut reader = BufReader::new(&srs[..]);
let result = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> = get_params(&mut reader)?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(EZKLError::InternalError(format!(
"Verification failed: {}",
e
))),
}
}
/// Verify aggregate proof with vk, proof, circuit settings and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn verify_aggr(
proof: Vec<u8>,
vk: Vec<u8>,
logrows: u64,
srs: Vec<u8>,
commitment: &str,
) -> Result<bool, EZKLError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?;
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment)
.map_err(|e| EZKLError::InternalError(format!("Invalid commitment: {}", e)))?;
let orig_n = 1 << logrows;
let mut reader = BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)),
)?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
result
.map(|_| true)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))
}
/// Prove in browser with compiled circuit, witness json, proving key, and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn prove(
witness: Vec<u8>,
pk: Vec<u8>,
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
) -> Result<Vec<u8>, EZKLError> {
#[cfg(feature = "det-prove")]
log::set_max_level(log::LevelFilter::Debug);
#[cfg(not(feature = "det-prove"))]
log::set_max_level(log::LevelFilter::Info);
let mut circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?;
let data: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&pk[..]);
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
circuit
.load_graph_witness(&data)
.map_err(InnerEZKLError::from)?;
let public_inputs = circuit
.prepare_public_inputs(&data)
.map_err(InnerEZKLError::from)?;
let proof_split_commits: Option<crate::pfsys::ProofSplitCommit> = data.into();
let mut reader = BufReader::new(&srs[..]);
let commitment = circuit.settings().run_args.commitment.into();
let proof = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
)?;
create_proof_circuit::<
KZGCommitmentScheme<Bn256>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
KZGSingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit,
vec![public_inputs],
&params,
&pk,
CheckMode::UNSAFE,
Commitments::KZG,
TranscriptType::EVM,
proof_split_commits,
None,
)
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(
|e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)),
)?;
create_proof_circuit::<
IPACommitmentScheme<G1Affine>,
_,
ProverIPA<_>,
VerifierIPA<_>,
IPASingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit,
vec![public_inputs],
&params,
&pk,
CheckMode::UNSAFE,
Commitments::IPA,
TranscriptType::EVM,
proof_split_commits,
None,
)
}
}
.map_err(InnerEZKLError::from)?;
Ok(serde_json::to_vec(&proof).map_err(InnerEZKLError::from)?)
}
/// Validate the witness json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn witness_validation(witness: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the compiled circuit
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn compiled_circuit_validation(compiled_circuit: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphCircuit = bincode::deserialize(&compiled_circuit[..]).map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize compiled circuit: {}", e))
})?;
Ok(true)
}
/// Validate the input json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn input_validation(input: Vec<u8>) -> Result<bool, EZKLError> {
let _: crate::graph::input::GraphData =
serde_json::from_slice(&input[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the proof json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn proof_validation(proof: Vec<u8>) -> Result<bool, EZKLError> {
let _: crate::pfsys::Snark<Fr, G1Affine> =
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the verifying key given the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings =
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&vk[..]);
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
Ok(true)
}
/// Validate the proving key given the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings =
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
let mut reader = BufReader::new(&pk[..]);
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
Ok(true)
}
/// Validate the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn settings_validation(settings: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphSettings = serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
Ok(true)
}
/// Validate the srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn srs_validation(srs: Vec<u8>) -> Result<bool, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let _: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize params: {}", e))
})?;
Ok(true)
}
// HELPER FUNCTIONS
fn get_params<
Scheme: for<'a> halo2_proofs::poly::commitment::Params<'a, halo2curves::bn256::G1Affine>,
>(
mut reader: &mut BufReader<&[u8]>,
) -> Result<Scheme, EZKLError> {
halo2_proofs::poly::commitment::Params::<G1Affine>::read(&mut reader)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)))
}
/// Creates a [ProvingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
pub fn create_vk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
circuit: &C,
params: &'_ Scheme::ParamsProver,
compress_selectors: bool,
) -> Result<VerifyingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
// Initialize the verifying key
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
Ok(vk)
}
/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target
pub fn create_pk_lean<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
vk: VerifyingKey<Scheme::Curve>,
circuit: &C,
params: &'_ Scheme::ParamsProver,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
// Initialize the proving key
let pk = keygen_pk(params, vk, &empty_circuit)?;
Ok(pk)
}

View File

@@ -1,398 +0,0 @@
use crate::{
circuit::modules::polycommit::PolyCommitChip,
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
graph::{quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings},
};
use console_error_panic_hook;
use halo2_proofs::{
plonk::*,
poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG},
};
use halo2_solidity_verifier::Evm;
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::PrimeField,
};
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use crate::bindings::universal::{
compiled_circuit_validation, encode_verifier_calldata, gen_pk, gen_vk, gen_witness,
input_validation, pk_validation, proof_validation, settings_validation, srs_validation,
verify_aggr, vk_validation, witness_validation, EZKLError as ExternalEZKLError,
};
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
impl From<ExternalEZKLError> for JsError {
fn from(e: ExternalEZKLError) -> Self {
JsError::new(&format!("{}", e))
}
}
#[wasm_bindgen]
/// Initialize logger for wasm
pub fn init_logger() {
log::set_logger(&DEFAULT_LOGGER).unwrap();
}
#[wasm_bindgen]
/// Initialize panic hook for wasm
pub fn init_panic_hook() {
console_error_panic_hook::set_once();
}
/// Wrapper around the halo2 encode call data method
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn encodeVerifierCalldata(
proof: wasm_bindgen::Clamped<Vec<u8>>,
vk_address: Option<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
encode_verifier_calldata(proof.0, vk_address).map_err(JsError::from)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(format!("{:?}", felt))
}
/// Converts a felt to a little endian string
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let repr = serde_json::to_string(&felt).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
Ok(b)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToInt(
array: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&felt_to_integer_rep(felt))
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
))
}
/// Converts felts to a floating point element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToFloat(
array: wasm_bindgen::Clamped<Vec<u8>>,
scale: crate::Scale,
) -> Result<f64, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let int_rep = felt_to_integer_rep(felt);
let multiplier = scale_to_multiplier(scale);
Ok(int_rep as f64 / multiplier)
}
/// Converts a floating point number to a hex string representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn floatToFelt(
mut input: f64,
scale: crate::Scale,
input_type: &str,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
crate::circuit::InputType::roundtrip(
&crate::circuit::InputType::from_str(input_type)
.map_err(|e| JsError::new(&format!("{}", e)))?,
&mut input,
);
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = integer_rep_to_felt(int_rep);
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
)?))
}
/// Generate a kzg commitment.
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn kzgCommit(
message: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let mut reader = std::io::BufReader::new(&params_ser[..]);
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
message,
(vk.cs().blinding_factors() + 1) as u32,
&params,
);
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?,
))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn bufferToVecOfFelt(
buffer: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
// Convert the buffer to a slice
let buffer: &[u8] = &buffer;
// Divide the buffer into chunks of 64 bytes
let chunks = buffer.chunks_exact(16);
// Get the remainder
let remainder = chunks.remainder();
// Add 0s to the remainder to make it 64 bytes
let mut remainder = remainder.to_vec();
// Collect chunks into a Vec<[u8; 16]>.
let chunks: Result<Vec<[u8; 16]>, JsError> = chunks
.map(|slice| {
let array: [u8; 16] = slice
.try_into()
.map_err(|_| JsError::new("failed to slice input chunks"))?;
Ok(array)
})
.collect();
let mut chunks = chunks?;
if remainder.len() != 0 {
remainder.resize(16, 0);
// Convert the Vec<u8> to [u8; 16]
let remainder_array: [u8; 16] = remainder
.try_into()
.map_err(|_| JsError::new("failed to slice remainder"))?;
// append the remainder to the chunks
chunks.push(remainder_array);
}
// Convert each chunk to a field element
let field_elements: Vec<Fr> = chunks
.iter()
.map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x)))
.collect();
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&field_elements)
.map_err(|e| JsError::new(&format!("Failed to serialize field elements: {}", e)))?,
))
}
/// Generate a poseidon hash in browser. Input message
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn poseidonHash(
message: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
super::universal::poseidon_hash(message.0)
.map_err(JsError::from)
.map(|x| wasm_bindgen::Clamped(x.clone()))
}
/// Generate a witness file from input.json, compiled model and a settings.json file.
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genWitness(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
input: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
gen_witness(compiled_circuit.0, input.0).map_err(JsError::from)
}
/// Generate verifying key in browser
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genVk(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
compress_selectors: bool,
) -> Result<Vec<u8>, JsError> {
gen_vk(compiled_circuit.0, params_ser.0, compress_selectors).map_err(JsError::from)
}
/// Generate proving key in browser
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn genPk(
vk: wasm_bindgen::Clamped<Vec<u8>>,
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
params_ser: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
gen_pk(vk.0, compiled_circuit.0, params_ser.0).map_err(JsError::from)
}
/// Verify proof in browser using wasm
#[wasm_bindgen]
pub fn verify(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
srs: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
super::universal::verify(proof_js.0, vk.0, settings.0, srs.0).map_err(JsError::from)
}
/// Verify proof in browser evm using wasm
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn verifyEVM(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
bytecode_verifier: Vec<u8>,
bytecode_vka: Option<Vec<u8>>,
) -> Result<bool, JsError> {
let mut evm = Evm::unlimited();
let decoded_verifier = utf8_bytes_to_hex_decoded(&bytecode_verifier)?;
let (verifier_address, _) = evm.create(decoded_verifier);
// if bytecode_vk is Some, then create the vk contract
let vk_address = if let Some(bytecode_vka) = bytecode_vka {
let decoded_vka = utf8_bytes_to_hex_decoded(&bytecode_vka)?;
let (address, _) = evm.create(decoded_vka);
Some(address.as_slice().to_vec())
// check if bytecode_verifier is none and if so then generate the
// reusable verifier
} else {
None
};
let calldata = encode_verifier_calldata(proof_js.0, vk_address).map_err(JsError::from);
let output = evm.call(verifier_address, calldata?).1;
let true_word = [vec![0; 31], vec![1]].concat();
Ok(output == true_word)
}
/// Verify aggregate proof in browser using wasm
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
verify_aggr(proof_js.0, vk.0, logrows, srs.0, commitment).map_err(JsError::from)
}
/// Prove in browser using wasm
#[wasm_bindgen]
pub fn prove(
witness: wasm_bindgen::Clamped<Vec<u8>>,
pk: wasm_bindgen::Clamped<Vec<u8>>,
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
srs: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
super::universal::prove(witness.0, pk.0, compiled_circuit.0, srs.0).map_err(JsError::from)
}
// VALIDATION FUNCTIONS
/// Witness file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn witnessValidation(witness: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
witness_validation(witness.0).map_err(JsError::from)
}
/// Compiled circuit validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn compiledCircuitValidation(
compiled_circuit: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
compiled_circuit_validation(compiled_circuit.0).map_err(JsError::from)
}
/// Input file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn inputValidation(input: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
input_validation(input.0).map_err(JsError::from)
}
/// Proof file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn proofValidation(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
proof_validation(proof.0).map_err(JsError::from)
}
/// Vk file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn vkValidation(
vk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
vk_validation(vk.0, settings.0).map_err(JsError::from)
}
/// Pk file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn pkValidation(
pk: wasm_bindgen::Clamped<Vec<u8>>,
settings: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
pk_validation(pk.0, settings.0).map_err(JsError::from)
}
/// Settings file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn settingsValidation(settings: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
settings_validation(settings.0).map_err(JsError::from)
}
/// Srs file validation
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn srsValidation(srs: wasm_bindgen::Clamped<Vec<u8>>) -> Result<bool, JsError> {
srs_validation(srs.0).map_err(JsError::from)
}
/// HELPER FUNCTIONS
pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
for &b in arr.iter().rev() {
n <<= 8;
n |= b as u128;
}
n
}
///
pub fn utf8_bytes_to_hex_decoded(input: &[u8]) -> Result<Vec<u8>, JsError> {
let string = std::str::from_utf8(input)?.trim();
let hex_string = if string.starts_with("0x") {
&string[2..]
} else {
string
};
hex::decode(hex_string).map_err(JsError::from)
}

View File

@@ -14,6 +14,7 @@ use tosubcommand::ToFlags;
use crate::{
circuit::{
chip::einsum::analysis::EinsumAnalysis,
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
},
@@ -24,6 +25,9 @@ use std::{collections::BTreeMap, marker::PhantomData};
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
use halo2curves::ff::{Field, PrimeField};
///
pub mod einsum;
#[allow(missing_docs)]
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
#[derive(
@@ -266,6 +270,8 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub range_checks: RangeChecks<F>,
/// [Selector]s for the shuffles
pub shuffles: Shuffles,
/// Einsum-specific configuration
pub einsums: Option<einsum::Einsums<F>>,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
@@ -280,6 +286,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
einsums: Some(einsum::Einsums::<F>::dummy(col_size, num_inner_cols)),
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
shared_table_inputs: vec![],
_marker: PhantomData,
}
}
/// Returns a new [BaseConfig] with no inputs, no selectors, no tables, and no Freivalds' argument.
pub fn dummy_without_freivalds(col_size: usize, num_inner_cols: usize) -> Self {
Self {
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
einsums: None,
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
@@ -414,6 +436,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
},
static_lookups: StaticLookups::default(),
dynamic_lookups: DynamicLookups::default(),
einsums: None,
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
shared_table_inputs: vec![],
@@ -688,6 +711,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
Ok(())
}
/// Configures and creates einsums
#[allow(clippy::too_many_arguments)]
pub fn configure_einsums(
&mut self,
cs: &mut ConstraintSystem<F>,
analysis: &EinsumAnalysis,
num_inner_cols: usize,
logrows: usize,
) -> Result<(), CircuitError>
where
F: Field,
{
self.einsums = Some(einsum::Einsums::configure_universal(
cs,
analysis,
num_inner_cols,
logrows,
));
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_shuffles(

View File

@@ -0,0 +1,210 @@
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use crate::circuit::{
einsum::reduction_planner::{self, Reduction},
CircuitError,
};
///
#[derive(Debug, Clone)]
pub struct EinsumAnalysis {
/// max size of input tensors
pub max_input_size: usize,
/// max size of output tensors
pub max_output_size: usize,
/// max number of input tensors
pub max_num_inputs: usize,
/// max number of output axes
pub max_num_output_axes: usize,
/// the sum of the lengths of dot product to compute all the reductions
pub reduction_length: usize,
}
/// The strategy to use for einsum
#[derive(Debug, Clone)]
pub enum EinsumStrategy {
/// Use only base ops
BaseOps,
/// Use Freivalds' argument
Freivalds,
}
///
#[derive(Debug, Clone)]
pub struct SingleEquationAnalysis {
///
pub equation: String,
///
pub num_inputs: usize,
///
pub max_input_size: usize,
///
pub output_size: usize,
///
pub num_output_axes: usize,
///
pub output_indices: Vec<char>,
/// the length of dot product to compute all the reductions
pub reduction_length: usize,
/// the strategy to use for einsum
pub strategy: EinsumStrategy,
}
///
pub fn analyze_einsum_usage(
equations: &HashMap<(usize, String), HashMap<char, usize>>,
) -> Result<EinsumAnalysis, CircuitError> {
let mut max_num_inputs = 0;
let mut max_input_size = 0;
let mut max_output_size = 0;
let mut max_num_output_axes = 0;
let mut reduction_length = 0;
for ((_, equation), input_axes_to_dim) in equations.iter() {
let analysis = analyze_single_equation(equation, input_axes_to_dim)?;
max_input_size = max_input_size.max(analysis.max_input_size);
max_output_size = max_output_size.max(analysis.output_size);
max_num_inputs = max_num_inputs.max(analysis.num_inputs);
max_num_output_axes = max_num_output_axes.max(analysis.num_output_axes);
reduction_length += analysis.reduction_length;
}
Ok(EinsumAnalysis {
max_input_size,
max_output_size,
max_num_inputs,
max_num_output_axes,
reduction_length,
})
}
///
pub fn analyze_single_equation(
equation: &str,
input_axes_to_dim: &HashMap<char, usize>,
) -> Result<SingleEquationAnalysis, CircuitError> {
// Sanitise equation to remove trivial axes
let equation = {
let (inputs_str, output_str) = equation.split_once("->").unwrap();
let input_equations: Vec<&str> = inputs_str.split(',').collect();
let inputs: Vec<String> = input_equations
.iter()
.map(|input| {
input
.chars()
.filter(|char| *input_axes_to_dim.get(char).unwrap() > 1)
.collect()
})
.collect();
let output = output_str
.chars()
.filter(|c| {
input_axes_to_dim.get(c).is_some() && *input_axes_to_dim.get(c).unwrap() > 1
})
.collect();
[inputs.join(","), output].join("->")
};
let (inputs_eq, output_eq) = equation.split_once("->").unwrap();
let input_equations: Vec<&str> = inputs_eq.split(',').collect();
let max_input_size = input_equations
.iter()
.map(|eqn| {
eqn.chars()
.map(|c| input_axes_to_dim.get(&c).unwrap())
.product()
})
.max()
.unwrap();
let output_indices: Vec<char> = output_eq.chars().collect();
let output_dims = output_indices
.iter()
.map(|c| input_axes_to_dim.get(&c).unwrap());
let output_size = output_dims.clone().product();
let output_reduction_length = {
let mut output_dims = output_dims.rev().cloned().collect_vec();
let mut total_length = 0;
for _ in 0..output_dims.len() {
let dot_product_len = output_dims.remove(0);
let num_dot_products: usize = output_dims.iter().product();
total_length += dot_product_len * num_dot_products;
}
total_length
};
let input_reductions_length = {
let input_reductions = reduction_planner::input_reductions(&equation)?;
input_reductions
.into_iter()
.map(|reduction| {
let (_, output_expr) = reduction.expression().split_once("->").unwrap();
let num_inputs = reduction.input_indices().len();
let dot_product_len = match reduction {
Reduction::RLC { axis, .. } => *input_axes_to_dim.get(&axis).unwrap(),
Reduction::Contraction { axis, .. } => *axis
.and_then(|axis| input_axes_to_dim.get(&axis))
.unwrap_or(&1),
};
let num_dot_products: usize = output_expr
.chars()
.map(|c| input_axes_to_dim.get(&c).unwrap())
.product();
// since `multi_dot` does pairwise mult between input pairs and final summation
if num_inputs <= 2 {
num_dot_products * dot_product_len
} else {
num_dot_products * (dot_product_len * num_inputs)
}
})
.sum::<usize>()
};
let dispatch_to_einsum_with_base_ops = {
let mut seen = HashSet::new();
let mut common_indices_to_inputs = vec![];
for input in input_equations.iter() {
for c in input.chars() {
if !seen.contains(&c) {
seen.insert(c);
} else {
common_indices_to_inputs.push(c);
}
}
}
let non_common_indices = input_axes_to_dim
.keys()
.filter(|&x| {
!common_indices_to_inputs.contains(x)
&& input_axes_to_dim.get(x).cloned().unwrap() > 1
})
.collect::<Vec<_>>();
!(output_indices.len() > 0
&& common_indices_to_inputs.len() > 0
&& non_common_indices.len() > 1)
};
let strategy = if dispatch_to_einsum_with_base_ops {
EinsumStrategy::BaseOps
} else {
EinsumStrategy::Freivalds
};
Ok(SingleEquationAnalysis {
output_size,
max_input_size,
equation: equation.to_string(),
num_inputs: input_equations.len(),
num_output_axes: output_indices.len(),
output_indices,
reduction_length: output_reduction_length + input_reductions_length,
strategy,
})
}

View File

@@ -0,0 +1,54 @@
use std::{collections::HashMap, marker::PhantomData};
use halo2_proofs::circuit::Value;
use halo2curves::ff::PrimeField;
use crate::{
circuit::CircuitError,
tensor::{Tensor, TensorError, TensorType},
};
/// Circuit parameter for a single einsum equation
#[derive(Clone, Debug, Default)]
pub struct SingleEinsumParams<F: PrimeField + TensorType + PartialOrd> {
///
pub equation: String,
/// Map from input axes to dimensions
pub input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> SingleEinsumParams<F> {
///
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}

View File

@@ -0,0 +1,359 @@
use halo2curves::ff::PrimeField;
use log::{error, trace};
use crate::{
circuit::{base::BaseOp, einsum::BaseOpInfo, region::RegionCtx, CheckMode, CircuitError},
tensor::{
get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
TensorError, TensorType, ValTensor, ValType,
},
};
use super::ContractionConfig;
/// Pairwise (elementwise) op layout
pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 2],
op: BaseOp,
phases: &[usize; 2],
) -> Result<ValTensor<F>, CircuitError> {
let (mut lhs, mut rhs) = if phases[0] <= phases[1] {
(values[0].clone(), values[1].clone())
} else {
(values[1].clone(), values[0].clone())
};
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
lhs.expand(&broadcasted_shape)?;
rhs.expand(&broadcasted_shape)?;
if lhs.len() != rhs.len() {
return Err(CircuitError::DimMismatch(format!(
"pairwise {} layout",
op.as_str()
)));
}
region.flush_einsum()?;
let input_vars = config.get_input_vars(phases.as_slice().into());
let output_var = config.get_output_var(phases.as_slice().into());
let inputs = [lhs, rhs]
.iter()
.zip(input_vars)
.map(|(val, var)| {
let res = region.assign_einsum(var, val)?;
Ok(res.get_inner()?)
})
.collect::<Result<Vec<_>, CircuitError>>()?;
// Now we can assign the dot product
// time the calc
let op_result = match op {
BaseOp::Add => add(&inputs),
BaseOp::Sub => sub(&inputs),
BaseOp::Mult => mult(&inputs),
_ => return Err(CircuitError::UnsupportedOp),
}
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let assigned_len = op_result.len();
let mut output = region.assign_einsum(output_var, &op_result.into())?;
// Enable the selectors
if !region.is_dummy() {
(0..assigned_len)
.map(|i| {
let (x, y, z) = output_var.cartesian_coord(region.einsum_col_coord() + i);
let op_info = BaseOpInfo {
op_kind: op.clone(),
input_phases: phases.as_slice().into(),
};
let selector = config.selectors.get(&(op_info, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
region.increment_einsum_col_coord(assigned_len);
output.reshape(&broadcasted_shape)?;
Ok(output)
}
pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 1],
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() == 1 {
return Ok(values[0].clone());
}
assert!(phase == 0 || phase == 1);
region.flush_einsum()?;
let mut input = values[0].clone();
let block_width = config.block_width();
let assigned_len: usize;
let input = {
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let var = config.get_input_vars([phase].as_slice().into())[0];
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
res.get_inner()?
};
// Now we can assign the dot product
let accumulated_sum = accumulated::sum(&input, block_width)?;
let output_var = config.get_output_var([phase].as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
output_var,
&accumulated_sum.into(),
check_mode,
)?;
// enable the selectors
if !region.is_dummy() {
for i in 0..output_assigned_len {
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// skip over duplicates at start of column
if z == 0 && i > 0 {
continue;
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::SumInit,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::Sum,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
}
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
// last element is the result
Ok(last_elem)
}
pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 1],
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert!(phase == 0 || phase == 1);
region.flush_einsum()?;
let block_width = config.block_width();
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
let var = config.get_input_vars([phase].as_slice().into())[0];
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
res.get_inner()?
};
// Now we can assign the dot product
let accumulated_prod = accumulated::prod(&input, block_width)?;
let output_var = config.get_output_var([phase].as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
output_var,
&accumulated_prod.into(),
check_mode,
)?;
// enable the selectors
if !region.is_dummy() {
(0..output_assigned_len)
.map(|i| {
let (x, _, z) =
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// skip over duplicates at start of column
if z == 0 && i > 0 {
return Ok(());
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::CumProdInit,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::CumProd,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
// last element is the result
Ok(last_elem)
}
pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 2],
phases: &[usize; 2],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
region.flush_einsum()?;
// time this entire function run
let global_start = instant::Instant::now();
let mut values = if phases[0] <= phases[1] {
[values[0].clone(), values[1].clone()]
} else {
[values[1].clone(), values[0].clone()]
};
let vars = config.get_input_vars(phases.as_slice().into());
let mut inputs = vec![];
let block_width = config.block_width();
let mut assigned_len = 0;
for (val, var) in values.iter_mut().zip(vars) {
val.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let inp = {
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?;
assigned_len = len;
res.get_inner()?
};
inputs.push(inp);
}
// Now we can assign the dot product
// time this step
let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?;
let output_var = config.get_output_var(phases.as_slice().into());
let (output, output_assigned_len) = region
.assign_einsum_with_duplication_constrained(output_var, &accumulated_dot.into(), check_mode)
.expect("failed to assign einsum with duplication constrained");
// enable the selectors
if !region.is_dummy() {
(0..output_assigned_len)
.map(|i| {
let (x, _, z) =
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// hop over duplicates at start of column
if z == 0 && i > 0 {
return Ok(());
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::DotInit,
input_phases: phases.as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::Dot,
input_phases: phases.as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
let elapsed = global_start.elapsed();
trace!("dot layout took: {:?}, row {}", elapsed, region.row());
trace!("----------------------------");
Ok(last_elem)
}
/// Dot product of more than two tensors
pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
phases: &[usize],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert!(phases.iter().all(|phase| *phase == 0 || *phase == 1));
if !values.iter().all(|value| value.len() == values[0].len()) {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
// time this entire function run
let global_start = instant::Instant::now();
let values: Vec<ValTensor<F>> = values.iter().copied().cloned().collect();
// do pairwise dot product between intermediate tensor and the next tensor
let (intermediate, output_phase) = values
.into_iter()
.zip(phases.iter().cloned())
.reduce(|(intermediate, intermediate_phase), (input, phase)| {
(
pairwise(
config,
region,
&[&intermediate, &input],
BaseOp::Mult,
&[intermediate_phase, phase],
)
.unwrap(),
std::cmp::max(intermediate_phase, phase),
)
})
.unwrap();
let accumulated_dot = sum(config, region, &[&intermediate], output_phase, check_mode)?;
let last_elem = accumulated_dot.last()?;
let elapsed = global_start.elapsed();
trace!("multi_dot layout took: {:?}, row {}", elapsed, region.row());
trace!("----------------------------");
Ok(last_elem)
}

View File

@@ -0,0 +1,867 @@
use crate::circuit::base::BaseOp;
use crate::circuit::chip::einsum::analysis::{analyze_single_equation, EinsumAnalysis};
use crate::circuit::einsum::layouts::{pairwise, sum};
use crate::circuit::einsum::reduction_planner::Reduction;
use crate::circuit::layouts::einsum_with_base_ops;
use crate::circuit::region::RegionCtx;
use crate::circuit::{BaseConfig, CheckMode, CircuitError};
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
Challenge, ConstraintSystem, Constraints, Expression, FirstPhase, Selector,
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use layouts::{dot, multi_dot, prod};
use std::collections::{BTreeMap, HashMap};
use std::marker::PhantomData;
///
pub mod analysis;
///
pub mod circuit_params;
mod layouts;
mod reduction_planner;
/// The maximum number of challenges
pub const NUM_MAX_EINSUM_CHALLENGES: usize = 10;
/// A struct representing reductions for the einsums
#[derive(Clone, Debug, Default)]
pub struct Einsums<F: PrimeField + TensorType + PartialOrd> {
/// custom gate to constrain tensor contractions
contraction_gate: ContractionConfig<F>,
/// custom gate to constrain random linear combinations used by Freivalds' argument
rlc_gates: Vec<RLCConfig<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Einsums<F> {
///
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let dummy_contraction_gate = ContractionConfig {
inputs: [
[dummy_var.clone(), dummy_var.clone()],
[dummy_var.clone(), dummy_var.clone()],
],
outputs: [dummy_var.clone(), dummy_var.clone()],
selectors: BTreeMap::default(),
_marker: PhantomData,
};
Self {
contraction_gate: dummy_contraction_gate,
rlc_gates: (0..NUM_MAX_EINSUM_CHALLENGES)
.map(|_| RLCConfig::dummy(&dummy_var))
.collect(),
}
}
/// Configure the columns based on universal Einsum analysis
pub fn configure_universal(
meta: &mut ConstraintSystem<F>,
analysis: &EinsumAnalysis,
num_inner_cols: usize,
logrows: usize,
) -> Self {
let capacity = analysis.reduction_length;
let inputs: [VarTensor; 4] = [
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
];
let outputs = [
VarTensor::new_advice(meta, logrows, num_inner_cols, capacity),
VarTensor::new_advice_in_second_phase(meta, logrows, num_inner_cols, capacity),
];
let contraction_gate = ContractionConfig::new(
meta,
&[&[&inputs[0], &inputs[1]], &[&inputs[2], &inputs[3]]],
&[&outputs[0], &outputs[1]],
);
let mut rlc_gates = vec![];
for _ in 0..analysis.max_num_output_axes {
let rlc_gate =
RLCConfig::new(meta, &[inputs[0].clone(), inputs[2].clone()], &outputs[1]);
rlc_gates.push(rlc_gate);
}
Self {
contraction_gate,
rlc_gates,
}
}
/// In dummy layout phase, calling this function will return error
pub fn challenges(&self) -> Result<Vec<Challenge>, CircuitError> {
self.rlc_gates
.iter()
.map(|gate| gate.challenge.ok_or(CircuitError::ChallengeNotSet))
.collect::<Result<Vec<_>, _>>()
}
///
pub fn assign_einsum(
&self,
base_config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input_tensors: &[&ValTensor<F>],
output_tensor: &ValTensor<F>,
equation: &str,
check_mode: &CheckMode,
) -> Result<(), CircuitError> {
region.set_num_einsum_inner_cols(self.contraction_gate.block_width());
let (input_exprs, _) = equation.split_once("->").unwrap();
let input_exprs = input_exprs.split(",").collect_vec();
assert_eq!(input_exprs.len(), input_tensors.len());
let mut input_tensors = input_tensors.iter().copied().cloned().collect_vec();
let mut output_tensor = output_tensor.clone();
let mut input_axes_to_dim: HashMap<char, usize> = HashMap::new();
input_exprs
.iter()
.zip(input_tensors.iter())
.for_each(|(indices, tensor)| {
indices.chars().zip(tensor.dims()).for_each(|(index, dim)| {
if let std::collections::hash_map::Entry::Vacant(e) =
input_axes_to_dim.entry(index)
{
e.insert(*dim);
}
});
});
let equation_analysis = analyze_single_equation(&equation, &input_axes_to_dim)?;
let equation = equation_analysis.equation;
// Remove trivial axes from tensors
input_tensors
.iter_mut()
.map(|tensor| tensor.remove_trivial_axes())
.collect::<Result<Vec<_>, TensorError>>()?;
output_tensor.remove_trivial_axes()?;
if matches!(
equation_analysis.strategy,
analysis::EinsumStrategy::BaseOps
) {
let _ = einsum_with_base_ops(
base_config,
region,
&input_tensors.iter().collect_vec(),
&equation,
)?;
return Ok(());
}
let output_shape = equation_analysis
.output_indices
.iter()
.map(|c| input_axes_to_dim.get(c).copied().unwrap())
.collect_vec();
let squashed_output =
self.assign_output(region, &output_tensor, output_shape, check_mode)?;
// reorder the reduction of input tensors and reduce
let reordered_input_reductions = reduction_planner::input_reductions(&equation).unwrap();
let mut tensors = input_tensors;
let mut reduced_input_phase = 0;
for reduction in reordered_input_reductions.iter() {
let (input_expr, output_expr) = reduction.expression().split_once("->").unwrap();
let input_exprs = input_expr.split(",").collect_vec();
let remaining_axes = output_expr.chars().collect_vec();
let mut remaining_axes_indices = remaining_axes
.iter()
.map(|c| 0..input_axes_to_dim[c])
.multi_cartesian_product()
.collect_vec();
// Dummy value to ensure the for loop runs at least once
if remaining_axes.is_empty() {
remaining_axes_indices.push(vec![]);
}
let input_tensors = reduction
.input_indices()
.iter()
.map(|idx| tensors[*idx].clone())
.collect_vec();
let mut flattened_input_tensors: Vec<Vec<ValTensor<F>>> =
vec![vec![]; input_tensors.len()];
for remaining_axes_indices in remaining_axes_indices {
// corresponds to 1 running sum of input tensors
for (i, (input_tensor, input_expr)) in
input_tensors.iter().zip(input_exprs.iter()).enumerate()
{
let mut sliced_dim = vec![];
input_expr.chars().for_each(|axis| {
if let Some(pos) = remaining_axes.iter().position(|c| *c == axis) {
sliced_dim
.push(remaining_axes_indices[pos]..remaining_axes_indices[pos] + 1);
} else {
// common axis
sliced_dim.push(0..input_axes_to_dim[&axis]);
}
});
let mut sliced_input_tensor = input_tensor.get_slice(&sliced_dim)?;
sliced_input_tensor.flatten();
flattened_input_tensors[i].push(sliced_input_tensor);
}
}
let flattened_input_tensors = flattened_input_tensors
.into_iter()
.map(|tensors| {
ValTensor::from(
tensors
.into_iter()
.flat_map(|t| t.get_inner_tensor().unwrap().clone().into_iter())
.collect_vec(),
)
})
.collect_vec();
let output_dims = output_expr
.chars()
.map(|c| input_axes_to_dim[&c])
.collect_vec();
let contracted_output = match reduction {
Reduction::RLC {
axis,
input_phase,
challenge_index,
..
} => {
assert_eq!(flattened_input_tensors.len(), 1);
let rlc_len = input_axes_to_dim[axis];
let mut result = self.rlc_gates[*challenge_index].assign_rlc(
region,
&flattened_input_tensors[0],
region.challenges()[*challenge_index],
rlc_len,
*input_phase,
check_mode,
)?;
result.reshape(&output_dims)?;
result
}
Reduction::Contraction {
axis, input_phases, ..
} => match axis {
Some(axis) => {
let dot_product_len = input_axes_to_dim[axis];
assign_input_contraction(
&self.contraction_gate,
region,
flattened_input_tensors,
dot_product_len,
&output_dims,
input_phases,
check_mode,
)?
}
None => {
let mut result = assign_pairwise_mult(
&self.contraction_gate,
region,
flattened_input_tensors,
input_phases,
)?;
result.reshape(&output_dims)?;
result
}
},
};
tensors.push(contracted_output);
reduced_input_phase = reduction.output_phase();
}
tensors.retain(|tensor| tensor.is_singleton());
let scalars: ValTensor<F> = tensors
.into_iter()
.map(|t| t.get_inner_tensor().unwrap().get_scalar())
.collect_vec()
.into();
let squashed_input = prod(
&self.contraction_gate,
region,
&[&scalars],
reduced_input_phase,
check_mode,
)?;
region.constrain_equal(&squashed_input, &squashed_output)
}
fn assign_output(
&self,
region: &mut RegionCtx<F>,
output: &ValTensor<F>,
output_shape: Vec<usize>,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
let mut intermediate_values = output.clone();
let challenges = region
.challenges()
.iter()
.take(output_shape.len())
.copied()
.collect_vec();
// Loop over the output axes
for (idx, (rlc_config, challenge)) in self
.rlc_gates
.iter()
.take(output_shape.len())
.zip(challenges.iter())
.rev()
.enumerate()
{
let rlc_len = output_shape[output_shape.len() - idx - 1];
intermediate_values.flatten();
let phase = if idx > 0 { 1 } else { 0 };
intermediate_values = rlc_config.assign_rlc(
region,
&intermediate_values,
*challenge,
rlc_len,
phase,
check_mode,
)?;
}
let phase = if challenges.len() > 0 { 1 } else { 0 };
let output_var = self
.contraction_gate
.get_output_var([phase].as_slice().into());
let res = region.assign_einsum(output_var, &intermediate_values)?;
region.increment_einsum_col_coord(1);
Ok(res)
}
}
fn assign_pairwise_mult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
flattened_tensors: Vec<ValTensor<F>>,
input_phases: &[usize],
) -> Result<ValTensor<F>, CircuitError> {
assert_eq!(flattened_tensors.len(), input_phases.len());
let (result, _) = flattened_tensors
.into_iter()
.zip(input_phases.iter().cloned())
.reduce(|(acc, acc_phase), (input, phase)| {
(
pairwise(
config,
region,
&[&acc, &input],
BaseOp::Mult,
&[acc_phase, phase],
)
.unwrap(),
std::cmp::max(acc_phase, phase),
)
})
.unwrap();
Ok(result)
}
fn assign_input_contraction<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
flattened_tensors: Vec<ValTensor<F>>,
dot_product_len: usize,
output_shape: &[usize],
input_phases: &[usize],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert_eq!(flattened_tensors.len(), input_phases.len());
let num_dot_products = output_shape.iter().product();
let mut dot_product_results = vec![];
for chunk_idx in 0..num_dot_products {
let start = chunk_idx * dot_product_len;
let tensors: Vec<_> = flattened_tensors
.iter()
.map(|tensor| tensor.get_slice(&[start..(start + dot_product_len)]))
.collect::<Result<Vec<_>, TensorError>>()?;
let result = if tensors.len() == 1 {
sum(config, region, &[&tensors[0]], input_phases[0], check_mode)?
} else if tensors.len() == 2 {
dot(
config,
region,
&[&tensors[0], &tensors[1]],
&[input_phases[0], input_phases[1]],
check_mode,
)?
} else {
multi_dot(
config,
region,
tensors.iter().collect_vec().as_slice(),
input_phases,
check_mode,
)?
};
dot_product_results.push(result.get_inner_tensor()?.get_scalar());
}
let mut tensor = ValTensor::from(dot_product_results);
tensor.reshape(output_shape)?;
Ok(tensor)
}
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Hash)]
enum InputPhases {
FirstPhase,
SecondPhase,
BothFirstPhase, // [0, 0]
Mixed, // [0, 1] or [1, 0]
BothSecondPhase, // [1, 1]
}
impl From<&[usize]> for InputPhases {
fn from(phases: &[usize]) -> Self {
match phases {
[0] => Self::FirstPhase,
[1] => Self::SecondPhase,
[0, 0] => Self::BothFirstPhase,
[0, 1] | [1, 0] => Self::Mixed,
[1, 1] => Self::BothSecondPhase,
_ => panic!("Invalid phase combination"),
}
}
}
#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
struct BaseOpInfo {
pub op_kind: BaseOp,
pub input_phases: InputPhases,
}
/// `ContractionConfig` is the custom gate to constrain tensor contractions
#[derive(Clone, Debug, Default)]
struct ContractionConfig<F: PrimeField + TensorType + PartialOrd> {
// [[first phase, first phase], [second phase, second phase]]
inputs: [[VarTensor; 2]; 2],
// [first phase, second phase]
outputs: [VarTensor; 2],
// (BaseOpInfo, block index, inner column index) -> selector
selectors: BTreeMap<(BaseOpInfo, usize, usize), Selector>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> ContractionConfig<F> {
fn get_input_vars(&self, input_phases: InputPhases) -> Vec<&VarTensor> {
match input_phases {
InputPhases::FirstPhase => vec![&self.inputs[0][0]],
InputPhases::SecondPhase => vec![&self.inputs[1][0]],
InputPhases::BothFirstPhase => vec![&self.inputs[0][0], &self.inputs[0][1]],
InputPhases::Mixed => vec![&self.inputs[0][0], &self.inputs[1][0]],
InputPhases::BothSecondPhase => vec![&self.inputs[1][0], &self.inputs[1][1]],
}
}
fn get_output_var(&self, input_phases: InputPhases) -> &VarTensor {
match input_phases {
InputPhases::FirstPhase => &self.outputs[0],
InputPhases::SecondPhase => &self.outputs[1],
InputPhases::BothFirstPhase => &self.outputs[0],
InputPhases::Mixed => &self.outputs[1],
InputPhases::BothSecondPhase => &self.outputs[1],
}
}
fn block_width(&self) -> usize {
self.outputs[0].num_inner_cols()
}
fn new(
meta: &mut ConstraintSystem<F>,
inputs: &[&[&VarTensor; 2]; 2],
outputs: &[&VarTensor; 2],
) -> Self {
let mut selectors = BTreeMap::new();
let num_blocks = outputs[0].num_blocks();
let block_width = outputs[0].num_inner_cols();
for input_phases in [
InputPhases::BothFirstPhase,
InputPhases::Mixed,
InputPhases::BothSecondPhase,
] {
for i in 0..num_blocks {
for j in 0..block_width {
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Mult,
input_phases,
},
i,
j,
),
meta.selector(),
);
}
for i in 0..num_blocks {
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::DotInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Dot,
input_phases,
},
i,
0,
),
meta.selector(),
);
}
}
}
for input_phases in [InputPhases::FirstPhase, InputPhases::SecondPhase] {
for i in 0..num_blocks {
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::SumInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::Sum,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::CumProdInit,
input_phases,
},
i,
0,
),
meta.selector(),
);
selectors.insert(
(
BaseOpInfo {
op_kind: BaseOp::CumProd,
input_phases,
},
i,
0,
),
meta.selector(),
);
}
}
for ((base_op, block_idx, inner_col_idx), selector) in selectors.iter() {
let inputs = match base_op.input_phases {
InputPhases::FirstPhase => vec![inputs[0][0]],
InputPhases::SecondPhase => vec![inputs[1][0]],
InputPhases::BothFirstPhase => vec![inputs[0][0], inputs[0][1]],
InputPhases::Mixed => vec![inputs[0][0], inputs[1][0]],
InputPhases::BothSecondPhase => vec![inputs[1][0], inputs[1][1]],
};
let output = match base_op.input_phases {
InputPhases::FirstPhase => outputs[0],
InputPhases::SecondPhase => outputs[1],
InputPhases::BothFirstPhase => outputs[0],
InputPhases::Mixed => outputs[1],
InputPhases::BothSecondPhase => outputs[1],
};
assert_eq!(inputs.len(), base_op.op_kind.num_inputs());
match base_op.op_kind {
BaseOp::Mult => {
meta.create_gate(base_op.op_kind.as_str(), |meta| {
let selector = meta.query_selector(*selector);
let zero = Expression::<F>::Constant(F::ZERO);
let mut qis = vec![zero; 2];
for (q_i, input) in qis.iter_mut().zip(inputs) {
*q_i = input
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("contraction config: input query failed")[0]
.clone()
}
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("contraction config: output query failed");
let res = base_op.op_kind.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res]
};
Constraints::with_selector(selector, constraints)
});
}
_ => {
meta.create_gate(base_op.op_kind.as_str(), |meta| {
let selector = meta.query_selector(*selector);
let mut qis = vec![vec![]; 2];
for (q_i, input) in qis.iter_mut().zip(inputs) {
*q_i = input
.query_whole_block(meta, *block_idx, 0, 1)
.expect("contraction config: input query failed")
.into_iter()
.collect()
}
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.op_kind.query_offset_rng();
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, 0, rotation_offset, rng)
.expect("contraction config: output query failed");
let res = base_op.op_kind.accum_f(
expected_output[0].clone(),
qis[1].clone(),
qis[0].clone(),
);
let constraints =
vec![expected_output[base_op.op_kind.constraint_idx()].clone() - res];
Constraints::with_selector(selector, constraints)
});
}
}
}
let first_phase_inputs: [VarTensor; 2] = inputs[0]
.iter()
.copied()
.cloned()
.collect_vec()
.try_into()
.unwrap();
let second_phase_inputs: [VarTensor; 2] = inputs[1]
.iter()
.copied()
.cloned()
.collect_vec()
.try_into()
.unwrap();
Self {
inputs: [first_phase_inputs, second_phase_inputs],
outputs: [outputs[0].clone(), outputs[1].clone()],
selectors,
_marker: PhantomData,
}
}
}
/// `RLCConfig` is the custom gate used for random linear combination with the specific challenge
#[derive(Clone, Debug)]
struct RLCConfig<F: PrimeField + TensorType + PartialOrd> {
/// The challenge used for the random linear combination
/// Challenge is `None` in the dummy configuration
pub challenge: Option<Challenge>,
/// [first phase, second phase]
pub inputs: [VarTensor; 2],
pub output: VarTensor,
/// (phase of input, block index) -> (init selector, acc selector)
pub selectors: BTreeMap<(usize, usize), (Selector, Selector)>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RLCConfig<F> {
fn dummy(dummy_var: &VarTensor) -> Self {
let challenge = None;
let inputs = [dummy_var.clone(), dummy_var.clone()];
let output = dummy_var.clone();
let selectors = BTreeMap::new();
Self {
challenge,
inputs,
output,
selectors,
_marker: PhantomData,
}
}
fn new(meta: &mut ConstraintSystem<F>, inputs: &[VarTensor; 2], output: &VarTensor) -> Self {
let challenge = meta.challenge_usable_after(FirstPhase);
let mut selectors = BTreeMap::new();
for (phase, input) in inputs.iter().enumerate() {
for block_idx in 0..input.num_blocks() {
let selector = (meta.selector(), meta.selector());
selectors.insert((phase, block_idx), selector);
}
}
let block_width = output.num_inner_cols();
let powers_of_challenge = (0..block_width)
.scan(Expression::Constant(F::ONE), |r_power, _| {
*r_power = r_power.clone() * challenge.expr();
Some(r_power.clone())
})
.collect_vec();
for ((phase, block_idx), (init_selector, acc_selector)) in selectors.iter() {
meta.create_gate("init", |meta| {
let selector = meta.query_selector(*init_selector);
let input_exprs = inputs[*phase]
.query_whole_block(meta, *block_idx, 0, 1)
.expect("rlc config: input query failed")
.into_iter()
.collect();
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, 0, 0, 1)
.expect("rlc config: output query failed");
let res = BaseOp::Dot.accum_f(
Expression::Constant(F::ZERO),
powers_of_challenge.iter().cloned().rev().collect_vec(),
input_exprs,
);
vec![expected_output[0].clone() - res]
};
Constraints::with_selector(selector, constraints)
});
meta.create_gate("acc", |meta| {
let selector = meta.query_selector(*acc_selector);
let input_exprs = inputs[*phase]
.query_whole_block(meta, *block_idx, 0, 1)
.expect("rlc config: input query failed")
.into_iter()
.collect();
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, 0, -1, 2)
.expect("rlc config: output query failed");
let res = BaseOp::Dot.accum_f(
expected_output[0].clone() * powers_of_challenge.last().cloned().unwrap(),
powers_of_challenge.iter().cloned().rev().collect_vec(),
input_exprs,
);
vec![expected_output[1].clone() - res]
};
Constraints::with_selector(selector, constraints)
});
}
Self {
inputs: inputs.clone(),
output: output.clone(),
selectors,
challenge: Some(challenge),
_marker: PhantomData,
}
}
fn assign_rlc(
&self,
region: &mut RegionCtx<F>,
flattened_input: &ValTensor<F>,
challenge: Value<F>,
rlc_len: usize,
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
region.flush_einsum()?;
let block_width = self.output.num_inner_cols();
let powers_of_challenge = (0..block_width)
.scan(Value::known(F::ONE), |challenge_power, _| {
*challenge_power = challenge_power.clone() * challenge;
Some(challenge_power.clone())
})
.collect_vec();
let mut rlc_results: Vec<ValType<F>> = vec![];
for tensor in flattened_input.get_inner_tensor()?.chunks_exact(rlc_len) {
let running_sums = tensor
.iter()
.chunks(block_width)
.into_iter()
.scan(Value::known(F::ZERO), |state, val| {
let curr_sum: Value<F> = val
.into_iter()
.zip(powers_of_challenge.iter().rev())
.map(|(v, c_power)| {
c_power.and_then(|c_power| {
v.get_felt_eval()
.and_then(|v| Some(Value::known(c_power * v)))
.unwrap_or(Value::unknown())
})
})
.reduce(|acc, v| acc + v)
.unwrap();
*state = *state * powers_of_challenge.last().unwrap() + curr_sum;
Some(*state)
})
.collect_vec();
let assigned_len = {
let mut input: ValTensor<F> = tensor.iter().collect_vec().into();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let (_, len) = region
.assign_einsum_with_duplication_unconstrained(&self.inputs[phase], &input)?;
len
};
let (assigned_output, assigned_output_len) = {
let running_sums = running_sums.into_iter().map(ValType::from).collect_vec();
region.assign_einsum_with_duplication_constrained(
&self.output,
&running_sums.into(),
check_mode,
)?
};
(0..assigned_output_len)
.map(|i| {
let (block_idx, _, z) = self
.output
.cartesian_coord(region.einsum_col_coord() + i * block_width);
if z == 0 && i > 0 {
return Ok(());
}
let selector = if i == 0 {
self.selectors
.get(&(phase, block_idx))
.map(|(init, _)| init)
} else {
self.selectors.get(&(phase, block_idx)).map(|(_, acc)| acc)
};
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
rlc_results.push(assigned_output.last()?.get_inner_tensor()?.get_scalar());
region.increment_einsum_col_coord(assigned_len);
}
Ok(rlc_results.into())
}
}

View File

@@ -0,0 +1,205 @@
use std::{collections::BTreeSet, ops::Index};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use crate::{
circuit::CircuitError,
tensor::{TensorType, ValTensor},
};
/// inj,jk->ik [inj,jk]
/// inj,i->nj => RLC [jk,nj]
/// jk,k->j => RLC [nj,j]
/// nj,j->n => Contraction [n]
/// n-> => Contraction []
///
/// bn,anm,bm->ba [bn,anm,bm]
/// bn,bm->bnm => Contraction [anm,bnm]
/// bnm,b->nm => RLC [anm,nm]
/// anm,a->nm => RLC [nm,nm]
/// nm,nm->m => Contraction [m]
/// m-> => Contraction []
#[derive(Debug)]
pub enum Reduction {
/// Random linear combination with powers of challenge along the axis
RLC {
expression: String,
axis: char,
/// Uniquely identifying index of input tensor to be reduced
input_index: TensorIndex,
/// phase of input tensor
input_phase: usize,
/// phase of output tensor
output_phase: usize,
challenge_index: usize,
},
Contraction {
expression: String,
/// when axis is `None`, the contraction is pairwise multiplication
axis: Option<char>,
/// Uniquely identifying indices of input tensors to be contracted
input_indices: Vec<TensorIndex>,
/// phases of input tensors
input_phases: Vec<usize>,
/// phase of output tensor
output_phase: usize,
},
}
#[derive(Clone, Copy, Debug)]
pub struct TensorIndex(usize);
impl<T: PrimeField + TensorType + PartialOrd> Index<TensorIndex> for Vec<ValTensor<T>> {
type Output = ValTensor<T>;
fn index(&self, index: TensorIndex) -> &Self::Output {
&self[index.0]
}
}
impl Reduction {
pub fn expression(&self) -> &str {
match self {
Reduction::Contraction { expression, .. } => expression,
Reduction::RLC { expression, .. } => &expression,
}
}
pub fn input_indices(&self) -> Vec<TensorIndex> {
match self {
Reduction::Contraction { input_indices, .. } => input_indices.clone(),
Reduction::RLC { input_index, .. } => vec![*input_index],
}
}
pub fn output_phase(&self) -> usize {
match self {
Reduction::Contraction { output_phase, .. } => *output_phase,
Reduction::RLC { output_phase, .. } => *output_phase,
}
}
}
pub fn input_reductions(expression: &str) -> Result<Vec<Reduction>, CircuitError> {
let (input_exprs, output_expr) = expression.split_once("->").unwrap();
let input_exprs: Vec<_> = input_exprs.split(",").map(|eq| eq.to_string()).collect();
// (phase, expression)
let input_exprs: Vec<(usize, String)> =
input_exprs.into_iter().map(|expr| (0, expr)).collect_vec();
let mut input_tensor_counter = input_exprs.len();
let mut input_exprs: Vec<((usize, String), TensorIndex)> = input_exprs
.into_iter()
.zip((0..input_tensor_counter).map(TensorIndex))
.collect();
let mut reductions: Vec<Reduction> = vec![];
// Reduce input_exprs along given axis
let mut reduce = |input_exprs: Vec<((usize, String), TensorIndex)>,
axis: char|
-> (Reduction, Vec<((usize, String), TensorIndex)>) {
let inputs = input_exprs
.iter()
.filter(|((_, eq), _)| eq.chars().contains(&axis))
.cloned()
.collect_vec();
let (inputs_axes, input_indices): (Vec<(usize, String)>, Vec<TensorIndex>) =
inputs.iter().cloned().unzip();
let (input_phases, inputs_axes): (Vec<usize>, Vec<String>) =
inputs_axes.into_iter().unzip();
let is_output_axis = output_expr.chars().contains(&axis);
let output: String = if is_output_axis == true && inputs.len() > 1 {
let output: BTreeSet<char> =
inputs_axes.iter().flat_map(|input| input.chars()).collect();
output.iter().collect()
} else {
let output: BTreeSet<char> = inputs_axes
.iter()
.flat_map(|input| input.chars().filter(|&c| c != axis))
.collect();
output.iter().collect()
};
let reduction = if is_output_axis == true && inputs.len() == 1 {
let mut expression = inputs_axes.join(",");
expression.push_str(format!(",{axis}").as_str());
expression.push_str("->");
expression.push_str(&output);
Reduction::RLC {
expression,
axis,
input_index: input_indices[0],
input_phase: input_phases[0],
output_phase: 1,
challenge_index: output_expr.chars().position(|c| c == axis).unwrap(),
}
} else if is_output_axis == true {
let mut expression = inputs_axes.join(",");
let output_phase = input_phases.iter().copied().max().unwrap();
expression.push_str("->");
expression.push_str(&output);
Reduction::Contraction {
expression,
axis: None,
input_indices: input_indices,
input_phases,
output_phase,
}
} else {
let mut expression = inputs_axes.join(",");
let output_phase = input_phases.iter().copied().max().unwrap();
expression.push_str("->");
expression.push_str(&output);
Reduction::Contraction {
expression,
axis: Some(axis),
input_indices: input_indices,
input_phases,
output_phase,
}
};
// Mutate input_exprs
let mut input_exprs = input_exprs.clone();
input_exprs.retain(|((_, input_eq), _)| !inputs_axes.contains(input_eq));
input_exprs.push((
(reduction.output_phase(), output.clone()),
TensorIndex(input_tensor_counter),
));
input_tensor_counter += 1;
(reduction, input_exprs)
};
let mut output_axes = output_expr.chars().collect_vec();
while let Some(axis) = output_axes.first().cloned() {
let num_inputs = input_exprs
.iter()
.filter(|((_, eq), _)| eq.chars().contains(&axis))
.count();
if num_inputs == 0 {
output_axes.remove(0);
} else {
let (reduction, new_input_exprs) = reduce(input_exprs, axis);
reductions.push(reduction);
input_exprs = new_input_exprs;
}
}
// These are not output axes and were not contracted with random vectors
let remaining_axes: BTreeSet<_> = input_exprs
.iter()
.flat_map(|((_, eq), _)| eq.chars())
.collect();
for axis in remaining_axes.iter() {
let (reduction, new_input_exprs) = reduce(input_exprs, *axis);
reductions.push(reduction);
input_exprs = new_input_exprs;
}
Ok(reductions)
}

View File

@@ -46,6 +46,9 @@ pub enum CircuitError {
/// Failed to get shuffle
#[error("failed to get shuffle for op: {0}")]
GetShuffleError(String),
/// Failed to get einsum
#[error("failed to get einsum for op: {0}")]
GetEinsumError(String),
/// Failed to get constants
#[error("failed to get constants for op: {0}")]
GetConstantsError(String),
@@ -61,6 +64,9 @@ pub enum CircuitError {
/// Missing product in einsum
#[error("missing product in einsum")]
MissingEinsumProduct,
/// Missing config in einsum
#[error("missing config in einsum")]
MissingEinsumConfig,
/// Mismatched lookup length
#[error("mismatched lookup lengths: {0} and {1}")]
MismatchedLookupLength(usize, usize),
@@ -109,4 +115,7 @@ pub enum CircuitError {
/// A decomposition base overflowed
#[error("decomposition base overflowed")]
DecompositionBaseOverflow,
/// Challenge not set
#[error("challenge not set")]
ChallengeNotSet,
}

View File

@@ -22,9 +22,8 @@ use crate::{
tensor::{
create_unit_tensor, get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
Tensor, TensorError, ValType,
DataFormat, KernelFormat, Tensor, TensorError, ValType,
},
tensor::{DataFormat, KernelFormat},
};
use super::*;
@@ -823,14 +822,73 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// let result = einsum::<Fp>(&dummy_config, &mut dummy_region, &[&x, &k], "mk,n->ma").unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1773]), &[1, 1]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
///
/// ```
///
pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
inputs: &[&ValTensor<F>],
equation: &str,
) -> Result<ValTensor<F>, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut indices_to_size = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = indices_to_size.entry(c) {
e.insert(input.dims()[j]);
} else if indices_to_size[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
// Track the einsum equation
region.add_used_einsum_equation(equation.to_string(), &indices_to_size)?;
if config.einsums.is_none() {
return einsum_with_base_ops(config, region, inputs, equation);
}
let input_values = inputs
.iter()
.map(|t| t.get_inner())
.collect::<Result<Vec<_>, TensorError>>()?;
let (output_tensor, _) =
crate::tensor::ops::accumulated::einsum(equation, &input_values.iter().collect_vec())?;
config.einsums.as_ref().unwrap().assign_einsum(
config,
region,
inputs,
&output_tensor.clone().into(),
equation,
&config.check_mode,
)?;
let output: ValTensor<F> = output_tensor.into();
region.increment_einsum_index(1);
Ok(output)
}
/// einsum with base ops
pub fn einsum_with_base_ops<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
inputs: &[&ValTensor<F>],
equation: &str,
) -> Result<ValTensor<F>, CircuitError> {
let mut equation = equation.split("->");
let inputs_eq = equation.next().ok_or(CircuitError::InvalidEinsum)?;
@@ -1036,6 +1094,8 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let output: ValTensor<F> = output.into();
region.increment_einsum_index(1);
Ok(output)
}
@@ -6169,9 +6229,9 @@ pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
(0..num_first_dims)
.flat_map(|_| {
(0..n).rev().map(|x| {
let base = (*base).checked_pow(x as u32);
let base = (*base as IntegerRep).checked_pow(x as u32);
if let Some(base) = base {
Ok(ValType::Constant(integer_rep_to_felt(base as IntegerRep)))
Ok(ValType::Constant(integer_rep_to_felt(base)))
} else {
Err(CircuitError::DecompositionBaseOverflow)
}
@@ -6281,9 +6341,9 @@ pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Has
(0..input.len())
.flat_map(|_| {
(0..*n).rev().map(|x| {
let base = (*base).checked_pow(x as u32);
let base = (*base as IntegerRep).checked_pow(x as u32);
if let Some(base) = base {
Ok(ValType::Constant(integer_rep_to_felt(base as IntegerRep)))
Ok(ValType::Constant(integer_rep_to_felt(base)))
} else {
Err(CircuitError::DecompositionBaseOverflow)
}

View File

@@ -1,12 +1,12 @@
use crate::{
circuit::table::Range,
circuit::{einsum::NUM_MAX_EINSUM_CHALLENGES, table::Range},
fieldutils::IntegerRep,
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored::Colorize;
use halo2_proofs::{
circuit::Region,
circuit::{Region, Value},
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
@@ -85,6 +85,45 @@ impl ShuffleIndex {
}
}
/// Einsum index
#[derive(Clone, Debug, Default)]
pub struct EinsumIndex {
index: usize,
col_coord: usize,
// (einsum equation, input axes to dimensions map)
equations: Vec<(String, HashMap<char, usize>)>,
num_inner_cols: usize,
}
impl EinsumIndex {
/// Create a new einsum index
pub fn new(index: usize, col_coord: usize, num_inner_cols: usize) -> EinsumIndex {
EinsumIndex {
index,
col_coord,
equations: Vec::new(),
num_inner_cols,
}
}
/// Get the einsum index
pub fn index(&self) -> usize {
self.index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
/// update with another einsum index
pub fn update(&mut self, other: &EinsumIndex) {
self.index += other.index;
self.col_coord += other.col_coord;
self.equations.extend(other.equations.clone());
}
}
#[derive(Debug, Clone)]
/// Some settings for a region to differentiate it across the different phases of proof generation
pub struct RegionSettings {
@@ -176,9 +215,11 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
einsum_index: EinsumIndex,
statistics: RegionStatistics,
settings: RegionSettings,
assigned_constants: ConstantsMap<F>,
challenges: Vec<Value<F>>,
max_dynamic_input_len: usize,
}
@@ -250,6 +291,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.shuffle_index.col_coord += n;
}
/// increment the einsum index
pub fn increment_einsum_index(&mut self, n: usize) {
self.einsum_index.index += n;
}
/// increment the einsum column coordinate
pub fn increment_einsum_col_coord(&mut self, n: usize) {
self.einsum_index.col_coord += n;
}
///
pub fn witness_gen(&self) -> bool {
self.settings.witness_gen
@@ -265,6 +316,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&self.statistics
}
///
pub fn challenges(&self) -> &[Value<F>] {
&self.challenges
}
/// Create a new region context
pub fn new(
region: Region<'a, F>,
@@ -283,9 +339,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
linear_coord,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
einsum_index: EinsumIndex::default(),
statistics: RegionStatistics::default(),
settings: RegionSettings::all_true(decomp_base, decomp_legs),
assigned_constants: HashMap::new(),
challenges: vec![],
max_dynamic_input_len: 0,
}
}
@@ -304,6 +362,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
new_self
}
/// Create a new region context with challenges
pub fn new_with_challenges(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
decomp_base: usize,
decomp_legs: usize,
challenges: Vec<Value<F>>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
new_self.challenges = challenges;
new_self
}
/// Create a new region context
pub fn new_dummy(
row: usize,
@@ -320,9 +392,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
einsum_index: EinsumIndex::default(),
statistics: RegionStatistics::default(),
settings,
assigned_constants: HashMap::new(),
challenges: vec![Value::unknown(); NUM_MAX_EINSUM_CHALLENGES],
max_dynamic_input_len: 0,
}
}
@@ -333,6 +407,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
linear_coord: usize,
num_inner_cols: usize,
settings: RegionSettings,
challenges: Vec<Value<F>>,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -342,9 +417,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
einsum_index: EinsumIndex::default(),
statistics: RegionStatistics::default(),
settings,
assigned_constants: HashMap::new(),
challenges,
max_dynamic_input_len: 0,
}
}
@@ -398,6 +475,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let einsum_index = Arc::new(Mutex::new(self.einsum_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output.par_enum_map(|idx, _| {
@@ -412,6 +490,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
starting_linear_coord,
self.num_inner_cols,
self.settings.clone(),
self.challenges.clone(),
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -430,6 +509,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the einsum index
let mut einsum_index = einsum_index.lock().unwrap();
einsum_index.update(&local_reg.einsum_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
@@ -450,6 +532,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
self.einsum_index = Arc::try_unwrap(einsum_index)
.map_err(|e| CircuitError::GetEinsumError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| CircuitError::GetEinsumError(format!("{:?}", e)))?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
.into_inner()
@@ -516,6 +602,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.update_max_min_lookup_range(range)
}
/// add used einsum equation
pub fn add_used_einsum_equation(
&mut self,
equation: String,
input_axes_to_dims: &HashMap<char, usize>,
) -> Result<(), CircuitError> {
self.einsum_index
.equations
.push((equation, input_axes_to_dims.clone()));
Ok(())
}
/// Get the offset
pub fn row(&self) -> usize {
self.row
@@ -551,6 +649,31 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.shuffle_index.col_coord
}
/// einsum index
pub fn einsum_index(&self) -> usize {
self.einsum_index.index
}
/// einsum column coordinate
pub fn einsum_col_coord(&self) -> usize {
self.einsum_index.col_coord
}
/// get used einsum equations
pub fn used_einsum_equations(&self) -> Vec<(String, HashMap<char, usize>)> {
self.einsum_index.equations.clone()
}
/// set the number of inner columns used in einsum custom gate
pub fn set_num_einsum_inner_cols(&mut self, num_inner_cols: usize) {
self.einsum_index.num_inner_cols = num_inner_cols;
}
/// number of inner columns used in einsum custom gate
pub fn num_einsum_inner_cols(&self) -> usize {
self.einsum_index.num_inner_cols
}
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.statistics.used_lookups.clone()
@@ -640,6 +763,28 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor in einsum area
pub fn assign_einsum(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
Ok(var.assign(
&mut region.borrow_mut(),
self.einsum_col_coord(),
values,
&mut self.assigned_constants,
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_unconstrained(
&mut self,
@@ -697,6 +842,63 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_einsum_with_duplication_unconstrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_unconstrained(
&mut region.borrow_mut(),
self.einsum_col_coord(),
values,
&mut self.assigned_constants,
)?;
Ok((res, len))
} else {
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.einsum_col_coord(),
values,
false,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_einsum_with_duplication_constrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
check_mode: &crate::circuit::CheckMode,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_constrained(
&mut region.borrow_mut(),
self.row,
self.einsum_col_coord(),
values,
check_mode,
&mut self.assigned_constants,
)?;
Ok((res, len))
} else {
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.einsum_col_coord(),
values,
true,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
}
}
/// Enable a selector
pub fn enable(&mut self, selector: Option<&Selector>, offset: usize) -> Result<(), Error> {
match &self.region {
@@ -763,4 +965,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
Ok(())
}
/// flush row to the next row in einsum area
pub fn flush_einsum(&mut self) -> Result<(), CircuitError> {
// increment by the difference between the current linear coord and the next row
let num_einsum_inner_cols = self.num_einsum_inner_cols();
let remainder = self.einsum_col_coord() % num_einsum_inner_cols;
if remainder != 0 {
let diff = num_einsum_inner_cols - remainder;
self.increment_einsum_col_coord(diff);
}
if self.einsum_col_coord() % num_einsum_inner_cols != 0 {
return Err(CircuitError::FlushError);
}
Ok(())
}
}

View File

@@ -359,8 +359,6 @@ mod matmul_col_ultra_overflow_double_col {
&pk,
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
@@ -480,8 +478,6 @@ mod matmul_col_ultra_overflow {
&pk,
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
@@ -1298,8 +1294,6 @@ mod conv_col_ultra_overflow {
&pk,
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
@@ -1467,8 +1461,6 @@ mod conv_relu_col_ultra_overflow {
&params,
&pk,
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
// use safe mode to verify that the proof is correct
None,
None,
@@ -2643,8 +2635,6 @@ mod lookup_ultra_overflow {
&pk,
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);

View File

@@ -9,10 +9,9 @@ use std::path::PathBuf;
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::RunArgs;
use crate::circuit::CheckMode;
use crate::pfsys::TranscriptType;
/// The default path to the .json data file
pub const DEFAULT_DATA: &str = "input.json";
@@ -28,26 +27,16 @@ pub const DEFAULT_SETTINGS: &str = "settings.json";
pub const DEFAULT_PK: &str = "pk.key";
/// The default path to the verification key file
pub const DEFAULT_VK: &str = "vk.key";
/// The default path to the proving key file for aggregated proofs
pub const DEFAULT_PK_AGGREGATED: &str = "pk_aggr.key";
/// The default path to the verification key file for aggregated proofs
pub const DEFAULT_VK_AGGREGATED: &str = "vk_aggr.key";
/// The default path to the proof file
pub const DEFAULT_PROOF: &str = "proof.json";
/// The default path to the proof file for aggregated proofs
pub const DEFAULT_PROOF_AGGREGATED: &str = "proof_aggr.json";
/// Default for whether to split proofs
pub const DEFAULT_SPLIT: &str = "false";
/// Default verifier abi
pub const DEFAULT_VERIFIER_ABI: &str = "verifier_abi.json";
/// Default verifier abi for aggregated proofs
pub const DEFAULT_VERIFIER_AGGREGATED_ABI: &str = "verifier_aggr_abi.json";
/// Default solidity code
pub const DEFAULT_SOL_CODE: &str = "evm_deploy.sol";
/// Default calldata path
pub const DEFAULT_CALLDATA: &str = "calldata.bytes";
/// Default solidity code for aggregated proofs
pub const DEFAULT_SOL_CODE_AGGREGATED: &str = "evm_deploy_aggr.sol";
/// Default contract address
pub const DEFAULT_CONTRACT_ADDRESS: &str = "contract.address";
/// Default contract address for vk
@@ -56,8 +45,6 @@ pub const DEFAULT_CONTRACT_ADDRESS_VK: &str = "contract_vk.address";
pub const DEFAULT_CHECKMODE: &str = "safe";
/// Default calibration target
pub const DEFAULT_CALIBRATION_TARGET: &str = "resources";
/// Default logrows for aggregated proofs
pub const DEFAULT_AGGREGATED_LOGROWS: &str = "23";
/// Default optimizer runs
pub const DEFAULT_OPTIMIZER_RUNS: &str = "1";
/// Default fuzz runs
@@ -91,35 +78,6 @@ pub const DEFAULT_DECIMALS: &str = "18";
/// Default path for the vka digest file
pub const DEFAULT_VKA_DIGEST: &str = "vka.digest";
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
impl<'py> IntoPyObject<'py> for TranscriptType {
type Target = pyo3::PyAny;
type Output = pyo3::Bound<'py, Self::Target>;
type Error = pyo3::PyErr;
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
let result = match self {
TranscriptType::Poseidon => "poseidon",
TranscriptType::EVM => "evm",
};
Ok(result.into_pyobject(py)?.into_any())
}
}
#[cfg(feature = "python-bindings")]
/// Obtains TranscriptType from PyObject (Required for TranscriptType to be compatible with Python)
impl<'source> FromPyObject<'source> for TranscriptType {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let trystr = String::extract_bound(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
"poseidon" => Ok(TranscriptType::Poseidon),
"evm" => Ok(TranscriptType::EVM),
_ => Err(PyValueError::new_err("Invalid value for TranscriptType")),
}
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
/// Determines what the calibration pass should optimize for
pub enum CalibrationTarget {
@@ -187,7 +145,6 @@ pub enum ContractType {
/// Deploys a verifier contrat tailored to the circuit and not reusable
Verifier {
/// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits)
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
reusable: bool,
},
}
@@ -535,9 +492,6 @@ pub enum Commands {
/// number of logrows to use for srs
#[arg(long, value_hint = clap::ValueHint::Other)]
logrows: usize,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Gets an SRS from a circuit settings file.
@@ -552,9 +506,6 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// Commitment used
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Loads model and input and runs mock prover (for testing)
Mock {
@@ -566,82 +517,6 @@ pub enum Commands {
model: Option<PathBuf>,
},
/// Mock aggregate proofs
MockAggregate {
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_snarks: Vec<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// whether the accumulated are segments of a larger proof
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
},
/// Setup aggregation circuit and generate pk and vk
SetupAggregate {
/// The path to samples of snarks that will be aggregated over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
sample_snarks: Vec<PathBuf>,
/// The path to save the desired verification key file to
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to save the proving key to
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// whether the accumulated are segments of a larger proof
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
/// compress selectors
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Aggregates proofs
Aggregate {
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_snarks: Vec<PathBuf>,
/// The path to load the desired proving key file (generated using the setup-aggregate command)
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
pk_path: Option<PathBuf>,
/// The path to output the proof file to
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::default(),
value_enum,
value_hint = clap::ValueHint::Other
)]
transcript: TranscriptType,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// run sanity checks during calculations (safe or unsafe)
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
/// whether the accumulated proofs are segments of a larger circuit
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
split_proofs: Option<bool>,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
/// The path to the .onnx model file
@@ -702,15 +577,6 @@ pub enum Commands {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = ProofType::Single,
value_enum,
value_hint = clap::ValueHint::Other
)]
proof_type: ProofType,
/// run sanity checks during calculations (safe or unsafe)
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
@@ -778,32 +644,6 @@ pub enum Commands {
decimals: Option<usize>,
},
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
CreateEvmVerifierAggr {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
// aggregated circuit settings paths, used to calculate the number of instances in the aggregate proof
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
aggregation_settings: Vec<PathBuf>,
// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// Whether to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier.
#[cfg_attr(all(feature = "reusable-verifier", not(target_arch = "wasm32")), arg(short = 'R', long, action = clap::ArgAction::SetTrue))]
reusable: Option<bool>,
},
/// Verifies a proof, returning accept or reject
Verify {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -822,27 +662,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
},
/// Verifies an aggregate proof, returning accept or reject
VerifyAggr {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
proof_path: Option<PathBuf>,
/// The path to the verification key file (generated using the setup-aggregate command)
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// reduced srs
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
reduced_srs: Option<bool>,
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
logrows: Option<u32>,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
DeployEvm {

View File

@@ -1,4 +1,3 @@
use crate::pfsys::evm::EvmVerificationError;
use crate::pfsys::{encode_calldata, Snark};
use alloy::contract::CallBuilder;
use alloy::core::primitives::Address as H160;
@@ -57,8 +56,6 @@ pub enum EthError {
Wallet(#[from] WalletError),
#[error("failed to parse url {0}")]
UrlParse(String),
#[error("evm verification error: {0}")]
EvmVerification(#[from] EvmVerificationError),
#[error("Private key must be in hex format, 64 chars, without 0x prefix")]
PrivateKeyFormat,
#[error("failed to parse hex: {0}")]
@@ -100,6 +97,8 @@ pub enum EthError {
VkaData(String),
#[error("rescaledinstance mismatch: {0}")]
RescaleCheckError(#[from] RescaleCheckError),
#[error("evm verification error: {0}")]
EvmVerificationError(String),
}
pub type EthersClient = Arc<
@@ -198,7 +197,7 @@ pub async fn register_vka_via_rv(
let result = client.call(&tx).await;
if let Err(e) = result {
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
return Err(EthError::EvmVerificationError(e.to_string()).into());
}
let result = result?;
debug!("result: {:#?}", result.to_vec());
@@ -270,7 +269,7 @@ pub async fn verify_proof_via_solidity(
let result = client.call(&tx).await;
if let Err(e) = result {
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
return Err(EthError::EvmVerificationError(e.to_string()).into());
}
let result = result?;
debug!("result: {:#?}", result.to_vec());
@@ -306,7 +305,7 @@ pub async fn verify_proof_via_solidity(
.ok_or(EthError::NoContractOutput)?
== &1u8;
if !result {
return Err(EvmVerificationError::InvalidProof.into());
return Err(EthError::EvmVerificationError("Invalid proof".into()));
}
let gas = client.estimate_gas(&tx).await?;

File diff suppressed because it is too large Load Diff

View File

@@ -64,6 +64,7 @@ use pyo3::types::PyDictMethods;
use pyo3::IntoPyObject;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::Deref;
pub use utilities::*;
pub use vars::*;
@@ -438,6 +439,15 @@ pub struct ShuffleParams {
pub total_shuffle_col_size: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
/// Parameters for einsum operations
pub struct EinsumParams {
/// einsum equations
pub equations: Vec<(String, HashMap<char, usize>)>,
/// total einsum column size
pub total_einsum_col_size: usize,
}
/// model parameters
#[derive(Clone, Debug, Default, PartialEq)]
pub struct GraphSettings {
@@ -453,6 +463,8 @@ pub struct GraphSettings {
pub dynamic_lookup_params: DynamicLookupParams,
/// shuffle parameters, flattened for backwards compatibility
pub shuffle_params: ShuffleParams,
/// einsum parameters
pub einsum_params: EinsumParams,
/// the shape of public inputs to the model (in order of appearance)
pub model_instance_shapes: Vec<Vec<usize>>,
/// model output scales
@@ -487,7 +499,7 @@ impl Serialize for GraphSettings {
if serializer.is_human_readable() {
// JSON format - use flattened fields for backwards compatibility
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("GraphSettings", 21)?;
let mut state = serializer.serialize_struct("GraphSettings", 22)?;
state.serialize_field("run_args", &self.run_args)?;
state.serialize_field("num_rows", &self.num_rows)?;
state.serialize_field("total_assignments", &self.total_assignments)?;
@@ -514,6 +526,9 @@ impl Serialize for GraphSettings {
&self.shuffle_params.total_shuffle_col_size,
)?;
// Serialize EinsumParams
state.serialize_field("einsum_params", &self.einsum_params)?;
state.serialize_field("model_instance_shapes", &self.model_instance_shapes)?;
state.serialize_field("model_output_scales", &self.model_output_scales)?;
state.serialize_field("model_input_scales", &self.model_input_scales)?;
@@ -530,13 +545,14 @@ impl Serialize for GraphSettings {
} else {
// Binary format (bincode) - use nested struct format
use serde::ser::SerializeTuple;
let mut state = serializer.serialize_tuple(18)?;
let mut state = serializer.serialize_tuple(19)?;
state.serialize_element(&self.run_args)?;
state.serialize_element(&self.num_rows)?;
state.serialize_element(&self.total_assignments)?;
state.serialize_element(&self.total_const_size)?;
state.serialize_element(&self.dynamic_lookup_params)?;
state.serialize_element(&self.shuffle_params)?;
state.serialize_element(&self.einsum_params)?;
state.serialize_element(&self.model_instance_shapes)?;
state.serialize_element(&self.model_output_scales)?;
state.serialize_element(&self.model_input_scales)?;
@@ -576,6 +592,8 @@ impl<'de> Deserialize<'de> for GraphSettings {
// Flattened ShuffleParams fields
NumShuffles,
TotalShuffleColSize,
// EinsumParams field
EinsumParams,
ModelInstanceShapes,
ModelOutputScales,
ModelInputScales,
@@ -615,6 +633,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
let mut num_dynamic_lookups = None;
let mut num_shuffles = None;
let mut total_shuffle_col_size = None;
let mut einsum_params = None;
let mut model_instance_shapes = None;
let mut model_output_scales = None;
let mut model_input_scales = None;
@@ -684,6 +703,12 @@ impl<'de> Deserialize<'de> for GraphSettings {
}
total_shuffle_col_size = Some(map.next_value()?);
}
Field::EinsumParams => {
if einsum_params.is_some() {
return Err(de::Error::duplicate_field("einsum_params"));
}
einsum_params = Some(map.next_value()?);
}
Field::ModelInstanceShapes => {
if model_instance_shapes.is_some() {
return Err(de::Error::duplicate_field("model_instance_shapes"));
@@ -822,6 +847,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
total_const_size,
dynamic_lookup_params,
shuffle_params,
einsum_params: einsum_params.unwrap_or_default(),
model_instance_shapes,
model_output_scales,
model_input_scales,
@@ -862,42 +888,45 @@ impl<'de> Deserialize<'de> for GraphSettings {
let shuffle_params = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(5, &self))?;
let model_instance_shapes = seq
let einsum_params = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(6, &self))?;
let model_output_scales = seq
let model_instance_shapes = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(7, &self))?;
let model_input_scales = seq
let model_output_scales = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(8, &self))?;
let module_sizes = seq
let model_input_scales = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(9, &self))?;
let required_lookups = seq
let module_sizes = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(10, &self))?;
let required_range_checks = seq
let required_lookups = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(11, &self))?;
let check_mode = seq
let required_range_checks = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(12, &self))?;
let version = seq
let check_mode = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(13, &self))?;
let num_blinding_factors = seq
let version = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(14, &self))?;
let timestamp = seq
let num_blinding_factors = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(15, &self))?;
let input_types = seq
let timestamp = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(16, &self))?;
let output_types = seq
let input_types = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(17, &self))?;
let output_types = seq
.next_element()?
.ok_or_else(|| Error::invalid_length(18, &self))?;
Ok(GraphSettings {
run_args,
@@ -906,6 +935,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
total_const_size,
dynamic_lookup_params,
shuffle_params,
einsum_params,
model_instance_shapes,
model_output_scales,
model_input_scales,
@@ -935,6 +965,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
"num_dynamic_lookups",
"num_shuffles",
"total_shuffle_col_size",
"einsum_params",
"model_instance_shapes",
"model_output_scales",
"model_input_scales",
@@ -953,7 +984,7 @@ impl<'de> Deserialize<'de> for GraphSettings {
deserializer.deserialize_struct("GraphSettings", FIELDS, GraphSettingsVisitor)
} else {
// Binary format (bincode) - use tuple deserialization
deserializer.deserialize_tuple(18, GraphSettingsVisitor)
deserializer.deserialize_tuple(19, GraphSettingsVisitor)
}
}
}
@@ -1038,6 +1069,13 @@ impl GraphSettings {
.ceil() as u32
}
/// Calculates the logrows for einsum computation area in which there is no column overflow
pub fn einsum_logrows(&self) -> u32 {
(self.einsum_params.total_einsum_col_size as f64 / self.run_args.num_inner_cols as f64)
.log2()
.ceil() as u32
}
/// calculate the total number of instances
pub fn total_instances(&self) -> Vec<usize> {
let mut instances: Vec<usize> = self.module_sizes.num_instances();
@@ -1593,13 +1631,19 @@ impl GraphCircuit {
let model_constraint_logrows = self.settings().model_constraint_logrows_with_blinding();
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
let constants_logrows = self.settings().constants_logrows();
let einsum_logrows = self.settings().einsum_logrows();
max_logrows = std::cmp::min(
max_logrows,
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
*[model_constraint_logrows, min_bits, constants_logrows]
.iter()
.max()
.unwrap(),
*[
model_constraint_logrows,
min_bits,
constants_logrows,
einsum_logrows,
]
.iter()
.max()
.unwrap(),
);
// we now have a min and max logrows
@@ -2180,6 +2224,7 @@ pub mod tests {
num_shuffles: 3,
total_shuffle_col_size: 256,
},
einsum_params: EinsumParams::default(),
model_instance_shapes: vec![vec![1, 2, 3]],
model_output_scales: vec![],
model_input_scales: vec![],
@@ -2254,7 +2299,8 @@ pub mod tests {
"decomp_base": 128,
"decomp_legs": 2,
"bounded_log_lookup": false,
"ignore_range_check_inputs_outputs": false
"ignore_range_check_inputs_outputs": false,
"disable_freivalds": false
},
"num_rows": 236,
"total_assignments": 472,

View File

@@ -3,6 +3,7 @@ use super::extract_const_quantized_values;
use super::node::*;
use super::vars::*;
use super::GraphSettings;
use crate::circuit::einsum::analysis::analyze_einsum_usage;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
@@ -37,7 +38,6 @@ use log::{debug, info, trace};
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs;
@@ -106,6 +106,8 @@ pub struct DummyPassRes {
pub dynamic_lookup_params: DynamicLookupParams,
/// shuffle parameters
pub shuffle_params: ShuffleParams,
/// einsum parameters
pub einsum_params: crate::graph::EinsumParams,
/// num shuffles
pub num_shuffles: usize,
/// shuffle
@@ -592,6 +594,7 @@ impl Model {
output_types: Some(self.get_output_types()),
dynamic_lookup_params: res.dynamic_lookup_params,
shuffle_params: res.shuffle_params,
einsum_params: res.einsum_params,
total_const_size: res.total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
@@ -1047,6 +1050,7 @@ impl Model {
let lookup_range = settings.run_args.lookup_range;
let logrows = settings.run_args.logrows as usize;
let num_inner_cols = settings.run_args.num_inner_cols;
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
@@ -1095,6 +1099,24 @@ impl Model {
)?;
}
// Configures the circuit to use Freivalds' argument
// In the dummy phase, Freivalds' is configured as a default (unless `disable-freivalds` is not enabled),
// but if einsum coordinate is 0, it means that all the einsum layouts are dispatched to use only base operations.
if settings.einsum_params.total_einsum_col_size > 0 {
debug!("configuring einsums...");
let used_einsums: HashMap<(usize, String), HashMap<char, usize>> = settings
.einsum_params
.equations
.iter()
.enumerate()
.map(|(idx, (equation, indices_to_dims))| {
((idx, equation.clone()), indices_to_dims.clone())
})
.collect();
let analysis = analyze_einsum_usage(&used_einsums)?;
base_gate.configure_einsums(meta, &analysis, num_inner_cols, logrows)?;
}
Ok(base_gate)
}
@@ -1147,17 +1169,30 @@ impl Model {
let original_constants = constants.clone();
let challenges = {
if let Some(einsum_config) = &config.base.einsums {
einsum_config
.challenges()?
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec()
} else {
vec![]
}
};
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new_with_constants(
let mut thread_safe_region = RegionCtx::new_with_challenges(
region,
0,
run_args.num_inner_cols,
run_args.decomp_base,
run_args.decomp_legs,
original_constants.clone(),
challenges.clone(),
);
thread_safe_region.update_constants(original_constants.clone());
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1459,8 +1494,16 @@ impl Model {
results.insert(*input_idx, vec![inputs[i].clone()]);
}
let mut dummy_config =
PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols);
let mut dummy_config = {
if run_args.disable_freivalds {
PolyConfig::dummy_without_freivalds(
run_args.logrows as usize,
run_args.num_inner_cols,
)
} else {
PolyConfig::dummy(run_args.logrows as usize, run_args.num_inner_cols)
}
};
let mut model_config = ModelConfig {
base: dummy_config.clone(),
vars: ModelVars::new_dummy(),
@@ -1529,6 +1572,10 @@ impl Model {
num_shuffles: region.shuffle_index(),
total_shuffle_col_size: region.shuffle_col_coord(),
},
einsum_params: crate::graph::EinsumParams {
equations: region.used_einsum_equations(),
total_einsum_col_size: region.einsum_col_coord(),
},
total_const_size: region.total_constants(),
lookup_ops: region.used_lookups(),
range_checks: region.used_range_checks(),

View File

@@ -42,8 +42,6 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[derive(thiserror::Error, Debug)]
#[allow(missing_docs)]
pub enum EZKLError {
#[error("[aggregation] {0}")]
AggregationError(#[from] pfsys::evm::aggregation_kzg::AggregationError),
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
@@ -100,18 +98,11 @@ impl From<String> for EZKLError {
EZKLError::UncategorizedError(s)
}
}
use std::str::FromStr;
use circuit::{table::Range, CheckMode};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
use fieldutils::IntegerRep;
use graph::{Visibility, MAX_PUBLIC_SRS};
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};
use halo2curves::bn256::{Bn256, G1Affine};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
@@ -130,7 +121,6 @@ pub fn version() -> &'static str {
/// Bindings management
#[cfg(any(
feature = "universal-bindings",
all(target_arch = "wasm32", target_os = "unknown"),
feature = "python-bindings"
))]
@@ -171,8 +161,6 @@ pub mod pfsys;
pub mod srs_sha;
/// An implementation of multi-dimensional tensors.
pub mod tensor;
#[cfg(feature = "ios-bindings")]
uniffi::setup_scaffolding!();
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use lazy_static::lazy_static;
@@ -198,78 +186,6 @@ const EZKL_KEY_FORMAT: &str = "raw-bytes";
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
const EZKL_BUF_CAPACITY: &usize = &8000;
#[derive(
Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Default, Copy,
)]
/// Commitment scheme
pub enum Commitments {
#[default]
/// KZG
KZG,
/// IPA
IPA,
}
impl From<Option<Commitments>> for Commitments {
fn from(value: Option<Commitments>) -> Self {
value.unwrap_or(Commitments::KZG)
}
}
impl FromStr for Commitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"kzg" => Ok(Commitments::KZG),
"ipa" => Ok(Commitments::IPA),
_ => Err("Invalid value for Commitments".to_string()),
}
}
}
impl From<KZGCommitmentScheme<Bn256>> for Commitments {
fn from(_value: KZGCommitmentScheme<Bn256>) -> Self {
Commitments::KZG
}
}
impl From<IPACommitmentScheme<G1Affine>> for Commitments {
fn from(_value: IPACommitmentScheme<G1Affine>) -> Self {
Commitments::IPA
}
}
impl std::fmt::Display for Commitments {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Commitments::KZG => write!(f, "kzg"),
Commitments::IPA => write!(f, "ipa"),
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Commitments {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<String> for Commitments {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
"kzg" => Commitments::KZG,
"ipa" => Commitments::IPA,
_ => {
log::error!("Invalid value for Commitments");
log::warn!("defaulting to KZG");
Commitments::KZG
}
}
}
}
/// Parameters specific to a proving run
///
/// RunArgs contains all configuration parameters needed to control the proving process,
@@ -336,10 +252,6 @@ pub struct RunArgs {
/// Controls level of constraint verification
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
pub check_mode: CheckMode,
/// Commitment scheme for circuit proving
/// Affects proof size and verification time
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
pub commitment: Option<Commitments>,
/// Base for number decomposition
/// Must be a power of 2
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
@@ -363,6 +275,13 @@ pub struct RunArgs {
/// Optional override for epsilon value
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
pub epsilon: Option<f64>,
/// Forcefully disable using Freivalds' argument in einsum operations
/// Freivalds' argument can make verifier bigger, so this option is useful when
/// the verifier size is a concern
/// Without this option the circuit layouter will always try to use Freivalds' argument
/// when it is good to do so
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
pub disable_freivalds: bool,
}
impl RunArgs {
@@ -393,11 +312,11 @@ impl Default for RunArgs {
param_visibility: Visibility::Fixed,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: None,
decomp_base: 16384,
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
epsilon: None,
disable_freivalds: false,
}
}
}

View File

@@ -1,442 +0,0 @@
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::CircuitSize;
use crate::pfsys::{Snark, SnarkWitness};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use colored_json::ToColoredJson;
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::plonk::{self};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem},
};
use halo2_wrong_ecc::{
integer::rns::Rns,
maingate::{
MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, RangeInstructions,
RegionCtx,
},
EccConfig,
};
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::debug;
use log::trace;
use rand::rngs::OsRng;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::loader::EcPointLoader;
use snark_verifier::{
loader,
pcs::{
kzg::{
Bdfg21, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding,
LimbsEncodingInstructions,
},
AccumulationScheme, AccumulationSchemeProver,
},
system,
util::arithmetic::fe_to_limbs,
verifier::{self, SnarkVerifier},
};
use std::rc::Rc;
use thiserror::Error;
const LIMBS: usize = 4;
const BITS: usize = 68;
type As = KzgAs<Bn256, Bdfg21>;
/// Type for aggregator verification
type PlonkSuccinctVerifier = verifier::plonk::PlonkSuccinctVerifier<As, LimbsEncoding<LIMBS, BITS>>;
const T: usize = 5;
const RATE: usize = 4;
const R_F: usize = 8;
const R_P: usize = 60;
type Svk = KzgSuccinctVerifyingKey<G1Affine>;
type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip<G1Affine, LIMBS, BITS>;
/// The loader type used in the transcript definition
type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>;
/// Application snark transcript
pub type PoseidonTranscript<L, S> =
system::halo2::transcript::halo2::PoseidonTranscript<G1Affine, L, S, T, RATE, R_F, R_P>;
#[derive(Error, Debug)]
/// Errors related to proof aggregation
pub enum AggregationError {
/// A KZG proof could not be verified
#[error("failed to verify KZG proof")]
KZGProofVerification,
/// proof read errors
#[error("Failed to read proof")]
ProofRead,
/// proof verification errors
#[error("Failed to verify proof")]
ProofVerify,
/// proof creation errors
#[error("Failed to create proof")]
ProofCreate,
}
type AggregationResult<'a> = (
// accumulator
KzgAccumulator<G1Affine, Rc<Halo2Loader<'a>>>,
// the set of assigned cells
Vec<Vec<AssignedCell<Fr, Fr>>>,
);
type LoadedProof<'a> = verifier::plonk::PlonkProof<
G1Affine,
Rc<
loader::halo2::Halo2Loader<
'a,
G1Affine,
halo2_wrong_ecc::BaseFieldEccChip<G1Affine, 4, 68>,
>,
>,
KzgAs<Bn256, Bdfg21>,
>;
/// Aggregate one or more application snarks of the same shape into a KzgAccumulator
pub fn aggregate<'a>(
svk: &Svk,
loader: &Rc<Halo2Loader<'a>>,
snarks: &[SnarkWitness<Fr, G1Affine>],
as_proof: Value<&'_ [u8]>,
split_proofs: bool,
) -> Result<AggregationResult<'a>, plonk::Error> {
let assign_instances = |instances: &[Vec<Value<Fr>>]| {
instances
.iter()
.map(|instances| {
instances
.iter()
.map(|instance| loader.assign_scalar(*instance))
.collect_vec()
})
.collect_vec()
};
let mut accumulators = vec![];
let mut snark_instances = vec![];
let mut proofs: Vec<LoadedProof<'_>> = vec![];
for snark in snarks.iter() {
let protocol = snark.protocol.as_ref().unwrap().loaded(loader);
let instances = assign_instances(&snark.instances);
// get assigned cells
snark_instances.extend(instances.iter().map(|instance| {
instance
.iter()
.map(|v| v.clone().into_assigned())
.collect_vec()
}));
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
.map_err(|_| plonk::Error::Synthesis)?;
if split_proofs {
let previous_proof = proofs.last();
let split_commit = match snark.clone().split {
Some(split) => split,
None => {
log::error!("Failed to split KZG commit for sequential proofs");
return Err(plonk::Error::Synthesis);
}
};
if let Some(previous_proof) = previous_proof {
// output of previous proof
let output = &previous_proof.witnesses[split_commit.start..split_commit.end];
// input of current proof
let split_commit_len = split_commit.end - split_commit.start;
let input = &proof.witnesses[..split_commit_len];
// these points were already assigned previously when loading the transcript so this is safe
// and equivalent to a copy constraint and an equality constraint
for (output, input) in output.iter().zip(input.iter()) {
loader
.ec_point_assert_eq("assert commits match", output, input)
.map_err(|e| {
log::error!(
"Failed to match KZG commits for sequential proofs: {:?}",
e
);
plonk::Error::Synthesis
})?;
}
}
proofs.push(proof.clone());
}
let mut accum = PlonkSuccinctVerifier::verify(svk, &protocol, &instances, &proof)
.map_err(|_| plonk::Error::Synthesis)?;
accumulators.append(&mut accum);
}
let accumulator = {
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, as_proof);
let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap();
As::verify(&Default::default(), &accumulators, &proof).map_err(|_| plonk::Error::Synthesis)
}?;
Ok((accumulator, snark_instances))
}
/// The Halo2 Config for the aggregation circuit
#[derive(Clone, Debug)]
pub struct AggregationConfig {
main_gate_config: MainGateConfig,
range_config: RangeConfig,
}
impl AggregationConfig {
/// Configure the aggregation circuit
pub fn configure<F: PrimeField>(
meta: &mut ConstraintSystem<F>,
composition_bits: Vec<usize>,
overflow_bits: Vec<usize>,
) -> Self {
let main_gate_config = MainGate::<F>::configure(meta);
let range_config =
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let circuit_size = CircuitSize::from_cs(meta, 23);
// not wasm
debug!(
"circuit size: \n {}",
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
.unwrap()
);
}
AggregationConfig {
main_gate_config,
range_config,
}
}
/// Create a MainGate from the aggregation approach
pub fn main_gate(&self) -> MainGate<Fr> {
MainGate::new(self.main_gate_config.clone())
}
/// Create a range chip to decompose and range check inputs
pub fn range_chip(&self) -> RangeChip<Fr> {
RangeChip::new(self.range_config.clone())
}
/// Create an ecc chip for ec ops
pub fn ecc_chip(&self) -> BaseFieldEccChip {
BaseFieldEccChip::new(EccConfig::new(
self.range_config.clone(),
self.main_gate_config.clone(),
))
}
}
/// Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof.
#[derive(Clone, Debug)]
pub struct AggregationCircuit {
svk: Svk,
snarks: Vec<SnarkWitness<Fr, G1Affine>>,
instances: Vec<Fr>,
as_proof: Value<Vec<u8>>,
split_proof: bool,
}
impl AggregationCircuit {
/// Create a new Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof.
pub fn new(
svk: &KzgSuccinctVerifyingKey<G1Affine>,
snarks: impl IntoIterator<Item = Snark<Fr, G1Affine>>,
split_proof: bool,
) -> Result<Self, AggregationError> {
let snarks = snarks.into_iter().collect_vec();
let mut accumulators = vec![];
for snark in snarks.iter() {
trace!("Aggregating with snark instances {:?}", snark.instances);
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(snark.proof.as_slice());
let proof = PlonkSuccinctVerifier::read_proof(
svk,
snark.protocol.as_ref().unwrap(),
&snark.instances,
&mut transcript,
)
.map_err(|e| {
log::error!("{:?}", e);
AggregationError::ProofRead
})?;
let mut accum = PlonkSuccinctVerifier::verify(
svk,
snark.protocol.as_ref().unwrap(),
&snark.instances,
&proof,
)
.map_err(|_| AggregationError::ProofVerify)?;
accumulators.append(&mut accum);
}
trace!("Accumulator");
let (accumulator, as_proof) = {
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(Vec::new());
let accumulator =
As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng)
.map_err(|_| AggregationError::ProofCreate)?;
(accumulator, transcript.finalize())
};
trace!("KzgAccumulator");
let KzgAccumulator { lhs, rhs } = accumulator;
let instances = [lhs.x, lhs.y, rhs.x, rhs.y]
.map(fe_to_limbs::<_, _, LIMBS, BITS>)
.concat();
Ok(Self {
svk: *svk,
snarks: snarks.into_iter().map_into().collect(),
instances,
as_proof: Value::known(as_proof),
split_proof,
})
}
/// Number of limbs used for decomposition
pub fn num_limbs() -> usize {
LIMBS
}
/// Number of bits used for decomposition
pub fn num_bits() -> usize {
BITS
}
/// Accumulator indices used in generating verifier.
pub fn accumulator_indices() -> Vec<(usize, usize)> {
(0..4 * LIMBS).map(|idx| (0, idx)).collect()
}
/// Number of instance variables for the aggregation circuit, used in generating verifier.
pub fn num_instance(orginal_circuit_instances: usize) -> Vec<usize> {
let accumulation_instances = 4 * LIMBS;
vec![accumulation_instances + orginal_circuit_instances]
}
/// Instance variables for the aggregation circuit, fed to verifier.
pub fn instances(&self) -> Vec<Fr> {
// also get snark instances here
let mut snark_instances: Vec<Vec<Vec<Value<Fr>>>> = self
.snarks
.iter()
.map(|snark| snark.instances.clone())
.collect_vec();
// reduce from Vec<Vec<Vec<Value<Fr>>>> to Vec<Vec<Value<Fr>>>
let mut instances: Vec<Fr> = self.instances.clone();
for snark_instance in snark_instances.iter_mut() {
for instance in snark_instance.iter_mut() {
let mut felt_evals = vec![];
for value in instance.iter_mut() {
value.map(|v| felt_evals.push(v));
}
instances.extend(felt_evals);
}
}
instances
}
fn as_proof(&self) -> Value<&[u8]> {
self.as_proof.as_ref().map(Vec::as_slice)
}
}
impl Circuit<Fr> for AggregationCircuit {
type Config = AggregationConfig;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
fn without_witnesses(&self) -> Self {
Self {
svk: self.svk,
snarks: self
.snarks
.iter()
.map(SnarkWitness::without_witnesses)
.collect(),
instances: Vec::new(),
as_proof: Value::unknown(),
split_proof: self.split_proof,
}
}
fn configure(meta: &mut ConstraintSystem<Fr>) -> Self::Config {
AggregationConfig::configure(
meta,
vec![BITS / LIMBS],
Rns::<Fq, Fr, LIMBS, BITS>::construct().overflow_lengths(),
)
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), plonk::Error> {
let main_gate = config.main_gate();
let range_chip = config.range_chip();
range_chip.load_table(&mut layouter)?;
let (accumulator_limbs, snark_instances) = layouter.assign_region(
|| "",
|region| {
let ctx = RegionCtx::new(region, 0);
let ecc_chip = config.ecc_chip();
let loader = Halo2Loader::new(ecc_chip, ctx);
let (accumulator, snark_instances) = aggregate(
&self.svk,
&loader,
&self.snarks,
self.as_proof(),
self.split_proof,
)?;
let accumulator_limbs = [accumulator.lhs, accumulator.rhs]
.iter()
.map(|ec_point| {
loader
.ecc_chip()
.assign_ec_point_to_limbs(&mut loader.ctx_mut(), ec_point.assigned())
})
.collect::<Result<Vec<_>, plonk::Error>>()?
.into_iter()
.flatten();
Ok((accumulator_limbs, snark_instances))
},
)?;
let mut instance_offset = 0;
for limb in accumulator_limbs {
main_gate.expose_public(layouter.namespace(|| ""), limb, instance_offset)?;
instance_offset += 1;
}
for instance in snark_instances.into_iter() {
for elem in instance.into_iter() {
main_gate.expose_public(layouter.namespace(|| ""), elem, instance_offset)?;
instance_offset += 1;
}
}
Ok(())
}
}

View File

@@ -1,24 +0,0 @@
use thiserror::Error;
/// Aggregate proof generation for EVM using KZG
pub mod aggregation_kzg;
#[derive(Error, Debug)]
/// Errors related to evm verification
pub enum EvmVerificationError {
/// If the Solidity verifier worked but returned false
#[error("Solidity verifier found the proof invalid")]
InvalidProof,
/// If the Solidity verifier threw and error (e.g. OutOfGas)
#[error("Execution of Solidity code failed: {0}")]
SolidityExecution(String),
/// EVM verify errors
#[error("evm verification reverted: {0}")]
Reverted(String),
/// EVM verify errors
#[error("evm deployment failed: {0}")]
DeploymentFailed(String),
/// Invalid Visibility
#[error("Invalid visibility")]
InvalidVisibility,
}

View File

@@ -1,6 +1,3 @@
/// EVM related proving and verification
pub mod evm;
/// SRS generation, processing, verification and downloading
pub mod srs;
@@ -13,17 +10,11 @@ use std::borrow::Borrow;
use crate::circuit::CheckMode;
use crate::graph::GraphWitness;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use crate::{EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use halo2_proofs::plonk::{
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
};
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
@@ -37,22 +28,16 @@ use rand::rngs::OsRng;
use rand::rngs::StdRng;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::verifier::plonk::PlonkProtocol;
use std::fs::File;
use std::io::{self, BufReader, BufWriter, Cursor, Write};
use std::ops::Deref;
use std::path::PathBuf;
use thiserror::Error as thisError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
/// Converts a string to a `SerdeFormat`.
/// # Panics
/// Panics if the provided `s` is not a valid `SerdeFormat` (i.e. not one of "processed", "raw-bytes-unchecked", or "raw-bytes").
@@ -140,144 +125,6 @@ where
bytes
}
#[allow(missing_docs)]
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
pub enum ProofType {
#[default]
Single,
ForAggr,
}
impl std::fmt::Display for ProofType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
ProofType::Single => "single",
ProofType::ForAggr => "for-aggr",
}
)
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for ProofType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<ProofType> for TranscriptType {
fn from(val: ProofType) -> Self {
match val {
ProofType::Single => TranscriptType::EVM,
ProofType::ForAggr => TranscriptType::Poseidon,
}
}
}
impl From<ProofType> for StrategyType {
fn from(val: ProofType) -> Self {
match val {
ProofType::Single => StrategyType::Single,
ProofType::ForAggr => StrategyType::Accum,
}
}
}
#[cfg(feature = "python-bindings")]
impl<'py> pyo3::IntoPyObject<'py> for ProofType {
type Target = pyo3::PyAny;
type Output = pyo3::Bound<'py, Self::Target>;
type Error = pyo3::PyErr;
fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
let result = match self {
ProofType::Single => "Single",
ProofType::ForAggr => "ForAggr",
};
Ok(result.into_pyobject(py)?.into_any())
}
}
#[cfg(feature = "python-bindings")]
/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python)
impl<'source> pyo3::FromPyObject<'source> for ProofType {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let strval = String::extract_bound(ob)?;
match strval.to_lowercase().as_str() {
"single" => Ok(ProofType::Single),
"for-aggr" => Ok(ProofType::ForAggr),
_ => Err(pyo3::exceptions::PyValueError::new_err(
"Invalid value for ProofType",
)),
}
}
}
#[allow(missing_docs)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
pub enum StrategyType {
Single,
Accum,
}
impl std::fmt::Display for StrategyType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// When the `ezkl` feature is disabled or we're targeting `wasm32`, use basic string representation.
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
{
write!(
f,
"{}",
match self {
StrategyType::Single => "single",
StrategyType::Accum => "accum",
}
)
}
// When the `ezkl` feature is enabled and we're not targeting `wasm32`, use `to_possible_value`.
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
self.to_possible_value()
.expect("no values are skipped")
.get_name()
.fmt(f)
}
}
}
#[cfg(feature = "python-bindings")]
/// Converts StrategyType into a PyObject (Required for StrategyType to be compatible with Python)
impl<'py> pyo3::IntoPyObject<'py> for StrategyType {
type Target = pyo3::PyAny;
type Output = pyo3::Bound<'py, Self::Target>;
type Error = pyo3::PyErr;
fn into_pyobject(self, py: pyo3::Python<'py>) -> Result<Self::Output, Self::Error> {
let result = match self {
StrategyType::Single => "single",
StrategyType::Accum => "accum",
};
Ok(result.into_pyobject(py)?.into_any())
}
}
#[cfg(feature = "python-bindings")]
/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python)
impl<'source> pyo3::FromPyObject<'source> for StrategyType {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let strval = String::extract_bound(ob)?;
match strval.to_lowercase().as_str() {
"single" => Ok(StrategyType::Single),
"accum" => Ok(StrategyType::Accum),
_ => Err(pyo3::exceptions::PyValueError::new_err(
"Invalid value for StrategyType",
)),
}
}
}
#[derive(thisError, Debug)]
/// Errors related to pfsys
pub enum PfSysError {
@@ -286,33 +133,8 @@ pub enum PfSysError {
PackingExponent,
}
#[allow(missing_docs)]
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
pub enum TranscriptType {
Poseidon,
#[default]
EVM,
}
impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
TranscriptType::Poseidon => "poseidon",
TranscriptType::EVM => "evm",
}
)
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for TranscriptType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
#[cfg(feature = "python-bindings")]
use halo2curves::bn256::G1Affine;
#[cfg(feature = "python-bindings")]
///
@@ -371,7 +193,7 @@ pub struct PrettyElements {
pub outputs: Vec<Vec<String>>,
}
/// An application snark with proof and instance variables ready for aggregation (raw field element)
/// An application snark with proof and instance variables
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Snark<F: PrimeField + SerdeObject, C: CurveAffine>
where
@@ -386,16 +208,12 @@ where
pub proof: Vec<u8>,
/// hex encoded proof
pub hex_proof: Option<String>,
/// transcript type
pub transcript_type: TranscriptType,
/// the split proof
pub split: Option<ProofSplitCommit>,
/// the proof instances as rescaled floats
pub pretty_public_inputs: Option<PrettyElements>,
/// timestamp
pub timestamp: Option<u128>,
/// commitment
pub commitment: Option<Commitments>,
/// (optional) version of ezkl used to generate the proof
version: Option<String>,
}
@@ -423,8 +241,6 @@ where
dict.set_item("instances", field_elems).unwrap();
let hex_proof = hex::encode(&self.proof);
dict.set_item("proof", format!("0x{}", hex_proof)).unwrap();
dict.set_item("transcript_type", self.transcript_type.into_pyobject(py)?)
.unwrap();
Ok(dict.into_any())
}
}
@@ -437,24 +253,21 @@ where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,
{
/// Create a new application snark from proof and instance variables ready for aggregation
/// Create a new application snark from proof and instance variables
#[allow(clippy::too_many_arguments)]
pub fn new(
protocol: Option<PlonkProtocol<C>>,
instances: Vec<Vec<F>>,
proof: Vec<u8>,
hex_proof: Option<String>,
transcript_type: TranscriptType,
split: Option<ProofSplitCommit>,
pretty_public_inputs: Option<PrettyElements>,
commitment: Option<Commitments>,
) -> Self {
Self {
protocol,
instances,
proof,
hex_proof,
transcript_type,
split,
pretty_public_inputs,
// unix timestamp
@@ -464,7 +277,6 @@ where
.unwrap()
.as_millis(),
),
commitment,
version: Some(crate::version().to_string()),
}
}
@@ -560,53 +372,6 @@ impl From<GraphWitness> for Option<ProofSplitCommit> {
}
}
/// An application snark with proof and instance variables ready for aggregation (wrapped field element)
#[derive(Clone, Debug)]
pub struct SnarkWitness<F: PrimeField, C: CurveAffine> {
protocol: Option<PlonkProtocol<C>>,
instances: Vec<Vec<Value<F>>>,
proof: Value<Vec<u8>>,
split: Option<ProofSplitCommit>,
}
impl<F: PrimeField, C: CurveAffine> SnarkWitness<F, C> {
fn without_witnesses(&self) -> Self {
SnarkWitness {
protocol: self.protocol.clone(),
instances: self
.instances
.iter()
.map(|instances| vec![Value::unknown(); instances.len()])
.collect(),
proof: Value::unknown(),
split: self.split.clone(),
}
}
fn proof(&self) -> Value<&[u8]> {
self.proof.as_ref().map(Vec::as_slice)
}
}
impl<F: PrimeField + SerdeObject, C: CurveAffine> From<Snark<F, C>> for SnarkWitness<F, C>
where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,
{
fn from(snark: Snark<F, C>) -> Self {
Self {
protocol: snark.protocol,
instances: snark
.instances
.into_iter()
.map(|instances| instances.into_iter().map(Value::known).collect())
.collect(),
proof: Value::known(snark.proof),
split: snark.split,
}
}
}
/// Creates a [VerifyingKey] and [ProvingKey] for a [crate::graph::GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`).
pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
circuit: &C,
@@ -652,8 +417,6 @@ pub fn create_proof_circuit<
params: &'params Scheme::ParamsProver,
pk: &ProvingKey<Scheme::Curve>,
check_mode: CheckMode,
commitment: Commitments,
transcript_type: TranscriptType,
split: Option<ProofSplitCommit>,
protocol: Option<PlonkProtocol<Scheme::Curve>>,
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, PfsysError>
@@ -701,16 +464,7 @@ where
let proof = transcript.finalize();
let hex_proof = format!("0x{}", hex::encode(&proof));
let checkable_pf = Snark::new(
protocol,
instances,
proof,
Some(hex_proof),
transcript_type,
split,
None,
Some(commitment),
);
let checkable_pf = Snark::new(protocol, instances, proof, Some(hex_proof), split, None);
// sanity check that the generated proof is valid
if check_mode == CheckMode::SAFE {
@@ -799,44 +553,6 @@ where
Ok(proof_first_bytes)
}
/// Swap the proof commitments to a new set in the proof for KZG
pub fn swap_proof_commitments_polycommit(
snark: &Snark<Fr, G1Affine>,
commitments: &[G1Affine],
) -> Result<Snark<Fr, G1Affine>, PfsysError> {
let proof = match snark.commitment {
Some(Commitments::KZG) => match snark.transcript_type {
TranscriptType::EVM => swap_proof_commitments::<
KZGCommitmentScheme<Bn256>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(snark, commitments)?,
TranscriptType::Poseidon => swap_proof_commitments::<
KZGCommitmentScheme<Bn256>,
_,
PoseidonTranscript<NativeLoader, _>,
>(snark, commitments)?,
},
Some(Commitments::IPA) => match snark.transcript_type {
TranscriptType::EVM => swap_proof_commitments::<
IPACommitmentScheme<G1Affine>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(snark, commitments)?,
TranscriptType::Poseidon => swap_proof_commitments::<
IPACommitmentScheme<G1Affine>,
_,
PoseidonTranscript<NativeLoader, _>,
>(snark, commitments)?,
},
None => {
return Err(PfsysError::InvalidCommitmentScheme);
}
};
Ok(proof)
}
/// A wrapper around halo2's verify_proof
pub fn verify_proof_circuit<
'params,
@@ -993,13 +709,11 @@ mod tests {
let snark = Snark::<Fr, G1Affine> {
proof: vec![1, 2, 3, 4, 5, 6, 7, 8],
instances: vec![vec![Fr::from(1)], vec![Fr::from(2)]],
transcript_type: TranscriptType::EVM,
protocol: None,
hex_proof: None,
split: None,
pretty_public_inputs: None,
timestamp: None,
commitment: None,
version: None,
};
@@ -1012,6 +726,5 @@ mod tests {
.unwrap();
assert_eq!(snark.instances, snark2.instances);
assert_eq!(snark.proof, snark2.proof);
assert_eq!(snark.transcript_type, snark2.transcript_type);
}
}

View File

@@ -480,6 +480,13 @@ impl<T: Clone + TensorType> Tensor<T> {
self[index].clone()
}
/// Extracts a single value from the tensor
pub fn get_scalar(&self) -> T {
assert!(self.inner.len() == 1);
assert!(self.dims.iter().all(|dim| *dim == 1));
self.inner[0].clone()
}
/// Get a mutable array index from rows / columns indices.
///
/// ```
@@ -901,6 +908,22 @@ impl<T: Clone + TensorType> Tensor<T> {
Ok(())
}
/// remove axes that have dimensions 1
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap();
/// let b = a.remove_trivial_axes().unwrap();
/// assert_eq!(b, expected);
/// ```
pub fn remove_trivial_axes(&self) -> Result<Self, TensorError> {
let mut result = self.clone();
let new_dims: Vec<_> = self.dims.iter().copied().filter(|dim| *dim > 1).collect();
result.reshape(&new_dims)?;
Ok(result)
}
/// Move axis of the tensor
/// ```
/// use ezkl::tensor::Tensor;

View File

@@ -5,6 +5,7 @@ use crate::{
};
use itertools::Itertools;
use maybe_rayon::{iter::ParallelIterator, prelude::IntoParallelRefIterator};
use std::collections::{HashMap, HashSet};
pub use std::ops::{Add, Mul, Neg, Sub};
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
@@ -2396,6 +2397,8 @@ pub mod nonlinearities {
/// Ops that return the transcript i.e intermediate calcs of an op
pub mod accumulated {
use maybe_rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator};
use super::*;
/// Dot product of two tensors.
@@ -2523,4 +2526,327 @@ pub mod accumulated {
Ok(transcript)
}
#[inline]
fn row_major_strides(dims: &[usize]) -> Vec<usize> {
let mut s = vec![0; dims.len()];
let mut acc = 1;
for (i, &d) in dims.iter().enumerate().rev() {
s[i] = acc;
acc *= d;
}
s
}
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::accumulated::einsum;
///
/// // matmul case
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[2, 3, 2, 1, 1, 1]),
/// &[3, 2],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("ij,jk->ik", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[8, 9, 5, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// // element wise multiplication
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
/// &[3, 3],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("ij,ij->ij", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 9, 2, 6, 12, 3, 8, 15]), &[3, 3]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // dot product of A with the transpose of B.
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
/// &[3, 3],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("ik,jk->ij", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 14, 14, 20, 20, 20, 26, 26, 26]), &[3, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // dot product
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]),
/// &[3, 3],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("ik,ik->i", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[14, 20, 26]), &[3]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // dot product
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3]),
/// &[3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3]),
/// &[3],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("i,i->", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[14]), &[1]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // wut ?
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("anm,bm->ba", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[68, 80, 95, 113, 134, 158]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // wutttttt ?
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let z = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8, 9, 9]),
/// &[2, 3],
/// ).unwrap();
///
/// let (result, _) = einsum::<IntegerRep>("bn,anm,bm->ba", &[&z, &x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[390, 414, 534, 994, 1153, 1384]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // contraction with a single common axis
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("abc,cd->", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[648]), &[1]).unwrap();
/// assert_eq!(result, expected);
///
/// // contraction with no common axes (outer product)
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]),
/// &[3, 3, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let (result, _) = einsum::<IntegerRep>("abc,ed->", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1296]), &[1]).unwrap();
/// assert_eq!(result, expected);
///
/// // trivial axes mapping
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 7, 8]),
/// &[2, 2],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[4, 5]),
/// &[2],
/// ).unwrap();
///
/// let (result, _) = einsum::<IntegerRep>("mk,k->m", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[41, 68]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
/// let (result, _) = einsum::<IntegerRep>("mk,k->mn", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[41, 68]), &[2, 1]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[0, 0, 0, 3]),
/// &[1, 4],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// Some(&[213, 227, 74, 77]),
/// &[4],
/// ).unwrap();
///
/// let (result, _) = einsum::<IntegerRep>("mk,k->ma", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[231]), &[1, 1]).unwrap();
/// assert_eq!(result, expected);
/// // subtle difference
/// let (result, _) = einsum::<IntegerRep>("mk,n->ma", &[&x, &k]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1773]), &[1, 1]).unwrap();
/// assert_eq!(result, expected);
///
/// ```
///
pub fn einsum<T>(
equation: &str,
input_tensors: &[&Tensor<T>],
) -> Result<(Tensor<T>, HashMap<char, usize>), TensorError>
where
T: Clone + TensorType + Mul<Output = T> + Add<Output = T> + Send + Sync,
{
let (input_exprs, output_expr) = equation.split_once("->").unwrap();
let input_exprs: Vec<&str> = input_exprs.split(',').collect();
assert_eq!(input_exprs.len(), input_tensors.len());
let mut dim_of: HashMap<char, usize> = HashMap::new();
for (input_expr, t) in input_exprs.iter().zip(input_tensors.iter()) {
for (c, &d) in input_expr.chars().zip(t.dims().iter()) {
let e = dim_of.entry(c).or_insert(d);
debug_assert!((*e == d) || (*e == 1) || (d == 1));
*e = (*e).max(d);
}
}
// Output dims
let out_idx: Vec<char> = output_expr.chars().collect();
let out_dims: Vec<usize> = out_idx
.iter()
.map(|c| *dim_of.get(c).unwrap_or(&1))
.collect();
// Reduction indices
let all_idx: HashSet<char> = dim_of.keys().copied().collect();
let out_set: HashSet<char> = out_idx.iter().copied().collect();
let red_idx: Vec<char> = all_idx.difference(&out_set).copied().collect();
let red_dims: Vec<usize> = red_idx.iter().map(|c| dim_of[c]).collect();
// Fast index->pos
let out_pos: HashMap<char, usize> =
out_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
let red_pos: HashMap<char, usize> =
red_idx.iter().enumerate().map(|(i, &c)| (c, i)).collect();
// Precompute strides per input and contributions
struct Contrib {
out_stride: Vec<usize>,
red_stride: Vec<usize>,
}
let contribs: Vec<Contrib> = input_exprs
.iter()
.zip(input_tensors.iter())
.map(|(expr, t)| {
let dims = t.dims().to_vec();
let strides = row_major_strides(&dims);
let mut out_stride = vec![0; out_idx.len()];
let mut red_stride = vec![0; red_idx.len()];
for (ax, (c, &d)) in expr.chars().zip(dims.iter()).enumerate() {
let s = if d == 1 { 0 } else { strides[ax] };
if let Some(&p) = out_pos.get(&c) {
out_stride[p] = s;
} else if let Some(&q) = red_pos.get(&c) {
red_stride[q] = s;
}
}
Contrib {
out_stride,
red_stride,
}
})
.collect();
// Prepare output buffer
let mut out = if out_dims.is_empty() {
Tensor::<T>::new(None, &[1])?
} else {
Tensor::<T>::new(None, &out_dims)?
};
let out_rank = out_dims.len();
let red_rank = red_dims.len();
// Materialize output elements one by one
out.par_iter_mut()
.enumerate()
.for_each(|(out_linear_coord, out)| {
let mut out_index = vec![0usize; out_rank];
{
let mut x = out_linear_coord;
for i in (0..out_rank).rev() {
let d = out_dims[i];
out_index[i] = x % d;
x /= d;
}
}
// Base offset per input from output coordinates
let mut base_off = vec![0usize; input_tensors.len()];
for (i, c) in contribs.iter().enumerate() {
let mut off = 0usize;
for p in 0..out_rank {
off += out_index[p] * c.out_stride[p];
}
base_off[i] = off;
}
let mut acc = T::zero().unwrap();
if red_rank == 0 {
// No reduction -> just multiply corresponding elements
let mut prod = T::one().unwrap();
for (i, t) in input_tensors.iter().enumerate() {
let val = t.get_flat_index(base_off[i]);
prod = prod * val;
}
acc = acc + prod;
} else {
// Iterate over all reduction coords
let red_size = red_dims.iter().product::<usize>();
let mut red_index = vec![0usize; red_rank];
for red_linear_coord in 0..red_size {
{
let mut x = red_linear_coord;
for q in (0..red_rank).rev() {
let d = red_dims[q];
red_index[q] = x % d;
x /= d;
}
}
let mut prod = T::one().unwrap();
for (i, (t, c)) in input_tensors.iter().zip(contribs.iter()).enumerate() {
let mut off = base_off[i];
for q in 0..red_rank {
off += red_index[q] * c.red_stride[q];
}
let val = t.get_flat_index(off);
prod = prod * val;
}
acc = acc + prod;
}
}
// write result
*out = acc;
});
Ok((out, dim_of))
}
}

View File

@@ -940,6 +940,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// remove axes that have dimensions 1
pub fn remove_trivial_axes(&mut self) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.remove_trivial_axes()?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
}
};
Ok(())
}
/// Takes a slice of the tensor along a given axis
///
/// # Arguments

View File

@@ -1,3 +1,4 @@
use halo2_proofs::plonk::SecondPhase;
use log::{debug, error, warn};
use crate::circuit::{region::ConstantsMap, CheckMode};
@@ -152,6 +153,52 @@ impl VarTensor {
}
}
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
/// the values need to be hidden in the proof.
///
/// # Arguments
/// * `cs` - The constraint system to create columns in
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_inner_cols` - Number of columns in each inner block
/// * `capacity` - Total number of advice cells to allocate
///
/// # Returns
/// A new VarTensor::Advice in SecondPhase with blinded columns enabled for equality constraints
pub fn new_advice_in_second_phase<F: PrimeField>(
cs: &mut ConstraintSystem<F>,
logrows: usize,
num_inner_cols: usize,
capacity: usize,
) -> Self {
let max_rows = Self::max_rows(cs, logrows);
let max_assignments = Self::max_rows(cs, logrows) * num_inner_cols;
let mut modulo = (capacity / max_assignments) + 1;
// we add a buffer for duplicated rows (we get at most 1 duplicated row per column)
modulo = ((capacity + modulo) / max_assignments) + 1;
let mut advices = vec![];
if modulo > 1 {
debug!("using column duplication for {} advice blocks", modulo - 1);
}
for _ in 0..modulo {
let mut inner = vec![];
for _ in 0..num_inner_cols {
let col = cs.advice_column_in(SecondPhase);
cs.enable_equality(col);
inner.push(col);
}
advices.push(inner);
}
VarTensor::Advice {
inner: advices,
num_inner_cols,
col_size: max_rows,
}
}
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
/// Fixed columns are used for constant values that are known at circuit creation time.
///
@@ -270,7 +317,7 @@ impl VarTensor {
/// # Returns
/// A tuple of (block_index, column_index, row_index)
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
// x indexes over blocks of size num_inner_cols
// x (block idx) indexes over blocks of size num_inner_cols
let x = linear_coord / self.block_size();
// y indexes over the cols inside a block
let y = linear_coord % self.num_inner_cols();
@@ -519,7 +566,7 @@ impl VarTensor {
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
row: usize,
_row: usize,
offset: usize,
values: &ValTensor<F>,
single_inner_col: bool,
@@ -545,7 +592,7 @@ impl VarTensor {
self.num_inner_cols()
};
let duplication_offset = if single_inner_col { row } else { offset };
let (_, _, duplication_offset) = self.cartesian_coord(offset);
// duplicates every nth element to adjust for column overflow
let mut res: ValTensor<F> = v
@@ -651,7 +698,7 @@ impl VarTensor {
>(
&self,
region: &mut Region<F>,
row: usize,
_row: usize,
offset: usize,
values: &ValTensor<F>,
check_mode: &CheckMode,
@@ -669,7 +716,7 @@ impl VarTensor {
ValTensor::Value { inner: v, dims, .. } => {
let duplication_freq = self.col_size();
let num_repeats = 1;
let duplication_offset = row;
let (_, _, duplication_offset) = self.cartesian_coord(offset);
// duplicates every nth element to adjust for column overflow
let v = v

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show More