Compare commits

..

28 Commits

Author SHA1 Message Date
dante
1a75963705 refactor: DataSource enum -> struct 2025-04-29 12:18:24 -04:00
dante
0ef1f35e59 fix: uniffi bindings 2025-04-29 12:12:56 -04:00
dante
808ab7d0de chore: feature-gate eth (#978) 2025-04-29 12:02:23 -04:00
dante
68b2c96b97 Merge branch 'vka-hashing' of https://github.com/zkonduit/ezkl into vka-hashing 2025-04-29 11:31:20 -04:00
dante
9a0ab22fdb fix matches 2025-04-29 11:31:13 -04:00
dante
f2b1de3740 Merge branch 'main' into vka-hashing 2025-04-29 11:26:04 -04:00
Ethan
dcb888ff1e fix wasm package graph data import error 2025-04-28 16:29:09 -05:00
Ethan
26f465e70c bring back zizmor analysis 2025-04-28 08:21:26 -05:00
Ethan
8eef53213d rmv data attestation 2025-04-27 19:36:54 -05:00
Ethan
a1345966d7 configure Git credentials more persistently 2025-04-27 18:09:37 -05:00
Ethan
640061c850 set git config after action checkouts 2025-04-27 17:48:30 -05:00
Ethan
da7db7d88d use git config local instead of global 2025-04-27 17:20:24 -05:00
Ethan
a55f75ff3f rmv debug statement on token 2025-04-24 11:19:19 -05:00
Ethan
bf6f704827 debug token 2025-04-24 10:56:27 -05:00
Ethan
0dbfdf4672 debug token 2025-04-24 10:54:56 -05:00
Ethan
98299356a6 *fix syntax error on yaml 2025-04-24 10:51:24 -05:00
Ethan
04805d2a91 move token env to job level 2025-04-24 10:42:35 -05:00
Ethan
ca18cf29bb set token as global env var 2025-04-24 10:36:37 -05:00
Ethan
78f8e23b55 use verification ezkl token 2025-04-24 10:26:15 -05:00
Ethan
7d40926082 activate git fetch with cli on runner 2025-04-24 09:53:20 -05:00
Ethan
e2c8182871 *update python bindings 2025-04-24 09:43:55 -05:00
Ethan
4f077c9134 *use https for loading h2 sol verifier crate 2025-04-23 21:57:07 -05:00
Ethan
038805ce02 Merge branch 'main' into vka-hashing 2025-04-23 21:32:56 -05:00
Ethan
0fb87c9a20 *update lock 2025-04-23 21:30:43 -05:00
Ethan
77423a6d07 *check that on-chain rescaled instances match what is stored in proof file. 2025-04-23 21:25:35 -05:00
Ethan
8b416c7a00 *comment out swift package test 2025-04-21 04:31:51 -05:00
Ethan
73ec5e549a *temporarily disable zizmor + swift package on ci. 2025-04-21 04:27:36 -05:00
Ethan
28386d8442 vka hashing + rescaling 2025-04-21 04:13:31 -05:00
142 changed files with 9086 additions and 46844 deletions

View File

@@ -15,7 +15,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -32,7 +32,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -49,7 +49,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -66,7 +66,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -83,7 +83,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -100,7 +100,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -117,7 +117,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -134,7 +134,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -151,7 +151,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -168,7 +168,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -185,7 +185,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true

View File

@@ -18,32 +18,29 @@ jobs:
permissions:
contents: read
packages: write
id-token: write # Required for provenance
name: publish-wasm-bindings
env:
RELEASE_TAG: ${{ github.ref_name }}
RUSTFLAGS: "-C target-feature=+atomics,+bulk-memory"
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-05-01
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
cache: false
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.12.1
version: "v0.12.1"
version: 'v0.12.1'
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2025-05-01-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
@@ -52,45 +49,45 @@ jobs:
wasm-opt --version
- name: Build wasm files for both web and nodejs compilation targets
run: |
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
- name: Create package.json in pkg folder
shell: bash
run: |
cat > pkg/package.json << EOF
{
"name": "@ezkljs/engine",
"version": "$RELEASE_TAG",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}
EOF
cat > pkg/package.json << EOF
{
"name": "@ezkljs/engine",
"version": "$RELEASE_TAG",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}
EOF
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:21,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
run: |
@@ -178,17 +175,76 @@ jobs:
run: |
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
# zizmor: ignore cache-poisoning
- name: Set up Node.js
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
package-manager-cache: false
- name: Publish to npm with provenance
- name: Publish to npm
run: |
cd pkg
npm publish --provenance --access public
npm install
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
permissions:
contents: read
packages: write
name: publish-in-browser-evm-verifier-package
needs: [publish-wasm-bindings]
runs-on: ubuntu-latest
env:
RELEASE_TAG: ${{ github.ref_name }}
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Update version in package.json
shell: bash
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
- name: Prepare tag and fetch package integrity
run: |
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
echo "ENGINE_INTEGRITY=$ENGINE_INTEGRITY" >> $GITHUB_ENV
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"$CLEANED_TAG\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Update pnpm-lock.yaml versions and integrity
run: |
awk -v integrity="$ENGINE_INTEGRITY" -v tag="$CLEANED_TAG" '
NR==30{$0=" specifier: \"" tag "\""}
NR==31{$0=" version: \"" tag "\""}
NR==400{$0=" /@ezkljs/engine@" tag ":"}
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
- name: Use pnpm 8
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
pnpm install --frozen-lockfile
pnpm run build
pnpm publish --no-git-checks
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -13,9 +13,9 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-05-01
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -27,8 +27,6 @@ jobs:
target: [x86_64]
env:
RELEASE_TAG: ${{ github.ref_name }}
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
@@ -38,16 +36,6 @@ jobs:
python-version: 3.12
architecture: x64
- name: Install build dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- name: Set pyproject.toml version to match github tag and rename ezkl to ezkl-gpu
shell: bash
run: |
@@ -55,12 +43,11 @@ jobs:
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig > pyproject.toml.tmp
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.tmp > pyproject.toml
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
cache: false
- name: Set Cargo.toml version to match github tag and rename ezkl to ezkl-gpu
shell: bash
@@ -83,7 +70,7 @@ jobs:
target: ${{ matrix.target }}
manylinux: auto
container: off
args: --release --out dist --features python-bindings,gpu-accelerated
args: --release --out dist --features python-bindings,icicle
- name: Install built wheel
if: matrix.target == 'x86_64'

View File

@@ -48,12 +48,11 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-05-01
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
cache: false
- name: Build wheels
if: matrix.target == 'universal2-apple-darwin'
@@ -114,12 +113,11 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-05-01
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
cache: false
- name: Build wheels
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0

View File

@@ -26,6 +26,7 @@ jobs:
shell: bash
run: |
echo "EZKL_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
echo "version is: ${{ env.EZKL_VERSION }}"
- name: Create Github Release
id: create-release
@@ -47,32 +48,23 @@ jobs:
TARGET_DIR: ./target
RUST_BACKTRACE: 1
PCRE2_SYS_STATIC: 1
RUSTFLAGS: "-C linker=gcc"
OPENSSL_NO_VENDOR: 1
steps:
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-05-01
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
cache: false
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Install build dependencies
run: |
sudo apt-get update
sudo apt-get install -y build-essential g++ gcc cmake libclang-dev llvm-dev libstdc++-12-dev libc6 libc6-dev libssl-dev pkg-config
- name: Force rebuild icicle dependencies
run: cargo clean -p icicle-runtime -p icicle-core -p icicle-hash -p icicle-bn254
- name: Get release version from tag
shell: bash
run: |
echo "EZKL_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
echo "version is: ${{ env.EZKL_VERSION }}"
- name: Set Cargo.toml version to match github tag
shell: bash
@@ -88,7 +80,7 @@ jobs:
sudo apt-get update
- name: Build release binary
run: cargo build --release -Z sparse-registry --features gpu-accelerated
run: cargo build --release -Z sparse-registry --features icicle
- name: Build archive
shell: bash
@@ -127,27 +119,27 @@ jobs:
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2025-05-01
rust: nightly-2025-02-17
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2025-05-01
rust: nightly-2025-02-17
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2025-05-01
rust: nightly-2025-02-17
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2025-05-01
rust: nightly-2025-02-17
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2025-05-01
rust: nightly-2025-02-17
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2025-05-01
rust: nightly-2025-02-17
target: aarch64-unknown-linux-gnu
steps:
@@ -160,6 +152,7 @@ jobs:
shell: bash
run: |
echo "EZKL_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
echo "version is: ${{ env.EZKL_VERSION }}"
- name: Set Cargo.toml version to match github tag
shell: bash
@@ -205,15 +198,15 @@ jobs:
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal,mimalloc
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'

File diff suppressed because it is too large Load Diff

View File

@@ -15,9 +15,9 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rust-lang/setup-rust-toolchain@fb51252c7ba57d633bc668f941da052e410add48 #v1.0.6
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-05-01
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy

2
.gitignore vendored
View File

@@ -52,5 +52,3 @@ docs/python/build
!tests/assets/vk_aggr.key
cache
out
!tests/assets/wasm.code
!tests/assets/wasm.sol

469
Cargo.lock generated
View File

@@ -126,27 +126,6 @@ dependencies = [
"winnow 0.6.26",
]
[[package]]
name = "alloy-eip2930"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0069cf0642457f87a01a014f6dc29d5d893cd4fd8fddf0c3cdfad1bb3ebafc41"
dependencies = [
"alloy-primitives 0.8.25",
"alloy-rlp",
]
[[package]]
name = "alloy-eip7702"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea59dc42102bc9a1905dc57901edc6dd48b9f38115df86c7d252acba70d71d04"
dependencies = [
"alloy-primitives 0.8.25",
"alloy-rlp",
"k256",
]
[[package]]
name = "alloy-eips"
version = "0.1.0"
@@ -238,7 +217,7 @@ dependencies = [
"bytes",
"cfg-if",
"const-hex",
"derive_more 0.99.20",
"derive_more",
"hex-literal",
"itoa",
"ruint",
@@ -255,7 +234,7 @@ dependencies = [
"bytes",
"cfg-if",
"const-hex",
"derive_more 0.99.20",
"derive_more",
"getrandom 0.2.16",
"hex-literal",
"itoa",
@@ -268,25 +247,6 @@ dependencies = [
"tiny-keccak",
]
[[package]]
name = "alloy-primitives"
version = "0.8.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c77490fe91a0ce933a1f219029521f20fc28c2c0ca95d53fa4da9c00b8d9d4e"
dependencies = [
"alloy-rlp",
"bytes",
"cfg-if",
"const-hex",
"derive_more 2.0.1",
"hashbrown 0.15.2",
"itoa",
"k256",
"paste",
"ruint",
"tiny-keccak",
]
[[package]]
name = "alloy-provider"
version = "0.1.0"
@@ -882,13 +842,14 @@ dependencies = [
]
[[package]]
name = "aurora-engine-modexp"
version = "1.2.0"
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "518bc5745a6264b5fd7b09dffb9667e400ee9e2bbe18555fac75e1fe9afa0df9"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hex",
"num",
"hermit-abi 0.1.19",
"libc",
"winapi",
]
[[package]]
@@ -959,6 +920,29 @@ dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.11.0",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.101",
"which",
]
[[package]]
name = "bit-set"
version = "0.5.3"
@@ -1177,6 +1161,15 @@ dependencies = [
"shlex",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
@@ -1227,7 +1220,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
"half 2.6.0",
]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "clap"
version = "2.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
dependencies = [
"bitflags 1.3.2",
"textwrap 0.11.0",
"unicode-width 0.1.14",
]
[[package]]
@@ -1258,7 +1273,7 @@ version = "4.5.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06f5378ea264ad4f82bbc826628b5aad714a75abf6ece087e923010eb937fb6"
dependencies = [
"clap",
"clap 4.5.37",
]
[[package]]
@@ -1301,7 +1316,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c"
dependencies = [
"lazy_static",
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]
@@ -1445,6 +1460,32 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "criterion"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f"
dependencies = [
"atty",
"cast",
"clap 2.34.0",
"criterion-plot 0.4.5",
"csv",
"itertools 0.10.5",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_cbor",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion"
version = "0.5.1"
@@ -1454,8 +1495,8 @@ dependencies = [
"anes",
"cast",
"ciborium",
"clap",
"criterion-plot",
"clap 4.5.37",
"criterion-plot 0.5.0",
"is-terminal",
"itertools 0.10.5",
"num-traits",
@@ -1471,6 +1512,16 @@ dependencies = [
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
@@ -1543,6 +1594,27 @@ dependencies = [
"typenum",
]
[[package]]
name = "csv"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf"
dependencies = [
"csv-core",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "csv-core"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d"
dependencies = [
"memchr",
]
[[package]]
name = "dashmap"
version = "5.5.3"
@@ -1621,27 +1693,6 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "derive_more"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.9.0"
@@ -1888,7 +1939,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
dependencies = [
"libc",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -1943,12 +1994,12 @@ dependencies = [
"bincode",
"camino",
"chrono",
"clap",
"clap 4.5.37",
"clap_complete",
"colored",
"colored_json",
"console_error_panic_hook",
"criterion",
"criterion 0.5.1",
"ecc",
"env_logger 0.10.2",
"ethabi",
@@ -1960,11 +2011,9 @@ dependencies = [
"halo2_solidity_verifier",
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
"hex",
"icicle-runtime",
"indicatif",
"instant",
"itertools 0.10.5",
"jemallocator",
"lazy_static",
"log",
"maybe-rayon",
@@ -2207,7 +2256,7 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4"
dependencies = [
"rustix",
"rustix 1.0.5",
"windows-sys 0.59.0",
]
@@ -2400,6 +2449,12 @@ dependencies = [
"subtle",
]
[[package]]
name = "half"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
[[package]]
name = "half"
version = "2.6.0"
@@ -2414,7 +2469,7 @@ dependencies = [
[[package]]
name = "halo2_gadgets"
version = "0.2.0"
source = "git+https://github.com/zkonduit/halo2?branch=ac%2Fconditional-compilation-icicle2#1dd2090741f006fd031a07da7f3c9dfce5e0015e"
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
dependencies = [
"arrayvec 0.7.6",
"bitvec",
@@ -2431,7 +2486,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e?branch=ac%2Fconditional-compilation-icicle2#01c88842679b4308e43ae5ed91c4183e861669bd"
source = "git+https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d"
dependencies = [
"bincode",
"blake2b_simd",
@@ -2441,7 +2496,7 @@ dependencies = [
"halo2curves 0.7.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=b753a832e92d5c86c5c997327a9cf9de86a18851)",
"icicle-bn254",
"icicle-core",
"icicle-runtime",
"icicle-cuda-runtime",
"instant",
"lazy_static",
"log",
@@ -2449,7 +2504,7 @@ dependencies = [
"mopro-msm",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
"rustc-hash",
"rustc-hash 2.1.1",
"serde",
"sha3 0.9.1",
"tracing",
@@ -2458,7 +2513,7 @@ dependencies = [
[[package]]
name = "halo2_solidity_verifier"
version = "0.1.0"
source = "git+https://github.com/zkonduit/ezkl-verifier?branch=main#a518a917f076adb851a1ae39e09527f8dbde5000"
source = "git+https://github.com/zkonduit/verification-ezkl?branch=vka-hash#409f977e461b435b9afc33ed38edba09fe2eaee4"
dependencies = [
"askama",
"blake2b_simd",
@@ -2466,7 +2521,6 @@ dependencies = [
"hex",
"itertools 0.11.0",
"regex",
"revm 14.0.3",
"ruint",
"sha3 0.10.8",
]
@@ -2664,6 +2718,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "hermit-abi"
version = "0.3.9"
@@ -2855,45 +2918,33 @@ dependencies = [
[[package]]
name = "icicle-bn254"
version = "3.7.0"
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
version = "2.8.0"
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
dependencies = [
"cmake",
"criterion 0.3.6",
"icicle-core",
"icicle-hash",
"icicle-runtime",
"icicle-cuda-runtime",
]
[[package]]
name = "icicle-core"
version = "3.7.0"
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
version = "2.8.0"
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
dependencies = [
"criterion 0.3.6",
"hex",
"icicle-runtime",
"once_cell",
"rand 0.8.5",
"icicle-cuda-runtime",
"rayon",
]
[[package]]
name = "icicle-hash"
version = "3.7.0"
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
name = "icicle-cuda-runtime"
version = "2.8.0"
source = "git+https://github.com/ingonyama-zk/icicle?branch=ezkl-icicle2#5dfe006a0f1bc62ea82ca297709bbf3d22a2ca25"
dependencies = [
"cmake",
"icicle-core",
"icicle-runtime",
"rand 0.8.5",
]
[[package]]
name = "icicle-runtime"
version = "3.7.0"
source = "git+https://github.com/ingonyama-zk/icicle?branch=emir%2Fgate_eval_2#012e00694f4cf399fe7a42d9cfbfa6cd7a60f876"
dependencies = [
"cmake",
"once_cell",
"bindgen",
"bitflags 1.3.2",
]
[[package]]
@@ -3151,7 +3202,7 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9"
dependencies = [
"hermit-abi 0.5.0",
"libc",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -3211,26 +3262,6 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "jemalloc-sys"
version = "0.5.4+5.3.0-patched"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac6c1946e1cea1788cbfde01c993b52a10e2da07f4bac608228d1bed20bfebf2"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "jemallocator"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0de374a9f8e63150e6f5e8a60cc14c668226d7a347d8aee1a45766e3c4dd3bc"
dependencies = [
"jemalloc-sys",
"libc",
]
[[package]]
name = "jiff"
version = "0.2.10"
@@ -3346,12 +3377,28 @@ dependencies = [
"spin",
]
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.172"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
[[package]]
name = "libloading"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.48.5",
]
[[package]]
name = "libm"
version = "0.2.13"
@@ -3379,6 +3426,12 @@ dependencies = [
"redox_syscall",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "linux-raw-sys"
version = "0.9.4"
@@ -3817,7 +3870,7 @@ dependencies = [
"num-traits",
"pyo3",
"pyo3-build-config",
"rustc-hash",
"rustc-hash 2.1.1",
]
[[package]]
@@ -4248,6 +4301,16 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]]
name = "prettyplease"
version = "0.2.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6"
dependencies = [
"proc-macro2",
"syn 2.0.101",
]
[[package]]
name = "primal-check"
version = "0.3.4"
@@ -4500,7 +4563,7 @@ dependencies = [
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustc-hash 2.1.1",
"rustls",
"socket2",
"thiserror 2.0.12",
@@ -4519,7 +4582,7 @@ dependencies = [
"getrandom 0.3.2",
"rand 0.9.1",
"ring",
"rustc-hash",
"rustc-hash 2.1.1",
"rustls",
"rustls-pki-types",
"slab",
@@ -4540,7 +4603,7 @@ dependencies = [
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -4835,21 +4898,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68f4ca8ae0345104523b4af1a8a7ea97cfa1865cdb7a7c25d23c1a18d9b48598"
dependencies = [
"auto_impl",
"revm-interpreter 1.3.0",
"revm-precompile 2.2.0",
]
[[package]]
name = "revm"
version = "14.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "641702b12847f9ed418d552f4fcabe536d867a2c980e96b6e7e25d7b992f929f"
dependencies = [
"auto_impl",
"cfg-if",
"dyn-clone",
"revm-interpreter 10.0.3",
"revm-precompile 11.0.3",
"revm-interpreter",
"revm-precompile",
]
[[package]]
@@ -4858,16 +4908,7 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f959cafdf64a7f89b014fa73dc2325001cf654b3d9400260b212d19a2ebe3da0"
dependencies = [
"revm-primitives 1.3.0",
]
[[package]]
name = "revm-interpreter"
version = "10.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e5e14002afae20b5bf1566f22316122f42f57517000e559c55b25bf7a49cba2"
dependencies = [
"revm-primitives 10.0.0",
"revm-primitives",
]
[[package]]
@@ -4879,23 +4920,7 @@ dependencies = [
"k256",
"num",
"once_cell",
"revm-primitives 1.3.0",
"ripemd",
"sha2",
"substrate-bn",
]
[[package]]
name = "revm-precompile"
version = "11.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3198c06247e8d4ad0d1312591edf049b0de4ddffa9fecb625c318fd67db8639b"
dependencies = [
"aurora-engine-modexp",
"cfg-if",
"k256",
"once_cell",
"revm-primitives 10.0.0",
"revm-primitives",
"ripemd",
"sha2",
"substrate-bn",
@@ -4917,24 +4942,6 @@ dependencies = [
"hex",
]
[[package]]
name = "revm-primitives"
version = "10.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f1525851a03aff9a9d6a1d018b414d76252d6802ab54695b27093ecd7e7a101"
dependencies = [
"alloy-eip2930",
"alloy-eip7702",
"alloy-primitives 0.8.25",
"auto_impl",
"bitflags 2.9.0",
"bitvec",
"cfg-if",
"dyn-clone",
"enumn",
"hex",
]
[[package]]
name = "rfc6979"
version = "0.4.0"
@@ -5017,6 +5024,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
@@ -5062,6 +5075,19 @@ dependencies = [
"version_check",
]
[[package]]
name = "rustix"
version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.9.0",
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.52.0",
]
[[package]]
name = "rustix"
version = "1.0.5"
@@ -5071,8 +5097,8 @@ dependencies = [
"bitflags 2.9.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.59.0",
"linux-raw-sys 0.9.4",
"windows-sys 0.52.0",
]
[[package]]
@@ -5319,6 +5345,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half 1.8.3",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.219"
@@ -5499,7 +5535,7 @@ dependencies = [
"num-traits",
"poseidon",
"rand 0.8.5",
"revm 3.5.0",
"revm",
"serde",
"sha3 0.10.8",
]
@@ -5765,8 +5801,8 @@ dependencies = [
"fastrand",
"getrandom 0.3.2",
"once_cell",
"rustix",
"windows-sys 0.59.0",
"rustix 1.0.5",
"windows-sys 0.52.0",
]
[[package]]
@@ -5811,6 +5847,15 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "textwrap"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
dependencies = [
"unicode-width 0.1.14",
]
[[package]]
name = "textwrap"
version = "0.16.2"
@@ -6204,7 +6249,7 @@ dependencies = [
"downcast-rs",
"dyn-clone",
"dyn-hash",
"half",
"half 2.6.0",
"itertools 0.12.1",
"lazy_static",
"maplit",
@@ -6239,7 +6284,7 @@ dependencies = [
"downcast-rs",
"dyn-clone",
"dyn-hash",
"half",
"half 2.6.0",
"lazy_static",
"liquid",
"liquid-core",
@@ -6420,7 +6465,7 @@ dependencies = [
"once_cell",
"paste",
"serde",
"textwrap",
"textwrap 0.16.2",
"toml 0.5.11",
"uniffi_meta",
"uniffi_testing",
@@ -6512,7 +6557,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cef408229a3a407fafa4c36dc4f6ece78a6fb258ab28d2b64bddd49c8cb680f6"
dependencies = [
"anyhow",
"textwrap",
"textwrap 0.16.2",
"uniffi_meta",
"uniffi_testing",
"weedle2",
@@ -6852,6 +6897,18 @@ dependencies = [
"nom",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix 0.38.44",
]
[[package]]
name = "winapi"
version = "0.3.9"
@@ -6874,7 +6931,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]
@@ -7229,7 +7286,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e"
dependencies = [
"libc",
"rustix",
"rustix 1.0.5",
]
[[package]]

View File

@@ -16,12 +16,12 @@ crate-type = ["cdylib", "rlib", "staticlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "ac/conditional-compilation-icicle2" }
halo2_gadgets = { git = "https://github.com/zkonduit/halo2" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
"derive_serde",
] }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
"circuit-params", "mv-lookup"
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", features = [
"circuit-params",
] }
rand = { version = "0.8", default-features = false }
itertools = { version = "0.10.3", default-features = false }
@@ -33,11 +33,9 @@ thiserror = { version = "1.0.38", default-features = false }
hex = { version = "0.4.3", default-features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde", "mv-lookup"
] }
halo2_solidity_verifier = { git = "https://github.com/zkonduit/ezkl-verifier", branch = "main", optional = true, features = [
"evm", "mv-lookup",
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/zkonduit/verification-ezkl", branch = "vka-hash", optional = true }
maybe-rayon = { version = "0.1.1", default-features = false }
bincode = { version = "1.3.3", default-features = false }
unzip-n = "0.1.2"
@@ -45,12 +43,10 @@ num = "0.4.1"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
semver = { version = "1.0.22", optional = true }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
# evm related deps
serde_json = { version = "1.0.97", features = ["float_roundtrip", "raw_value"] }
# evm related deps
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = [
"provider-http",
"signers",
@@ -60,7 +56,6 @@ alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5
"node-bindings",
], optional = true }
foundry-compilers = { version = "0.4.1", features = [
"svm-solc",
], optional = true }
@@ -93,9 +88,9 @@ pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
jemallocator = { version = "0.5", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
@@ -103,10 +98,6 @@ uniffi_bindgen = { version = "=0.28.0", optional = true }
camino = { version = "^1.1", optional = true }
uuid = { version = "1.10.0", features = ["v4"], optional = true }
# GPU / device related things (optional - only enabled with gpu-accelerated feature)
icicle-runtime = { git = "https://github.com/ingonyama-zk/icicle", branch="emir/gate_eval_2", package="icicle-runtime", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default-features = false, optional = true }
env_logger = { version = "0.10.0", default-features = false, optional = true }
@@ -226,28 +217,20 @@ required-features = ["python-bindings"]
[features]
web = ["wasm-bindgen-rayon"]
default = [
"eth",
"dep:halo2_solidity_verifier",
"eth-mv-lookup",
"ezkl",
"precompute-coset",
"no-banner",
"parallel-poly-read",
"reusable-verifier",
]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
universal-bindings = [
"uniffi",
"mv-lookup",
"precompute-coset",
"parallel-poly-read",
"dep:halo2_solidity_verifier"
]
logging = ["dep:colored", "dep:env_logger", "dep:chrono"]
ios-bindings = ["universal-bindings"]
ios-bindings = ["eth-mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
ezkl = [
"onnx",
"dep:colored",
"dep:env_logger",
"tabled/color",
"serde_json/std",
"colored_json",
@@ -258,50 +241,57 @@ ezkl = [
"dep:lazy_static",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:clap_complete",
"dep:semver",
"dep:clap",
"dep:tosubcommand",
"logging",
]
eth = ["dep:alloy", "dep:foundry-compilers", "dep:ethabi"]
eth = [
"dep:alloy",
"dep:foundry-compilers",
"dep:ethabi",
]
solidity-verifier = [
"dep:halo2_solidity_verifier",
]
solidity-verifier-mv-lookup = [
"halo2_solidity_verifier/mv-lookup",
]
eth-mv-lookup = [
"solidity-verifier-mv-lookup",
"mv-lookup",
"eth",
]
eth-original-lookup = [
"eth",
"solidity-verifier",
]
parallel-poly-read = [
"halo2_proofs/circuit-params",
"halo2_proofs/parallel-poly-read",
]
mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup"]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
]
asm = ["halo2curves/asm", "halo2_proofs/asm"]
precompute-coset = ["halo2_proofs/precompute-coset"]
det-prove = []
gpu-accelerated = ["halo2_proofs/gpu-accelerated", "dep:icicle-runtime"]
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
jemalloc = ["dep:jemallocator"]
mimalloc = ["dep:mimalloc"]
reusable-verifier = []
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
"circuit-params", "mv-lookup"
] }
[patch.'https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#1dd2090741f006fd031a07da7f3c9dfce5e0015e", package = "halo2_proofs", branch= "ac/conditional-compilation-icicle2", features = [
"circuit-params", "mv-lookup"
] }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
@@ -314,5 +304,3 @@ opt-level = 3
[package.metadata.wasm-pack.profile.release]
wasm-opt = ["-O4", "--flexible-inline-max-function-size", "4294967295"]

View File

@@ -76,6 +76,11 @@ For more details visit the [docs](https://docs.ezkl.xyz). The CLI is faster than
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
#### In-browser EVM Verifier
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
### Building the Project 🔨
#### Rust CLI

312
abis/DataAttestation.json Normal file
View File

@@ -0,0 +1,312 @@
[
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256[]",
"name": "_decimals",
"type": "uint256[]"
},
{
"internalType": "uint256[]",
"name": "_bits",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [],
"name": "HALF_ORDER",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "ORDER",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"name": "attestData",
"outputs": [],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "callData",
"outputs": [
{
"internalType": "bytes",
"name": "",
"type": "bytes"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "contractAddress",
"outputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "getInstancesCalldata",
"outputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "getInstancesMemory",
"outputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "index",
"type": "uint256"
}
],
"name": "getScalars",
"outputs": [
{
"components": [
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "bits",
"type": "uint256"
}
],
"internalType": "struct DataAttestation.Scalars",
"name": "",
"type": "tuple"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
{
"internalType": "uint8",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "x",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "y",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "denominator",
"type": "uint256"
}
],
"name": "mulDiv",
"outputs": [
{
"internalType": "uint256",
"name": "result",
"type": "uint256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "int256",
"name": "x",
"type": "int256"
},
{
"components": [
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "bits",
"type": "uint256"
}
],
"internalType": "struct DataAttestation.Scalars",
"name": "_scalars",
"type": "tuple"
}
],
"name": "quantizeData",
"outputs": [
{
"internalType": "int256",
"name": "quantized_data",
"type": "int256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "target",
"type": "address"
},
{
"internalType": "bytes",
"name": "data",
"type": "bytes"
}
],
"name": "staticCall",
"outputs": [
{
"internalType": "bytes",
"name": "",
"type": "bytes"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "int256",
"name": "x",
"type": "int256"
}
],
"name": "toFieldElement",
"outputs": [
{
"internalType": "uint256",
"name": "field_element",
"type": "uint256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
}
]

98
abis/QuantizeData.json Normal file
View File

@@ -0,0 +1,98 @@
[
{
"inputs": [
{
"internalType": "int256[]",
"name": "quantized_data",
"type": "int256[]"
}
],
"name": "check_is_valid_field_element",
"outputs": [
{
"internalType": "uint256[]",
"name": "output",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes[]",
"name": "data",
"type": "bytes[]"
},
{
"internalType": "uint256[]",
"name": "decimals",
"type": "uint256[]"
},
{
"internalType": "uint256[]",
"name": "scales",
"type": "uint256[]"
}
],
"name": "quantize_data_multi",
"outputs": [
{
"internalType": "int256[]",
"name": "quantized_data",
"type": "int256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "data",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256[]",
"name": "scales",
"type": "uint256[]"
}
],
"name": "quantize_data_single",
"outputs": [
{
"internalType": "int256[]",
"name": "quantized_data",
"type": "int256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "int64[]",
"name": "quantized_data",
"type": "int64[]"
}
],
"name": "to_field_element",
"outputs": [
{
"internalType": "uint256[]",
"name": "output",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
}
]

32
abis/TestReads.json Normal file
View File

@@ -0,0 +1,32 @@
[
{
"inputs": [
{
"internalType": "int256[]",
"name": "_numbers",
"type": "int256[]"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"name": "arr",
"outputs": [
{
"internalType": "int256",
"name": "",
"type": "int256"
}
],
"stateMutability": "view",
"type": "function"
}
]

View File

@@ -68,7 +68,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[&self.image, &self.kernel, &self.bias],
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],

View File

@@ -15,7 +15,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -60,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs,
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),

View File

@@ -1,78 +1,52 @@
use criterion::{
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
Throughput,
};
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::einsum::circuit_params::SingleEinsumParams;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::{create_keys, create_proof_circuit, TranscriptType};
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::collections::HashMap;
use std::marker::PhantomData;
static mut LEN: usize = 4;
static mut K: usize = 15;
const K: usize = 16;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum_params: SingleEinsumParams<F>,
struct MyCircuit {
inputs: [ValTensor<Fr>; 2],
_marker: PhantomData<Fr>,
}
impl Circuit<Fr> for MyCircuit<Fr> {
impl Circuit<Fr> for MyCircuit {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = SingleEinsumParams<Fr>;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let mut config = Self::Config::default();
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let len = unsafe { LEN };
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
unsafe {
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
}
let a = VarTensor::new_advice(cs, K, 1, len * len);
config
}
let b = VarTensor::new_advice(cs, K, 1, len * len);
fn params(&self) -> Self::Params {
SingleEinsumParams::<Fr>::new(
&self.einsum_params.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
let output = VarTensor::new_advice(cs, K, 1, (len + 1) * len);
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE)
}
fn synthesize(
@@ -80,33 +54,16 @@ impl Circuit<Fr> for MyCircuit<Fr> {
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs,
Box::new(PolyOp::Einsum {
equation: self.einsum_params.equation.clone(),
equation: "ab,bc->ac".to_string(),
}),
)
.unwrap();
@@ -119,49 +76,41 @@ impl Circuit<Fr> for MyCircuit<Fr> {
fn runmatmul(c: &mut Criterion) {
let mut group = c.benchmark_group("accum_einsum_matmul");
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Linear));
group.sampling_mode(criterion::SamplingMode::Flat);
group.sample_size(10);
let len = 128;
unsafe {
LEN = len;
}
for k in 16..17 {
let params = unsafe {
K = k;
gen_srs::<KZGCommitmentScheme<_>>(K as u32)
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [4, 32].iter() {
unsafe {
LEN = len;
};
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[len, len]).unwrap();
let einsum_params = SingleEinsumParams::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum_params,
_marker: PhantomData,
};
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit<Fr>>(&circuit, &params, false)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit<Fr>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,

View File

@@ -17,7 +17,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -87,13 +86,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[&output.unwrap()],
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -17,7 +17,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -88,13 +87,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[&output.unwrap()],
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -15,7 +15,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -60,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs,
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.unwrap();

View File

@@ -63,7 +63,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[&self.image],
&[self.image.clone()],
Box::new(HybridOp::SumPool {
padding: vec![(0, 0); 2],
stride: vec![1, 1],

View File

@@ -15,7 +15,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -58,11 +57,7 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.unwrap();
Ok(())
},

View File

@@ -16,7 +16,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,11 +58,7 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(4)),
)
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.unwrap();
Ok(())
},

View File

@@ -70,7 +70,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[&self.input],
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,

View File

@@ -67,7 +67,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[&self.input],
&[self.input.clone()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

397
contracts/AttestData.sol Normal file
View File

@@ -0,0 +1,397 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;
contract LoadInstances {
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from memory
* @param encoded - verifier calldata
*/
function getInstancesMemory(
bytes memory encoded
) public pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := mload(add(encoded, 0x20))
}
if (funcSig == 0xaf83a18d) {
instances_offset = 0x64;
} else if (funcSig == 0x1e8e1e13) {
instances_offset = 0x44;
} else {
revert("Invalid function signature");
}
assembly {
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := mload(add(encoded, instances_offset))
instances_length := mload(add(add(encoded, 0x24), instances_offset))
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly {
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
mload(add(add(encoded, add(i, 0x24)), instances_offset))
)
}
}
require(
funcSig == 0xaf83a18d || funcSig == 0x1e8e1e13,
"Invalid function signature"
);
}
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
* @notice must pass encoded bytes from calldata
* @param encoded - verifier calldata
*/
function getInstancesCalldata(
bytes calldata encoded
) public pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
}
if (funcSig == 0xaf83a18d) {
instances_offset = 0x44;
} else if (funcSig == 0x1e8e1e13) {
instances_offset = 0x24;
} else {
revert("Invalid function signature");
}
// We need to create a new assembly block in order for solidity
// to cast the funcSig to a bytes4 type. Otherwise it will load the entire first 32 bytes of the calldata
// within the block
assembly {
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := calldataload(
add(encoded.offset, instances_offset)
)
instances_length := calldataload(
add(add(encoded.offset, 0x04), instances_offset)
)
}
instances = new uint256[](instances_length); // Allocate memory for the instances array.
assembly {
// Now instances points to the start of the array data
// (right after the length field).
for {
let i := 0x20
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
i := add(i, 0x20)
} {
mstore(
add(instances, i),
calldataload(
add(add(encoded.offset, add(i, 0x04)), instances_offset)
)
)
}
}
}
}
// The kzg commitments of a given model, all aggregated into a single bytes array.
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
// It will be used to check that the proof commitments match the expected commitments.
bytes constant COMMITMENT_KZG = hex"1234";
contract SwapProofCommitments {
/**
* @dev Swap the proof commitments
* @notice must pass encoded bytes from memory
* @param encoded - verifier calldata
*/
function checkKzgCommits(
bytes calldata encoded
) internal pure returns (bool equal) {
bytes4 funcSig;
uint256 proof_offset;
uint256 proof_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
}
if (funcSig == 0xaf83a18d) {
proof_offset = 0x24;
} else if (funcSig == 0x1e8e1e13) {
proof_offset = 0x04;
} else {
revert("Invalid function signature");
}
assembly {
// Fetch proof offset which is 4 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 away for `verifyProof(address,bytes,uint256[])`
proof_offset := calldataload(add(encoded.offset, proof_offset))
proof_length := calldataload(
add(add(encoded.offset, 0x04), proof_offset)
)
}
// Check the length of the commitment against the proof bytes
if (proof_length < COMMITMENT_KZG.length) {
return false;
}
// Load COMMITMENT_KZG into memory
bytes memory commitment = COMMITMENT_KZG;
// Compare the first N bytes of the proof with COMMITMENT_KZG
uint words = (commitment.length + 31) / 32; // Calculate the number of 32-byte words
assembly {
// Now we compare the commitment with the proof,
// ensuring that the commitments divided up into 32 byte words are all equal.
for {
let i := 0x20
} lt(i, add(mul(words, 0x20), 0x20)) {
i := add(i, 0x20)
} {
let wordProof := calldataload(
add(add(encoded.offset, add(i, 0x04)), proof_offset)
)
let wordCommitment := mload(add(commitment, i))
equal := eq(wordProof, wordCommitment)
if eq(equal, 0) {
break
}
}
}
return equal; // Return true if the commitment comparison passed
} /// end checkKzgCommits
}
contract DataAttestation is LoadInstances, SwapProofCommitments {
// the address of the account to make calls to
address public immutable contractAddress;
// the abi encoded function calls to make to the `contractAddress` that returns the attested to data
bytes public callData;
struct Scalars {
// The number of base 10 decimals to scale the data by.
// For most ERC20 tokens this is 1e18
uint256 decimals;
// The number of fractional bits of the fixed point EZKL data points.
uint256 bits;
}
Scalars[] private scalars;
function getScalars(uint256 index) public view returns (Scalars memory) {
return scalars[index];
}
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 public constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 public constant HALF_ORDER = ORDER >> 1;
uint8 public instanceOffset;
/**
* @dev Initialize the contract with account calls the EZKL model will read from.
* @param _contractAddresses - The calls to all the contracts EZKL reads storage from.
* @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
*/
constructor(
address _contractAddresses,
bytes memory _callData,
uint256[] memory _decimals,
uint[] memory _bits,
uint8 _instanceOffset
) {
require(
_bits.length == _decimals.length,
"Invalid scalar array lengths"
);
for (uint i; i < _bits.length; i++) {
scalars.push(Scalars(10 ** _decimals[i], 1 << _bits[i]));
}
contractAddress = _contractAddresses;
callData = _callData;
instanceOffset = _instanceOffset;
}
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) public pure returns (uint256 result) {
unchecked {
uint256 prod0;
uint256 prod1;
assembly {
let mm := mulmod(x, y, not(0))
prod0 := mul(x, y)
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
if (prod1 == 0) {
return prod0 / denominator;
}
require(denominator > prod1, "Math: mulDiv overflow");
uint256 remainder;
assembly {
remainder := mulmod(x, y, denominator)
prod1 := sub(prod1, gt(remainder, prod0))
prod0 := sub(prod0, remainder)
}
uint256 twos = denominator & (~denominator + 1);
assembly {
denominator := div(denominator, twos)
prod0 := div(prod0, twos)
twos := add(div(sub(0, twos), twos), 1)
}
prod0 |= prod1 * twos;
uint256 inverse = (3 * denominator) ^ 2;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
result = prod0 * inverse;
return result;
}
}
/**
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param x - One of the elements of the data returned from the account calls
* @param _scalars - The scaling factors for the data returned from the account calls.
*
*/
function quantizeData(
int x,
Scalars memory _scalars
) public pure returns (int256 quantized_data) {
if (_scalars.bits == 1 && _scalars.decimals == 1) {
return x;
}
bool neg = x < 0;
if (neg) x = -x;
uint output = mulDiv(uint256(x), _scalars.bits, _scalars.decimals);
if (
mulmod(uint256(x), _scalars.bits, _scalars.decimals) * 2 >=
_scalars.decimals
) {
output += 1;
}
if (output > HALF_ORDER) {
revert("Overflow field modulus");
}
quantized_data = neg ? -int256(output) : int256(output);
}
/**
* @dev Make a static call to the account to fetch the data that EZKL reads from.
* @param target - The address of the account to make calls to.
* @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
* @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise)
*/
function staticCall(
address target,
bytes memory data
) public view returns (bytes memory) {
(bool success, bytes memory returndata) = target.staticcall(data);
if (success) {
if (returndata.length == 0) {
require(
target.code.length > 0,
"Address: call to non-contract"
);
}
return returndata;
} else {
revert("Address: low-level call failed");
}
}
/**
* @dev Convert the fixed point quantized data into a field element.
* @param x - The quantized data.
* @return field_element - The field element.
*/
function toFieldElement(
int256 x
) public pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
}
/**
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
*/
function attestData(uint256[] memory instances) public view {
bytes memory returnData = staticCall(contractAddress, callData);
int256[] memory x = abi.decode(returnData, (int256[]));
int output;
uint fieldElement;
for (uint i = 0; i < x.length; i++) {
output = quantizeData(x[i], scalars[i]);
fieldElement = toFieldElement(output);
if (fieldElement != instances[i]) {
revert("Public input does not match");
}
}
}
/**
* @dev Verify the proof with the data attestation.
* @param verifier - The address of the verifier contract.
* @param encoded - The verifier calldata.
*/
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
attestData(getInstancesCalldata(encoded));
require(checkKzgCommits(encoded), "Invalid KZG commitments");
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {
revert("low-level call to verifier failed");
}
}
}

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '22.3.0'
release = '0.0.0'
version = release

View File

@@ -1,171 +0,0 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
const K: usize = 13;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let mut config = Self::Config::default();
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runmatmul() {
let len = 64;
let mut a = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[len, len]).unwrap();
let einsum = Einsum::<Fr>::new("ij,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runmatmul()
}

View File

@@ -1,179 +0,0 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
static mut LEN: usize = 4;
const K: usize = 11;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let len = unsafe { LEN };
let a = VarTensor::new_advice(cs, K, 1, len);
let b = VarTensor::new_advice(cs, K, 1, len);
let output = VarTensor::new_advice(cs, K, 1, len);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 1;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runbatchmatmul() {
let batch_size = 5;
let len = 12;
let mut a = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[batch_size, len, len]).unwrap();
// parameters
let mut b = Tensor::from((0..batch_size * len * len).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[batch_size, len, len]).unwrap();
let einsum = Einsum::<Fr>::new("ijk,ikl->ijl", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runbatchmatmul()
}

View File

@@ -32,6 +32,7 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -215,7 +216,11 @@ where
.layer_config
.layout(
&mut region,
&[&self.input, &self.l0_params[0], &self.l0_params[1]],
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
Box::new(op),
)
.unwrap();
@@ -224,7 +229,7 @@ where
.layer_config
.layout(
&mut region,
&[&x.unwrap()],
&[x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -236,7 +241,7 @@ where
.layer_config
.layout(
&mut region,
&[&x.unwrap()],
&[x.unwrap()],
Box::new(LookupOp::Div { denom: 32.0.into() }),
)
.unwrap()
@@ -248,7 +253,7 @@ where
.layer_config
.layout(
&mut region,
&[&self.l2_params[0], &x],
&[self.l2_params[0].clone(), x],
Box::new(PolyOp::Einsum {
equation: "ij,j->ik".to_string(),
}),
@@ -260,7 +265,7 @@ where
.layer_config
.layout(
&mut region,
&[&x, &self.l2_params[1]],
&[x, self.l2_params[1].clone()],
Box::new(PolyOp::Add),
)
.unwrap()

View File

@@ -117,7 +117,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&self.l0_params[0].clone().try_into().unwrap(), &self.input],
&[
self.l0_params[0].clone().try_into().unwrap(),
self.input.clone(),
],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -132,7 +135,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x, &self.l0_params[1].clone().try_into().unwrap()],
&[x, self.l0_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -144,7 +147,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x],
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -160,7 +163,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&self.l2_params[0].clone().try_into().unwrap(), &x],
&[self.l2_params[0].clone().try_into().unwrap(), x],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -175,7 +178,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x, &self.l2_params[1].clone().try_into().unwrap()],
&[x, self.l2_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -187,7 +190,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x],
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -200,7 +203,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x.unwrap()],
&[x.unwrap()],
Box::new(LookupOp::Div {
denom: ezkl::circuit::utils::F32::from(128.),
}),

View File

@@ -866,7 +866,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
@@ -879,7 +879,6 @@
"run_args.input_visibility = \"private\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.disable_freivalds = True\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"\n",
@@ -905,7 +904,7 @@
"outputs": [],
"source": [
"\n",
"res = ezkl.calibrate_settings(\"input.json\", target=\"resources\", scales = [4])\n",
"res = await ezkl.calibrate_settings(\"input.json\", target=\"resources\", scales = [4])\n",
"assert res == True\n",
"print(\"verified\")\n"
]
@@ -955,7 +954,7 @@
"source": [
"\n",
"\n",
"res = ezkl.gen_witness()\n"
"res = await ezkl.gen_witness()\n"
]
},
{

View File

@@ -150,7 +150,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -170,7 +170,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -204,7 +204,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -437,7 +437,7 @@
"\n",
"# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy\n",
"# You may want to increase the max logrows if accuracy is a concern\n",
"res = ezkl.calibrate_settings(target = \"resources\", max_logrows = 12, scales = [2])"
"res = await ezkl.calibrate_settings(target = \"resources\", max_logrows = 12, scales = [2])"
]
},
{
@@ -526,7 +526,7 @@
"# now generate the witness file\n",
"witness_path = os.path.join('witness.json')\n",
"\n",
"res = ezkl.gen_witness()\n",
"res = await ezkl.gen_witness()\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -736,4 +736,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -467,7 +467,7 @@
"outputs": [],
"source": [
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -508,7 +508,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -196,7 +196,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -237,7 +237,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -341,4 +341,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -179,7 +179,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -214,7 +214,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -241,7 +241,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -291,7 +291,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -510,4 +510,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -152,7 +152,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -188,7 +188,7 @@
"# now generate the witness file \n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -155,7 +155,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -190,7 +190,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -233,7 +233,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -315,7 +315,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n"
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n"
]
},
{
@@ -512,4 +512,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -193,7 +193,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -228,7 +228,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -1,284 +1,284 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"\n",
"# note that you can also call the following function to generate random data for the model\n",
"# it is functionally equivalent to the code above\n",
"ezkl.gen_random_data()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"\n",
"# note that you can also call the following function to generate random data for the model\n",
"# it is functionally equivalent to the code above\n",
"ezkl.gen_random_data()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -347,7 +347,7 @@
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -383,7 +383,7 @@
"# now generate the witness file \n",
"witness_path = \"gan_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -142,7 +142,7 @@
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -177,7 +177,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -276,4 +276,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -139,7 +139,7 @@
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
@@ -193,7 +193,7 @@
"# now generate the witness file \n",
"witness_path = \"lstmwitness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -323,7 +323,7 @@
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[2,7])\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[2,7])\n",
"assert res == True"
]
},
@@ -362,7 +362,7 @@
"# now generate the witness file\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -558,4 +558,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -289,7 +289,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
]
},
{
@@ -321,7 +321,7 @@
"# now generate the witness file \n",
"witness_path = \"gan_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -425,4 +425,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -341,7 +341,7 @@
"\n",
" # generate settings for the current model\n",
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
" res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
" res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
" assert res == True\n",
"\n",
" # load settings and print them to the console\n",
@@ -361,7 +361,7 @@
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
"\n",
"for i in range(3):\n",
@@ -484,4 +484,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -215,7 +215,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -247,7 +247,7 @@
"# now generate the witness file\n",
"witness_path = \"ae_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -451,7 +451,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n",
"print(\"verified\")"
]
@@ -485,7 +485,7 @@
"# now generate the witness file \n",
"witness_path = \"vae_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -590,4 +590,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -845,7 +845,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
"assert res == True"
]
},
@@ -881,7 +881,7 @@
},
"outputs": [],
"source": [
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

File diff suppressed because it is too large Load Diff

View File

@@ -282,7 +282,7 @@
"\n",
" # generate settings for the current model\n",
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
" res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
" res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
" assert res == True\n",
"\n",
" # load settings and print them to the console\n",
@@ -303,7 +303,7 @@
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
"\n",
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
"\n",
"for i in range(2):\n",
@@ -472,4 +472,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -176,7 +176,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -210,7 +210,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -309,4 +309,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -1,336 +1,331 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reusable Verifiers \n",
"\n",
"TODO: Update the reusable verifier solidity contract name.. Make it less generic to H2 and more bespoke to us.\n",
"\n",
"This notebook demonstrates how to create and reuse the same set of separated verifiers for different models. Specifically, we will use the same verifier for the following four models:\n",
"\n",
"- `1l_mlp sigmoid`\n",
"- `1l_mlp relu`\n",
"- `1l_conv sigmoid`\n",
"- `1l_conv relu`\n",
"\n",
"When deploying EZKL verifiers on the blockchain, each associated model typically requires its own unique verifier, leading to increased on-chain state usage. \n",
"However, with the reusable verifier, we can deploy a single verifier that can be used to verify proofs for any valid H2 circuit. This notebook shows how to do so. \n",
"\n",
"By reusing the same verifier across multiple models, we significantly reduce the amount of state bloat on the blockchain. Instead of deploying a unique verifier for each model, we register a unique and much smaller verifying key artifact (VKA) on the reusable verifier contract for each model while sharing a common separated verifier. The VKA contains the VK for the model as well circuit specific metadata that was otherwise hardcoded into the stack of the original non-reusable verifier. The VKA is passed as a parameter to the verifyProof method. This VKA calldata needs to be d with the reusable verifier before it can start verifying proofs by calling the registerVKA method. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.onnx\n",
"\n",
"# Define the models\n",
"class MLP_Sigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(MLP_Sigmoid, self).__init__()\n",
" self.fc = nn.Linear(3, 3)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" x = self.sigmoid(x)\n",
" return x\n",
"\n",
"class MLP_Relu(nn.Module):\n",
" def __init__(self):\n",
" super(MLP_Relu, self).__init__()\n",
" self.fc = nn.Linear(3, 3)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"class Conv_Sigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(Conv_Sigmoid, self).__init__()\n",
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.sigmoid(x)\n",
" return x\n",
"\n",
"class Conv_Relu(nn.Module):\n",
" def __init__(self):\n",
" super(Conv_Relu, self).__init__()\n",
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"# Instantiate the models\n",
"mlp_sigmoid = MLP_Sigmoid()\n",
"mlp_relu = MLP_Relu()\n",
"conv_sigmoid = Conv_Sigmoid()\n",
"conv_relu = Conv_Relu()\n",
"\n",
"# Dummy input tensor for mlp\n",
"dummy_input_mlp = torch.tensor([[-1.5737053155899048, -1.708398461341858, 0.19544155895709991]])\n",
"input_mlp_path = 'mlp_input.json'\n",
"\n",
"# Dummy input tensor for conv\n",
"dummy_input_conv = torch.tensor([[[1.4124163389205933, 0.6938204169273376, 1.0664031505584717]]])\n",
"input_conv_path = 'conv_input.json'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"names = ['mlp_sigmoid', 'mlp_relu', 'conv_sigmoid', 'conv_relu']\n",
"models = [mlp_sigmoid, mlp_relu, conv_sigmoid, conv_relu]\n",
"inputs = [dummy_input_mlp, dummy_input_mlp, dummy_input_conv, dummy_input_conv]\n",
"input_paths = [input_mlp_path, input_mlp_path, input_conv_path, input_conv_path]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import ezkl\n",
"\n",
"for name, model, x, input_path in zip(names, models, inputs, input_paths):\n",
" # Create a new directory for the model if it doesn't exist\n",
" if not os.path.exists(name):\n",
" os.mkdir(name)\n",
" # Store the paths in each of their respective directories\n",
" model_path = os.path.join(name, \"network.onnx\")\n",
" compiled_model_path = os.path.join(name, \"network.compiled\")\n",
" pk_path = os.path.join(name, \"test.pk\")\n",
" vk_path = os.path.join(name, \"test.vk\")\n",
" settings_path = os.path.join(name, \"settings.json\")\n",
"\n",
" witness_path = os.path.join(name, \"witness.json\")\n",
" sol_code_path = os.path.join(name, 'test.sol')\n",
" vka_path = os.path.join(name, 'vka.bytes')\n",
" abi_path = os.path.join(name, 'test.abi')\n",
" proof_path = os.path.join(name, \"proof.json\")\n",
"\n",
" # Flips the neural net into inference mode\n",
" model.eval()\n",
"\n",
" # Export the model\n",
" torch.onnx.export(model, x, model_path, export_params=True, opset_version=10,\n",
" do_constant_folding=True, input_names=['input'],\n",
" output_names=['output'], dynamic_axes={'input': {0: 'batch_size'},\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
" data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
" data = dict(input_data=[data_array])\n",
" json.dump(data, open(input_path, 'w'))\n",
"\n",
" py_run_args = ezkl.PyRunArgs()\n",
" py_run_args.input_visibility = \"private\"\n",
" py_run_args.output_visibility = \"public\"\n",
" py_run_args.param_visibility = \"fixed\" # private by default\n",
"\n",
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n",
" assert res == True\n",
"\n",
" ezkl.calibrate_settings(input_path, model_path, settings_path, \"resources\")\n",
"\n",
" res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
" assert res == True\n",
"\n",
" res = await ezkl.get_srs(settings_path)\n",
" assert res == True\n",
"\n",
" # now generate the witness file\n",
" res = ezkl.gen_witness(input_path, compiled_model_path, witness_path)\n",
" assert os.path.isfile(witness_path) == True\n",
"\n",
" # SETUP \n",
" # We recommend disabling selector compression for the setup as it decreases the size of the VK artifact\n",
" res = ezkl.setup(compiled_model_path, vk_path, pk_path, disable_selector_compression=True)\n",
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" assert os.path.isfile(settings_path)\n",
"\n",
" # GENERATE A PROOF\n",
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, \"single\")\n",
" assert os.path.isfile(proof_path)\n",
"\n",
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",
" # TODO: Add a flag force equals true to in the deprication process to preserve OG single purpose verifier?\n",
" assert res == True\n",
"\n",
" # TODO: \n",
" res = await ezkl.create_evm_vka(vk_path, settings_path, vka_path, decimals=18)\n",
" assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check that the generated verifiers are identical for all models."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import filecmp\n",
"\n",
"def compare_files(file1, file2):\n",
" return filecmp.cmp(file1, file2, shallow=False)\n",
"\n",
"sol_code_path_0 = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
"sol_code_path_1 = os.path.join(\"mlp_relu\", 'test.sol')\n",
"\n",
"sol_code_path_2 = os.path.join(\"conv_sigmoid\", 'test.sol')\n",
"sol_code_path_3 = os.path.join(\"conv_relu\", 'test.sol')\n",
"\n",
"\n",
"assert compare_files(sol_code_path_0, sol_code_path_1) == True\n",
"assert compare_files(sol_code_path_2, sol_code_path_3) == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we deploy reusable verifier that will be shared by the four models. We picked the `1l_mlp sigmoid` model as an example but you could have used any of the generated verifiers since they are all identical. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"sol_code_path = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" \"verifier/reusable\" # TODO deprecate this option for selecting the type of verifier you want to deploy. \n",
" # verifier, verifier/reusable, vka\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(addr_path_verifier, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we deploy each of the unique VK-artifacts and verify them using the shared verifier deployed in the previous step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for name in names:\n",
" addr_path_vk = \"addr_vk.txt\"\n",
" vka_path = os.path.join(name, 'vka.bytes')\n",
" res = await ezkl.register_vka(\n",
" addr, # address of the reusable verifier. TODO: If we deploy the RV across all chains to a single canoncial address, we can hardcode that address and remove this param.\n",
" 'http://127.0.0.1:3030',\n",
" vka_path=vka_path, # TODO: Pass in private key and potentially create new command that both creates and registers the vka. Simplify testing pipeline for us and other folks. \n",
" )\n",
" assert res == True\n",
" \n",
" proof_path = os.path.join(name, \"proof.json\")\n",
" res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" vka_path = vka_path # TODO: Turn this from optional to required if we deprecate the orignal verifier. \n",
" # TODO: Make it where the use only needs to deply a vka. \n",
" )\n",
" assert res == True"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reusable Verifiers \n",
"\n",
"This notebook demonstrates how to create and reuse the same set of separated verifiers for different models. Specifically, we will use the same verifier for the following four models:\n",
"\n",
"- `1l_mlp sigmoid`\n",
"- `1l_mlp relu`\n",
"- `1l_conv sigmoid`\n",
"- `1l_conv relu`\n",
"\n",
"When deploying EZKL verifiers on the blockchain, each associated model typically requires its own unique verifier, leading to increased on-chain state usage. \n",
"However, with the reusable verifier, we can deploy a single verifier that can be used to verify proofs for any valid H2 circuit. This notebook shows how to do so. \n",
"\n",
"By reusing the same verifier across multiple models, we significantly reduce the amount of state bloat on the blockchain. Instead of deploying a unique verifier for each model, we deploy a unique and much smaller verifying key artifact (VKA) contract for each model while sharing a common separated verifier. The VKA contains the VK for the model as well circuit specific metadata that was otherwise hardcoded into the stack of the original non-reusable verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.onnx\n",
"\n",
"# Define the models\n",
"class MLP_Sigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(MLP_Sigmoid, self).__init__()\n",
" self.fc = nn.Linear(3, 3)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" x = self.sigmoid(x)\n",
" return x\n",
"\n",
"class MLP_Relu(nn.Module):\n",
" def __init__(self):\n",
" super(MLP_Relu, self).__init__()\n",
" self.fc = nn.Linear(3, 3)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"class Conv_Sigmoid(nn.Module):\n",
" def __init__(self):\n",
" super(Conv_Sigmoid, self).__init__()\n",
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
" self.sigmoid = nn.Sigmoid()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.sigmoid(x)\n",
" return x\n",
"\n",
"class Conv_Relu(nn.Module):\n",
" def __init__(self):\n",
" super(Conv_Relu, self).__init__()\n",
" self.conv = nn.Conv1d(1, 1, kernel_size=3, stride=1)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.relu(x)\n",
" return x\n",
"\n",
"# Instantiate the models\n",
"mlp_sigmoid = MLP_Sigmoid()\n",
"mlp_relu = MLP_Relu()\n",
"conv_sigmoid = Conv_Sigmoid()\n",
"conv_relu = Conv_Relu()\n",
"\n",
"# Dummy input tensor for mlp\n",
"dummy_input_mlp = torch.tensor([[-1.5737053155899048, -1.708398461341858, 0.19544155895709991]])\n",
"input_mlp_path = 'mlp_input.json'\n",
"\n",
"# Dummy input tensor for conv\n",
"dummy_input_conv = torch.tensor([[[1.4124163389205933, 0.6938204169273376, 1.0664031505584717]]])\n",
"input_conv_path = 'conv_input.json'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"names = ['mlp_sigmoid', 'mlp_relu', 'conv_sigmoid', 'conv_relu']\n",
"models = [mlp_sigmoid, mlp_relu, conv_sigmoid, conv_relu]\n",
"inputs = [dummy_input_mlp, dummy_input_mlp, dummy_input_conv, dummy_input_conv]\n",
"input_paths = [input_mlp_path, input_mlp_path, input_conv_path, input_conv_path]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import torch\n",
"import ezkl\n",
"\n",
"for name, model, x, input_path in zip(names, models, inputs, input_paths):\n",
" # Create a new directory for the model if it doesn't exist\n",
" if not os.path.exists(name):\n",
" os.mkdir(name)\n",
" # Store the paths in each of their respective directories\n",
" model_path = os.path.join(name, \"network.onnx\")\n",
" compiled_model_path = os.path.join(name, \"network.compiled\")\n",
" pk_path = os.path.join(name, \"test.pk\")\n",
" vk_path = os.path.join(name, \"test.vk\")\n",
" settings_path = os.path.join(name, \"settings.json\")\n",
"\n",
" witness_path = os.path.join(name, \"witness.json\")\n",
" sol_code_path = os.path.join(name, 'test.sol')\n",
" vka_path = os.path.join(name, 'vka.bytes')\n",
" abi_path = os.path.join(name, 'test.abi')\n",
" proof_path = os.path.join(name, \"proof.json\")\n",
"\n",
" # Flips the neural net into inference mode\n",
" model.eval()\n",
"\n",
" # Export the model\n",
" torch.onnx.export(model, x, model_path, export_params=True, opset_version=10,\n",
" do_constant_folding=True, input_names=['input'],\n",
" output_names=['output'], dynamic_axes={'input': {0: 'batch_size'},\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
" data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
" data = dict(input_data=[data_array])\n",
" json.dump(data, open(input_path, 'w'))\n",
"\n",
" py_run_args = ezkl.PyRunArgs()\n",
" py_run_args.input_visibility = \"private\"\n",
" py_run_args.output_visibility = \"public\"\n",
" py_run_args.param_visibility = \"fixed\" # private by default\n",
"\n",
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)\n",
" assert res == True\n",
"\n",
" await ezkl.calibrate_settings(input_path, model_path, settings_path, \"resources\")\n",
"\n",
" res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
" assert res == True\n",
"\n",
" res = await ezkl.get_srs(settings_path)\n",
" assert res == True\n",
"\n",
" # now generate the witness file\n",
" res = await ezkl.gen_witness(input_path, compiled_model_path, witness_path)\n",
" assert os.path.isfile(witness_path) == True\n",
"\n",
" # SETUP \n",
" # We recommend disabling selector compression for the setup as it decreases the size of the VK artifact\n",
" res = ezkl.setup(compiled_model_path, vk_path, pk_path, disable_selector_compression=True)\n",
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" assert os.path.isfile(settings_path)\n",
"\n",
" # GENERATE A PROOF\n",
" res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, \"single\")\n",
" assert os.path.isfile(proof_path)\n",
"\n",
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",
" assert res == True\n",
"\n",
" res = await ezkl.create_evm_vka(vk_path, settings_path, vka_path)\n",
" assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check that the generated verifiers are identical for all models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import filecmp\n",
"\n",
"def compare_files(file1, file2):\n",
" return filecmp.cmp(file1, file2, shallow=False)\n",
"\n",
"sol_code_path_0 = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
"sol_code_path_1 = os.path.join(\"mlp_relu\", 'test.sol')\n",
"\n",
"sol_code_path_2 = os.path.join(\"conv_sigmoid\", 'test.sol')\n",
"sol_code_path_3 = os.path.join(\"conv_relu\", 'test.sol')\n",
"\n",
"\n",
"assert compare_files(sol_code_path_0, sol_code_path_1) == True\n",
"assert compare_files(sol_code_path_2, sol_code_path_3) == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we deploy separate verifier that will be shared by the four models. We picked the `1l_mlp sigmoid` model as an example but you could have used any of the generated verifiers since they are all identical. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"sol_code_path = os.path.join(\"mlp_sigmoid\", 'test.sol')\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" \"verifier/reusable\"\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(addr_path_verifier, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we deploy each of the unique VK-artifacts and verify them using the shared verifier deployed in the previous step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for name in names:\n",
" addr_path_vk = \"addr_vk.txt\"\n",
" vka_path = os.path.join(name, 'vka.bytes')\n",
" res = await ezkl.register_vka(\n",
" addr,\n",
" 'http://127.0.0.1:3030',\n",
" vka_path=vka_path,\n",
" )\n",
" assert res == True\n",
" \n",
" proof_path = os.path.join(name, \"proof.json\")\n",
" sol_code_path = os.path.join(name, 'vk.sol')\n",
" res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" vka_path = vka_path\n",
" )\n",
" assert res == True"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -231,7 +231,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -267,7 +267,7 @@
" # Serialize data into file:\n",
"json.dump( data, open(data_path_faulty, 'w' ))\n",
"\n",
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
"res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
"assert os.path.isfile(witness_path_faulty)"
]
},
@@ -312,7 +312,7 @@
"# Serialize data into file:\n",
"json.dump( data, open(data_path_truthy, 'w' ))\n",
"\n",
"res = ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
"res = await ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
"assert os.path.isfile(witness_path_truthy)"
]
},

View File

@@ -171,7 +171,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -205,7 +205,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -404,4 +404,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -171,7 +171,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -205,7 +205,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -304,4 +304,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -169,7 +169,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -203,7 +203,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -302,4 +302,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -170,7 +170,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -204,7 +204,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -303,4 +303,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -149,7 +149,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -183,7 +183,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -298,7 +298,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
"assert os.path.isfile(witness_path)\n",
"\n",
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
@@ -412,7 +412,7 @@
"source": [
"# now generate the witness file\n",
"\n",
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path, vk_path)\n",
"res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path, vk_path)\n",
"assert os.path.isfile(witness_path)\n",
"\n",
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",

View File

@@ -167,7 +167,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
},
@@ -187,7 +187,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -221,7 +221,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

View File

@@ -152,7 +152,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -186,7 +186,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -392,7 +392,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True"
]
}
@@ -418,4 +418,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -637,7 +637,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
]
},
{
@@ -683,7 +683,7 @@
" data = json.load(f)\n",
" print(len(data['input_data'][0]))\n",
"\n",
"ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -758,4 +758,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@@ -525,7 +525,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
]
},
{
@@ -572,7 +572,7 @@
" data = json.load(f)\n",
" print(len(data['input_data'][0]))\n",
"\n",
"ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
"await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
@@ -647,4 +647,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@@ -458,7 +458,7 @@
"\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"ezkl.calibrate_settings(\n",
"await ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\", scales = [4])\n",
"res = await ezkl.get_srs(settings_filename)\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)\n",
@@ -527,7 +527,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"res = await ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -762,4 +762,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -629,7 +629,7 @@
"source": [
"\n",
"\n",
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
"res = await ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
"assert res == True\n",
"print(\"verified\")\n"
]
@@ -680,7 +680,7 @@
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
@@ -905,4 +905,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -193,7 +193,7 @@
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
@@ -227,7 +227,7 @@
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},

File diff suppressed because it is too large Load Diff

View File

@@ -1,79 +0,0 @@
from torch import nn
import torch.nn.init as init
import torch
import json
N = 100
class Model(nn.Module):
def __init__(self, inplace=False):
super(Model, self).__init__()
self.aff1 = nn.Linear(N,N)
self.aff2 = nn.Linear(N,N)
self.aff3 = nn.Linear(N,N)
self.aff4 = nn.Linear(N,N)
self.aff5 = nn.Linear(N,N)
self.aff6 = nn.Linear(N,N)
self.aff7 = nn.Linear(N,N)
self.aff8 = nn.Linear(N,N)
self.aff9 = nn.Linear(N,N)
self.relu = nn.ReLU()
self._initialize_weights()
def forward(self, x):
# concat 10 x along dim 0
x = x.repeat(10, 1)
x = self.aff1(x)
x = self.relu(x)
x = self.aff2(x)
x = self.relu(x)
x = self.aff3(x)
x = self.relu(x)
x = self.aff4(x)
x = self.relu(x)
x = self.aff5(x)
x = self.relu(x)
x = self.aff6(x)
x = self.relu(x)
x = self.aff7(x)
x = self.relu(x)
x = self.aff8(x)
x = self.relu(x)
x = self.aff9(x)
return x
def _initialize_weights(self):
init.orthogonal_(self.aff1.weight)
model = Model()
# Flips the neural net into inference mode
model.eval()
model.to('cpu')
x = torch.randn(1, N)
# Export the model
torch.onnx.export(model, # model being run
# model input (or a tuple for multiple inputs)
x,
# where to save the model (can be a file or file-like object)
"network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
data_json = dict(input_data=[data_array])
print(data_json)
# Serialize data into file:
json.dump(data_json, open("input.json", 'w'))

View File

@@ -1 +0,0 @@
{"input_data": [[0.33088353276252747, -0.8819183707237244, 1.245591163635254, -1.807046890258789, 1.9922369718551636, -0.3360576629638672, 0.4529011845588684, -0.3590165674686432, 0.08356846123933792, 0.5126393437385559, 0.44627535343170166, 1.4916497468948364, 0.49731069803237915, -0.9748706817626953, -0.4923185408115387, 1.3548223972320557, 0.2306872010231018, 1.125955581665039, -1.7063908576965332, 0.3777385354042053, -2.7988760471343994, -1.1846797466278076, 0.7473157048225403, 1.490412950515747, 0.017497723922133446, 2.113945245742798, -1.2141249179840088, -0.16120357811450958, 0.021127669140696526, 0.7207374572753906, -1.369688868522644, -0.7369781732559204, -0.630584180355072, -0.4520200788974762, 0.29123976826667786, 0.6334688067436218, -0.869332492351532, -1.258501648902893, 0.3012596666812897, -0.5507447123527527, 0.669975757598877, 0.15088629722595215, -0.1050339788198471, 0.5505334138870239, -0.1287376880645752, -1.4297826290130615, -0.01703289896249771, -1.2296998500823975, 0.5122153162956238, -0.16924428939819336, -0.415036678314209, -1.1979341506958008, 0.05831022188067436, -0.4411357045173645, 2.0713791847229004, 1.4611141681671143, -0.9357407093048096, -0.333297461271286, -0.676478385925293, 1.390028476715088, -0.05827632546424866, 1.535687804222107, 0.3060210347175598, -0.03171076253056526, -0.614985466003418, 1.2040390968322754, 0.31318482756614685, -1.2134959697723389, 0.13110508024692535, -1.4880926609039307, 1.7007993459701538, 1.5412729978561401, 0.09260450303554535, 0.7649128437042236, -0.5009126663208008, -0.5356241464614868, -0.069572813808918, -0.011717632412910461, 0.21314217150211334, -0.1985170543193817, -0.0223808903247118, 1.2128918170928955, 0.8334696888923645, 1.9029873609542847, -0.11491120606660843, -0.10303237289190292, -0.2467050403356552, 1.557223916053772, -1.1108328104019165, -0.9065343141555786, -0.2271333783864975, 0.6959827542304993, -0.48698121309280396, 0.5689510703086853, 1.115319013595581, -0.8907430768013, -0.24722427129745483, -0.7437837719917297, 0.6742106676101685, -1.7830933332443237]]}

Binary file not shown.

View File

@@ -104,5 +104,5 @@ json.dump(data, open("input.json", 'w'))
# ezkl.gen_settings("network.onnx", "settings.json")
# !RUST_LOG = full
# res = ezkl.calibrate_settings(
# res = await ezkl.calibrate_settings(
# "input.json", "network.onnx", "settings.json", "resources")

View File

@@ -1,182 +0,0 @@
use ezkl::circuit::einsum::analysis::analyze_einsum_usage;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::tensor::*;
use halo2_proofs::circuit::floor_planner::V1;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr;
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use rand::rngs::OsRng;
use std::collections::HashMap;
use std::marker::PhantomData;
static mut LEN: usize = 4;
const K: usize = 11;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
einsum: Einsum<F>,
}
#[derive(Clone, Default)]
struct Einsum<F: PrimeField + TensorType + PartialOrd> {
equation: String,
input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Einsum<F> {
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}
impl Circuit<Fr> for MyCircuit<Fr> {
type Config = BaseConfig<Fr>;
type FloorPlanner = V1;
type Params = Einsum<Fr>;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure_with_params(cs: &mut ConstraintSystem<Fr>, params: Self::Params) -> Self::Config {
let len = unsafe { LEN };
let a = VarTensor::new_advice(cs, K, 1, len);
let b = VarTensor::new_advice(cs, K, 1, len);
let output = VarTensor::new_advice(cs, K, 1, len);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE);
let mut equations = HashMap::new();
equations.insert((0, params.equation), params.input_axes_to_dims);
let analysis = analyze_einsum_usage(&equations).unwrap();
let num_einsum_inner_cols = 2;
config
.configure_einsums(cs, &analysis, num_einsum_inner_cols, K)
.unwrap();
let _constant = VarTensor::constant_cols(cs, K, 2, false);
config
}
fn params(&self) -> Self::Params {
Einsum::<Fr>::new(
&self.einsum.equation,
&[
&self.inputs[0].get_inner().unwrap(),
&self.inputs[1].get_inner().unwrap(),
],
)
.unwrap()
}
fn configure(_cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unimplemented!("call configure_with_params instead")
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let challenges = config
.einsums
.as_ref()
.ok_or(Error::Synthesis)?
.challenges()
.unwrap()
.iter()
.map(|c| layouter.get_challenge(*c))
.collect_vec();
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new_with_challenges(
region,
0,
1,
1024,
2,
challenges.clone(),
);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: self.einsum.equation.clone(),
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runmatmul() {
let i = 10;
let n = 10;
let j = 40;
let k = 10;
let mut a = Tensor::from((0..i * n * j).map(|_| Value::known(Fr::random(OsRng))));
a.reshape(&[i, n, j]).unwrap();
// parameters
let mut b = Tensor::from((0..j * k).map(|_| Value::known(Fr::random(OsRng))));
b.reshape(&[j, k]).unwrap();
let einsum = Einsum::<Fr>::new("inj,jk->ik", &[&a, &b]).unwrap();
let circuit = MyCircuit {
inputs: [ValTensor::from(a), ValTensor::from(b)],
einsum,
};
let mock_prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
mock_prover.assert_satisfied();
}
pub fn main() {
runmatmul()
}

View File

@@ -0,0 +1,60 @@
# inbrowser-evm-verify
We would like the Solidity verifier to be canonical and usually all you ever need. For this, we need to be able to run that verifier in browser.
## How to use (Node js)
```ts
import localEVMVerify from '@ezkljs/verify';
// Load in the proof file as a buffer
const proofFileBuffer = fs.readFileSync(`${path}/${example}/proof.pf`)
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
const result = await localEVMVerify(proofFileBuffer, bytecode)
console.log('result', result)
```
**Note**: Run `ezkl create-evm-verifier` to get the Solidity verifier, with which you can retrieve the bytecode once compiled. We recommend compiling to the Shanghai hardfork target, else you will have to pass an additional parameter specifying the EVM version to the `localEVMVerify` function like so (for Paris hardfork):
```ts
import localEVMVerify, { hardfork } from '@ezkljs/verify';
const result = await localEVMVerify(proofFileBuffer, bytecode, hardfork['Paris'])
```
**Note**: You can also verify separated vk verifiers using the `localEVMVerify` function. Just pass the vk verifier bytecode as the third parameter like so:
```ts
import localEVMVerify from '@ezkljs/verify';
const result = await localEVMVerify(proofFileBuffer, verifierBytecode, VKBytecode)
```
## How to use (Browser)
```ts
import localEVMVerify from '@ezkljs/verify';
// Load in the proof file as a buffer using the web apis (fetch, FileReader, etc)
// We use fetch in this example to load the proof file as a buffer
const proofFileBuffer = await fetch(`${path}/${example}/proof.pf`).then(res => res.arrayBuffer())
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
const result = await browserEVMVerify(proofFileBuffer, bytecode)
console.log('result', result)
```
Output:
```ts
result: true
```

View File

@@ -0,0 +1,42 @@
{
"name": "@ezkljs/verify",
"version": "v10.4.2",
"publishConfig": {
"access": "public"
},
"description": "Evm verify EZKL proofs in the browser.",
"main": "dist/commonjs/index.js",
"module": "dist/esm/index.js",
"types": "dist/commonjs/index.d.ts",
"files": [
"dist",
"LICENSE",
"README.md"
],
"scripts": {
"clean": "rm -r dist || true",
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "4.0.0",
"@ethereumjs/evm": "2.0.0",
"@ethereumjs/statemanager": "2.0.0",
"@ethereumjs/tx": "5.0.0",
"@ethereumjs/util": "9.0.0",
"@ethereumjs/vm": "7.0.0",
"@ethersproject/abi": "5.7.0",
"@ezkljs/engine": "10.4.2",
"ethers": "6.7.1",
"json-bigint": "1.0.0"
},
"devDependencies": {
"@types/node": "^20.8.3",
"ts-loader": "^9.5.0",
"ts-node": "^10.9.1",
"resolve-tspaths": "^0.8.16",
"tsconfig-paths": "^4.2.0",
"typescript": "^5.2.2"
}
}

1479
in-browser-evm-verifier/pnpm-lock.yaml generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,144 @@
import { defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
import { Address, hexToBytes } from '@ethereumjs/util'
import { Chain, Common, Hardfork } from '@ethereumjs/common'
import { LegacyTransaction, LegacyTxData } from '@ethereumjs/tx'
// import { DefaultStateManager } from '@ethereumjs/statemanager'
// import { Blockchain } from '@ethereumjs/blockchain'
import { VM } from '@ethereumjs/vm'
import { EVM } from '@ethereumjs/evm'
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
import { getAccountNonce, insertAccount } from './utils/account-utils'
import { encodeVerifierCalldata } from '../nodejs/ezkl';
async function deployContract(
vm: VM,
common: Common,
senderPrivateKey: Uint8Array,
deploymentBytecode: string
): Promise<Address> {
// Contracts are deployed by sending their deployment bytecode to the address 0
// The contract params should be abi-encoded and appended to the deployment bytecode.
// const data =
const data = encodeDeployment(deploymentBytecode)
const txData = {
data,
nonce: await getAccountNonce(vm, senderPrivateKey),
}
const tx = LegacyTransaction.fromTxData(
buildTransaction(txData) as LegacyTxData,
{ common, allowUnlimitedInitCodeSize: true },
).sign(senderPrivateKey)
const deploymentResult = await vm.runTx({
tx,
skipBlockGasLimitValidation: true,
skipNonce: true
})
if (deploymentResult.execResult.exceptionError) {
throw deploymentResult.execResult.exceptionError
}
return deploymentResult.createdAddress!
}
async function verify(
vm: VM,
contractAddress: Address,
caller: Address,
proof: Uint8Array | Uint8ClampedArray,
vkAddress?: Address | Uint8Array,
): Promise<boolean> {
if (proof instanceof Uint8Array) {
proof = new Uint8ClampedArray(proof.buffer)
}
if (vkAddress) {
const vkAddressBytes = hexToBytes(vkAddress.toString())
const vkAddressArray = Array.from(vkAddressBytes)
let string = JSON.stringify(vkAddressArray)
const uint8Array = new TextEncoder().encode(string);
// Step 3: Convert to Uint8ClampedArray
vkAddress = new Uint8Array(uint8Array.buffer);
// convert uitn8array of length
console.error('vkAddress', vkAddress)
}
const data = encodeVerifierCalldata(proof, vkAddress)
const verifyResult = await vm.evm.runCall({
to: contractAddress,
caller: caller,
origin: caller, // The tx.origin is also the caller here
data: data,
})
if (verifyResult.execResult.exceptionError) {
throw verifyResult.execResult.exceptionError
}
const results = AbiCoder.decode(['bool'], verifyResult.execResult.returnValue)
return results[0]
}
/**
* Spins up an ephemeral EVM instance for executing the bytecode of a solidity verifier
* @param proof Json serialized proof file
* @param bytecode The bytecode of a compiled solidity verifier.
* @param bytecode_vk The bytecode of a contract that stores the vk. (Optional, only required if the vk is stored in a separate contract)
* @param evmVersion The evm version to use for the verification. (Default: London)
* @returns The result of the evm verification.
* @throws If the verify transaction reverts
*/
export default async function localEVMVerify(
proof: Uint8Array | Uint8ClampedArray,
bytecode_verifier: string,
bytecode_vk?: string,
evmVersion?: Hardfork,
): Promise<boolean> {
try {
const hardfork = evmVersion ? evmVersion : Hardfork['Shanghai']
const common = new Common({ chain: Chain.Mainnet, hardfork })
const accountPk = hexToBytes(
'0xe331b6d69882b4cb4ea581d88e0b604039a3de5967688d3dcffdd2270c0fd109', // anvil deterministic Pk
)
const evm = new EVM({
allowUnlimitedContractSize: true,
allowUnlimitedInitCodeSize: true,
})
const vm = await VM.create({ common, evm })
const accountAddress = Address.fromPrivateKey(accountPk)
await insertAccount(vm, accountAddress)
const verifierAddress = await deployContract(
vm,
common,
accountPk,
bytecode_verifier
)
if (bytecode_vk) {
const accountPk = hexToBytes("0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"); // anvil deterministic Pk
const accountAddress = Address.fromPrivateKey(accountPk)
await insertAccount(vm, accountAddress)
const output = await deployContract(vm, common, accountPk, bytecode_vk)
const result = await verify(vm, verifierAddress, accountAddress, proof, output)
return true
}
const result = await verify(vm, verifierAddress, accountAddress, proof)
return result
} catch (error) {
// log or re-throw the error, depending on your needs
console.error('An error occurred:', error)
throw error
}
}

View File

@@ -0,0 +1,32 @@
import { VM } from '@ethereumjs/vm'
import { Account, Address } from '@ethereumjs/util'
export const keyPair = {
secretKey:
'0x3cd7232cd6f3fc66a57a6bedc1a8ed6c228fff0a327e169c2bcc5e869ed49511',
publicKey:
'0x0406cc661590d48ee972944b35ad13ff03c7876eae3fd191e8a2f77311b0a3c6613407b5005e63d7d8d76b89d5f900cde691497688bb281e07a5052ff61edebdc0',
}
export const insertAccount = async (vm: VM, address: Address) => {
const acctData = {
nonce: 0,
balance: BigInt('1000000000000000000'), // 1 eth
}
const account = Account.fromAccountData(acctData)
await vm.stateManager.putAccount(address, account)
}
export const getAccountNonce = async (
vm: VM,
accountPrivateKey: Uint8Array,
) => {
const address = Address.fromPrivateKey(accountPrivateKey)
const account = await vm.stateManager.getAccount(address)
if (account) {
return account.nonce
} else {
return BigInt(0)
}
}

View File

@@ -0,0 +1,59 @@
import { Interface, defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
import {
AccessListEIP2930TxData,
FeeMarketEIP1559TxData,
TxData,
} from '@ethereumjs/tx'
type TransactionsData =
| TxData
| AccessListEIP2930TxData
| FeeMarketEIP1559TxData
export const encodeFunction = (
method: string,
params?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
types: any[]
values: unknown[]
},
): string => {
const parameters = params?.types ?? []
const methodWithParameters = `function ${method}(${parameters.join(',')})`
const signatureHash = new Interface([methodWithParameters]).getSighash(method)
const encodedArgs = AbiCoder.encode(parameters, params?.values ?? [])
return signatureHash + encodedArgs.slice(2)
}
export const encodeDeployment = (
bytecode: string,
params?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
types: any[]
values: unknown[]
},
) => {
const deploymentData = '0x' + bytecode
if (params) {
const argumentsEncoded = AbiCoder.encode(params.types, params.values)
return deploymentData + argumentsEncoded.slice(2)
}
return deploymentData
}
export const buildTransaction = (
data: Partial<TransactionsData>,
): TransactionsData => {
const defaultData: Partial<TransactionsData> = {
gasLimit: 3_000_000_000_000_000,
gasPrice: 7,
value: 0,
data: '0x',
}
return {
...defaultData,
...data,
}
}

View File

@@ -0,0 +1,7 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "CommonJS",
"outDir": "./dist/commonjs"
}
}

View File

@@ -0,0 +1,7 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "ES2020",
"outDir": "./dist/esm"
}
}

View File

@@ -0,0 +1,62 @@
{
"compilerOptions": {
"rootDir": "src",
"target": "es2017",
"outDir": "dist",
"declaration": true,
"lib": [
"dom",
"dom.iterable",
"esnext"
],
"allowJs": true,
"checkJs": true,
"skipLibCheck": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noEmit": false,
"esModuleInterop": true,
"module": "CommonJS",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
// "incremental": true,
"noUncheckedIndexedAccess": true,
"baseUrl": ".",
"paths": {
"@/*": [
"./src/*"
]
}
},
"include": [
"src/**/*.ts",
"src/**/*.tsx",
"src/**/*.cjs",
"src/**/*.mjs"
],
"exclude": [
"node_modules"
],
// NEW: Options for file/directory watching
"watchOptions": {
// Use native file system events for files and directories
"watchFile": "useFsEvents",
"watchDirectory": "useFsEvents",
// Poll files for updates more frequently
// when they're updated a lot.
"fallbackPolling": "dynamicPriority",
// Don't coalesce watch notification
"synchronousWatchDirectory": true,
// Finally, two additional settings for reducing the amount of possible
// files to track work from these directories
"excludeDirectories": [
"**/node_modules",
"_build"
],
"excludeFiles": [
"build/fileWhichChangesOften.ts"
]
}
}

View File

@@ -9,6 +9,6 @@ pytest==8.1.1
tomli==2.0.1
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.17.0
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2025-05-01"
channel = "nightly-2025-02-17"
components = ["rustfmt", "clippy"]

View File

@@ -1,229 +0,0 @@
#!/bin/bash
set -e
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Default installation directory
DEFAULT_INSTALL_DIR="/opt/icicle/lib/backend/halo2"
# Halo2 repository details
HALO2_REPO="https://github.com/zkonduit/halo2"
HALO2_BRANCH="ac/conditional-compilation-icicle2"
# Parse command line arguments
AUTO_YES=false
for arg in "$@"; do
case $arg in
-y|--yes)
AUTO_YES=true
shift
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " -y, --yes Automatically answer 'yes' to all prompts"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $arg"
echo "Use -h or --help for usage information"
exit 1
;;
esac
done
echo -e "${GREEN}EZKL GPU Setup Script${NC}"
echo -e "${GREEN}=====================${NC}"
echo ""
# Parse commit hash from Cargo.lock
echo "Parsing halo2 commit hash from Cargo.lock..."
if [ ! -f "Cargo.lock" ]; then
echo -e "${RED}Error: Cargo.lock not found. Please run this script from the project root.${NC}"
exit 1
fi
HALO2_COMMIT=$(grep "github\.com/zkonduit/halo2?" Cargo.lock | grep -v "halo2wrong" | head -1 | grep -o "#[a-f0-9]\{40\}" | cut -c2-)
if [ -z "$HALO2_COMMIT" ]; then
echo -e "${RED}Error: Could not parse halo2 commit hash from Cargo.lock${NC}"
exit 1
fi
echo -e "${GREEN}Found halo2 commit: $HALO2_COMMIT${NC}"
echo ""
echo "This script will:"
echo "1. Sparse checkout the halo2 repository at commit $HALO2_COMMIT"
echo "2. Extract only the icicle/backend/cuda/ directory"
echo "3. Set the ICICLE_BACKEND_INSTALL_DIR environment variable"
echo ""
# Check if user wants to override the default directory
if [ "$AUTO_YES" = true ]; then
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
echo -e "${GREEN}Using default installation directory: ${INSTALL_DIR}${NC}"
else
echo -e "${YELLOW}Default installation directory: ${DEFAULT_INSTALL_DIR}${NC}"
read -p "Do you want to use a different directory? [y/N]: " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
read -p "Enter the installation directory: " INSTALL_DIR
INSTALL_DIR="${INSTALL_DIR/#\~/$HOME}" # Expand ~ to $HOME
else
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
fi
# Confirm the installation directory
echo ""
echo -e "${YELLOW}Installation directory: ${INSTALL_DIR}${NC}"
read -p "Continue with this directory? [y/N]: " -n 1 -r
echo
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo -e "${RED}Setup cancelled by user.${NC}"
exit 1
fi
fi
# Check if ICICLE_BACKEND_INSTALL_DIR is already set
if [ ! -z "$ICICLE_BACKEND_INSTALL_DIR" ] && [ "$AUTO_YES" = false ]; then
echo ""
echo -e "${YELLOW}Warning: ICICLE_BACKEND_INSTALL_DIR is already set to: $ICICLE_BACKEND_INSTALL_DIR${NC}"
read -p "Do you want to override it? [y/N]: " -n 1 -r
echo
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo -e "${RED}Setup cancelled by user.${NC}"
exit 1
fi
elif [ ! -z "$ICICLE_BACKEND_INSTALL_DIR" ] && [ "$AUTO_YES" = true ]; then
echo -e "${GREEN}Overriding existing ICICLE_BACKEND_INSTALL_DIR (was: $ICICLE_BACKEND_INSTALL_DIR)${NC}"
fi
echo ""
echo -e "${GREEN}Starting GPU setup...${NC}"
# Create installation directory
echo "Creating installation directory..."
mkdir -p "$INSTALL_DIR"
# Create temporary directory for sparse checkout
TEMP_DIR=$(mktemp -d)
echo "Using temporary directory: $TEMP_DIR"
# Clone with sparse checkout
echo "Cloning halo2 repository with sparse checkout..."
cd "$TEMP_DIR"
git clone --filter=blob:none --sparse "$HALO2_REPO" halo2
cd halo2
# Checkout the specific branch and commit
echo "Checking out branch $HALO2_BRANCH at commit $HALO2_COMMIT..."
git checkout "$HALO2_BRANCH"
git checkout "$HALO2_COMMIT"
# Configure sparse checkout
echo "Configuring sparse checkout for icicle/backend/cuda/..."
git sparse-checkout init --cone
git sparse-checkout set icicle/backend/cuda/
# Copy the icicle directory to the installation location
if [ -d "icicle/backend/cuda" ]; then
echo "Copying icicle/backend/cuda/ to $INSTALL_DIR..."
cp -r icicle/backend/cuda/* "$INSTALL_DIR/"
echo -e "${GREEN}Files copied successfully!${NC}"
else
echo -e "${RED}Error: icicle/backend/cuda directory not found in the repository${NC}"
exit 1
fi
# Clean up temporary directory
echo "Cleaning up temporary files..."
rm -rf "$TEMP_DIR"
# Ask user about setting environment variable permanently
SETUP_PERMANENT_ENV=false
if [ "$AUTO_YES" = true ]; then
SETUP_PERMANENT_ENV=true
echo ""
echo -e "${GREEN}Setting ICICLE_BACKEND_INSTALL_DIR environment variable permanently...${NC}"
else
echo ""
echo -e "${YELLOW}Do you want to set ICICLE_BACKEND_INSTALL_DIR environment variable permanently?${NC}"
echo "This will add 'export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\"' to your shell configuration file."
read -p "Set environment variable permanently? [y/N]: " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
SETUP_PERMANENT_ENV=true
fi
fi
if [ "$SETUP_PERMANENT_ENV" = true ]; then
echo "Setting ICICLE_BACKEND_INSTALL_DIR environment variable..."
# Detect shell and set environment variable accordingly
if [ -n "$ZSH_VERSION" ]; then
SHELL_RC="$HOME/.zshrc"
elif [ -n "$BASH_VERSION" ]; then
SHELL_RC="$HOME/.bashrc"
else
# Try to detect based on $SHELL
case "$SHELL" in
*/zsh)
SHELL_RC="$HOME/.zshrc"
;;
*/bash)
SHELL_RC="$HOME/.bashrc"
;;
*)
SHELL_RC="$HOME/.profile"
;;
esac
fi
# Add environment variable to shell configuration
ENV_EXPORT="export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\""
# Check if the variable is already set in the file
if [ -f "$SHELL_RC" ] && grep -q "ICICLE_BACKEND_INSTALL_DIR" "$SHELL_RC"; then
# Replace existing line
if [[ "$OSTYPE" == "darwin"* ]]; then
# macOS
sed -i '' "s|export ICICLE_BACKEND_INSTALL_DIR=.*|$ENV_EXPORT|" "$SHELL_RC"
else
# Linux
sed -i "s|export ICICLE_BACKEND_INSTALL_DIR=.*|$ENV_EXPORT|" "$SHELL_RC"
fi
echo "Updated existing ICICLE_BACKEND_INSTALL_DIR in $SHELL_RC"
else
# Add new line
echo "$ENV_EXPORT" >> "$SHELL_RC"
echo "Added ICICLE_BACKEND_INSTALL_DIR to $SHELL_RC"
fi
echo -e "${GREEN}Environment variable set permanently.${NC}"
else
echo "Skipping permanent environment variable setup."
fi
# Export for current session regardless
export ICICLE_BACKEND_INSTALL_DIR="$INSTALL_DIR"
echo "Environment variable set for current session."
echo ""
echo -e "${GREEN}GPU setup completed successfully!${NC}"
echo ""
echo -e "${YELLOW}Important:${NC}"
echo "1. The ICICLE_BACKEND_INSTALL_DIR environment variable has been set to: $INSTALL_DIR"
if [ "$SETUP_PERMANENT_ENV" = true ]; then
echo "2. Please restart your terminal or run: source $SHELL_RC"
else
echo "2. To use GPU features, set: export ICICLE_BACKEND_INSTALL_DIR=\"$INSTALL_DIR\""
fi
echo "3. You can now build with GPU support using: cargo build --features gpu-accelerated"
echo ""
echo -e "${GREEN}Setup complete!${NC}"

View File

@@ -1,5 +1,12 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: MiMalloc = MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
@@ -15,6 +22,9 @@ use log::{error, info};
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
use rand::prelude::SliceRandom;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[cfg(feature = "icicle")]
use std::env;
#[tokio::main(flavor = "current_thread")]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn main() {
@@ -28,7 +38,12 @@ pub async fn main() {
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
banner();
#[cfg(feature = "icicle")]
if env::var("ENABLE_ICICLE_GPU").is_ok() {
info!("Running with ICICLE GPU");
} else {
info!("Running with CPU");
}
debug!(
"command: \n {}",
&command.as_json().to_colored_json_auto().unwrap()

View File

@@ -3,7 +3,7 @@
pub mod python;
/// Universal bindings for all platforms
#[cfg(any(
feature = "universal-bindings",
feature = "ios-bindings",
all(target_arch = "wasm32", target_os = "unknown")
))]
pub mod universal;

View File

@@ -93,6 +93,17 @@ impl From<PyG1> for G1 {
}
}
impl pyo3::ToPyObject for PyG1 {
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
let g1_dict = pyo3::types::PyDict::new(py);
g1_dict.set_item("x", self.x.to_object(py)).unwrap();
g1_dict.set_item("y", self.y.to_object(py)).unwrap();
g1_dict.set_item("z", self.z.to_object(py)).unwrap();
g1_dict.into()
}
}
/// pyclass containing the struct used for G1
#[pyclass]
#[derive(Debug, Clone)]
@@ -124,6 +135,16 @@ impl From<PyG1Affine> for G1Affine {
}
}
impl pyo3::ToPyObject for PyG1Affine {
fn to_object(&self, py: pyo3::Python) -> pyo3::PyObject {
let g1_dict = pyo3::types::PyDict::new(py);
g1_dict.set_item("x", self.x.to_object(py)).unwrap();
g1_dict.set_item("y", self.y.to_object(py)).unwrap();
g1_dict.into()
}
}
/// Python class containing the struct used for run_args
///
/// Returns
@@ -140,9 +161,6 @@ struct PyRunArgs {
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing parameters
pub param_scale: crate::Scale,
/// int: The scale to rebase to (optional). If None, we rebase to the max of input_scale and param_scale
/// This is an advanced parameter that should be used with caution
pub rebase_scale: Option<crate::Scale>,
#[pyo3(get, set)]
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
pub scale_rebase_multiplier: u32,
@@ -191,9 +209,6 @@ struct PyRunArgs {
/// float: epsilon used for arguments that use division
#[pyo3(get, set)]
pub epsilon: f64,
/// bool: Whether to disable using Freivalds' argument in einsum operations
#[pyo3(get, set)]
pub disable_freivalds: bool,
}
/// default instantiation of PyRunArgs
@@ -212,7 +227,6 @@ impl From<PyRunArgs> for RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
rebase_scale: py_run_args.rebase_scale,
num_inner_cols: py_run_args.num_inner_cols,
scale_rebase_multiplier: py_run_args.scale_rebase_multiplier,
lookup_range: py_run_args.lookup_range,
@@ -228,7 +242,6 @@ impl From<PyRunArgs> for RunArgs {
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
epsilon: Some(py_run_args.epsilon),
disable_freivalds: py_run_args.disable_freivalds,
}
}
}
@@ -240,7 +253,6 @@ impl Into<PyRunArgs> for RunArgs {
bounded_log_lookup: self.bounded_log_lookup,
input_scale: self.input_scale,
param_scale: self.param_scale,
rebase_scale: self.rebase_scale,
num_inner_cols: self.num_inner_cols,
scale_rebase_multiplier: self.scale_rebase_multiplier,
lookup_range: self.lookup_range,
@@ -256,7 +268,6 @@ impl Into<PyRunArgs> for RunArgs {
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
epsilon: eps,
disable_freivalds: self.disable_freivalds,
}
}
}
@@ -683,7 +694,7 @@ fn ipa_commit(
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
let srs_path =
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::IPA);
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
let srs = load_srs_prover::<IPACommitmentScheme<G1Affine>>(srs_path)
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
@@ -873,7 +884,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
))]
#[gen_stub_pyfunction]
fn get_srs(
py: Python<'_>,
py: Python,
settings_path: Option<PathBuf>,
logrows: Option<u32>,
srs_path: Option<PathBuf>,
@@ -1019,6 +1030,7 @@ fn gen_random_data(
))]
#[gen_stub_pyfunction]
fn calibrate_settings(
py: Python,
data: String,
model: PathBuf,
settings: PathBuf,
@@ -1027,7 +1039,7 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
) -> PyResult<bool> {
) -> PyResult<Bound<'_, PyAny>> {
crate::execute::calibrate(
model,
data,
@@ -1079,18 +1091,19 @@ fn calibrate_settings(
))]
#[gen_stub_pyfunction]
fn gen_witness(
py: Python,
data: String,
model: PathBuf,
output: Option<PathBuf>,
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> PyResult<PyObject> {
) -> PyResult<Bound<'_, PyAny>> {
let output =
crate::execute::gen_witness(model, data, output, vk_path, srs_path).map_err(|e| {
let err_str = format!("Failed to generate witness: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.into_pyobject(py).unwrap().into()))
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// Mocks the prover
@@ -1272,7 +1285,7 @@ fn prove(
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(snark.into_pyobject(py).unwrap().into()))
Python::with_gil(|py| Ok(snark.to_object(py)))
}
/// Verifies a given proof
@@ -1638,7 +1651,7 @@ fn encode_evm_calldata<'a>(
))]
#[gen_stub_pyfunction]
fn create_evm_verifier(
py: Python<'_>,
py: Python,
vk_path: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
@@ -1665,7 +1678,6 @@ fn create_evm_verifier(
})
}
#[cfg(feature = "reusable-verifier")]
/// Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
/// This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
///
@@ -1699,7 +1711,7 @@ fn create_evm_verifier(
))]
#[gen_stub_pyfunction]
fn create_evm_vka(
py: Python<'_>,
py: Python,
vk_path: PathBuf,
settings_path: PathBuf,
vka_path: PathBuf,
@@ -1729,7 +1741,7 @@ fn create_evm_vka(
))]
#[gen_stub_pyfunction]
fn deploy_evm(
py: Python<'_>,
py: Python,
addr_path: PathBuf,
rpc_url: String,
sol_code_path: PathBuf,
@@ -1756,7 +1768,6 @@ fn deploy_evm(
})
}
#[cfg(feature = "reusable-verifier")]
/// Registers a VKA on the EZKL reusable verifier contract
///
/// Arguments
@@ -1830,9 +1841,6 @@ fn register_vka<'a>(
///
/// vka_path: str
/// The path to the VKA calldata bytes file (generated using the create_evm_vka command)
///
/// encoded_calldata: str
/// The path to the encoded calldata bytes file (generated using the encode calldata command)
/// Returns
/// -------
/// bool
@@ -1842,7 +1850,6 @@ fn register_vka<'a>(
rpc_url,
proof_path=PathBuf::from(DEFAULT_PROOF),
vka_path = None,
encoded_calldata = None,
))]
#[gen_stub_pyfunction]
fn verify_evm<'a>(
@@ -1851,23 +1858,16 @@ fn verify_evm<'a>(
rpc_url: String,
proof_path: PathBuf,
vka_path: Option<PathBuf>,
encoded_calldata: Option<PathBuf>,
) -> PyResult<Bound<'a, PyAny>> {
let addr_verifier = H160Flag::from(addr_verifier);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::verify_evm(
proof_path,
addr_verifier,
rpc_url,
vka_path,
encoded_calldata,
)
.await
.map_err(|e| {
let err_str = format!("Failed to run verify_evm: {}", e);
PyRuntimeError::new_err(err_str)
})?;
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, vka_path)
.await
.map_err(|e| {
let err_str = format!("Failed to run verify_evm: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
@@ -1913,7 +1913,7 @@ fn verify_evm<'a>(
))]
#[gen_stub_pyfunction]
fn create_evm_verifier_aggr(
py: Python<'_>,
py: Python,
aggregation_settings: Vec<PathBuf>,
vk_path: PathBuf,
sol_code_path: PathBuf,
@@ -1984,13 +1984,11 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
#[cfg(feature = "reusable-verifier")]
m.add_function(wrap_pyfunction!(create_evm_vka, m)?)?;
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
#[cfg(feature = "reusable-verifier")]
m.add_function(wrap_pyfunction!(register_vka, m)?)?;
Ok(())
}

View File

@@ -1,6 +1,7 @@
use halo2_proofs::{
plonk::*,
poly::{
VerificationStrategy,
commitment::{CommitmentScheme, ParamsProver},
ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
@@ -12,7 +13,6 @@ use halo2_proofs::{
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy as KZGSingleStrategy,
},
VerificationStrategy,
},
};
use std::fmt::Display;
@@ -20,23 +20,19 @@ use std::io::BufReader;
use std::str::FromStr;
use crate::{
CheckMode, Commitments, EZKLError as InnerEZKLError,
circuit::region::RegionSettings,
graph::GraphSettings,
pfsys::{
create_proof_circuit, encode_calldata,
TranscriptType, create_proof_circuit,
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
verify_proof_circuit, TranscriptType,
verify_proof_circuit,
},
tensor::TensorType,
CheckMode, Commitments, EZKLError as InnerEZKLError,
};
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::graph::{GraphCircuit, GraphWitness};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::{FromUniformBytes, PrimeField},
@@ -65,32 +61,10 @@ impl From<InnerEZKLError> for EZKLError {
}
}
/// Hash the input message with poseidon
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn poseidon_hash(message: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
let message: Vec<Fr> = serde_json::from_slice(&message[..]).map_err(InnerEZKLError::from)?;
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
.map_err(InnerEZKLError::from)?;
Ok(serde_json::to_vec(&output).map_err(InnerEZKLError::from)?)
}
/// Hash the input message with poseidon without converting to Fr
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn poseidon_hash_no_felt(message: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
let message: Vec<Fr> = message.iter().map(|x| Fr::from(*x as u64)).collect();
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
.map_err(InnerEZKLError::from)?;
Ok(serde_json::to_vec(&output).map_err(InnerEZKLError::from)?)
}
/// Encode verifier calldata from proof and ethereum vk_address
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn encode_verifier_calldata(
// TODO - shuold it be pub or pub or pub(super)?
pub(crate) fn encode_verifier_calldata(
// TODO - shuold it be pub(crate) or pub or pub(super)?
proof: Vec<u8>,
vka: Option<Vec<u8>>,
) -> Result<Vec<u8>, EZKLError> {
@@ -116,23 +90,18 @@ pub fn encode_verifier_calldata(
/// Generate witness from compiled circuit and input json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn gen_witness(compiled_circuit: Vec<u8>, input: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
println!("[circuit]");
pub(crate) fn gen_witness(compiled_circuit: Vec<u8>, input: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..])
.map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize compiled model: {}", e))
})?;
println!("[input]");
let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..])
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize input: {}", e)))?;
println!("[load graph input]");
let mut input = circuit
.load_graph_input(&input)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
println!("[load graph witness]");
let witness = circuit
.forward::<KZGCommitmentScheme<Bn256>>(
&mut input,
@@ -145,14 +114,13 @@ pub fn gen_witness(compiled_circuit: Vec<u8>, input: Vec<u8>) -> Result<Vec<u8>,
)
.map_err(|e| EZKLError::InternalError(format!("{}", e)))?;
println!("[serialize witness]");
serde_json::to_vec(&witness)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize witness: {}", e)))
}
/// Generate verifying key from compiled circuit, and parameters srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn gen_vk(
pub(crate) fn gen_vk(
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
compress_selectors: bool,
@@ -182,7 +150,11 @@ pub fn gen_vk(
/// Generate proving key from vk, compiled circuit and parameters srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn gen_pk(vk: Vec<u8>, compiled_circuit: Vec<u8>, srs: Vec<u8>) -> Result<Vec<u8>, EZKLError> {
pub(crate) fn gen_pk(
vk: Vec<u8>,
compiled_circuit: Vec<u8>,
srs: Vec<u8>,
) -> Result<Vec<u8>, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> = get_params(&mut reader)?;
@@ -209,7 +181,7 @@ pub fn gen_pk(vk: Vec<u8>, compiled_circuit: Vec<u8>, srs: Vec<u8>) -> Result<Ve
/// Verify proof with vk, proof json, circuit settings json and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn verify(
pub(crate) fn verify(
proof: Vec<u8>,
vk: Vec<u8>,
settings: Vec<u8>,
@@ -291,7 +263,7 @@ pub fn verify(
/// Verify aggregate proof with vk, proof, circuit settings and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn verify_aggr(
pub(crate) fn verify_aggr(
proof: Vec<u8>,
vk: Vec<u8>,
logrows: u64,
@@ -373,7 +345,7 @@ pub fn verify_aggr(
/// Prove in browser with compiled circuit, witness json, proving key, and srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn prove(
pub(crate) fn prove(
witness: Vec<u8>,
pk: Vec<u8>,
compiled_circuit: Vec<u8>,
@@ -471,7 +443,7 @@ pub fn prove(
/// Validate the witness json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn witness_validation(witness: Vec<u8>) -> Result<bool, EZKLError> {
pub(crate) fn witness_validation(witness: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?;
Ok(true)
@@ -479,7 +451,7 @@ pub fn witness_validation(witness: Vec<u8>) -> Result<bool, EZKLError> {
/// Validate the compiled circuit
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn compiled_circuit_validation(compiled_circuit: Vec<u8>) -> Result<bool, EZKLError> {
pub(crate) fn compiled_circuit_validation(compiled_circuit: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphCircuit = bincode::deserialize(&compiled_circuit[..]).map_err(|e| {
EZKLError::InternalError(format!("Failed to deserialize compiled circuit: {}", e))
})?;
@@ -489,7 +461,7 @@ pub fn compiled_circuit_validation(compiled_circuit: Vec<u8>) -> Result<bool, EZ
/// Validate the input json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn input_validation(input: Vec<u8>) -> Result<bool, EZKLError> {
pub(crate) fn input_validation(input: Vec<u8>) -> Result<bool, EZKLError> {
let _: crate::graph::input::GraphData =
serde_json::from_slice(&input[..]).map_err(InnerEZKLError::from)?;
@@ -498,7 +470,7 @@ pub fn input_validation(input: Vec<u8>) -> Result<bool, EZKLError> {
/// Validate the proof json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn proof_validation(proof: Vec<u8>) -> Result<bool, EZKLError> {
pub(crate) fn proof_validation(proof: Vec<u8>) -> Result<bool, EZKLError> {
let _: crate::pfsys::Snark<Fr, G1Affine> =
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
@@ -507,7 +479,7 @@ pub fn proof_validation(proof: Vec<u8>) -> Result<bool, EZKLError> {
/// Validate the verifying key given the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings =
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
@@ -524,7 +496,7 @@ pub fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError>
/// Validate the proving key given the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError> {
let circuit_settings: GraphSettings =
serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
@@ -541,7 +513,7 @@ pub fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKLError>
/// Validate the settings json
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn settings_validation(settings: Vec<u8>) -> Result<bool, EZKLError> {
pub(crate) fn settings_validation(settings: Vec<u8>) -> Result<bool, EZKLError> {
let _: GraphSettings = serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?;
Ok(true)
@@ -549,7 +521,7 @@ pub fn settings_validation(settings: Vec<u8>) -> Result<bool, EZKLError> {
/// Validate the srs
#[cfg_attr(feature = "ios-bindings", uniffi::export)]
pub fn srs_validation(srs: Vec<u8>) -> Result<bool, EZKLError> {
pub(crate) fn srs_validation(srs: Vec<u8>) -> Result<bool, EZKLError> {
let mut reader = BufReader::new(&srs[..]);
let _: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(|e| {

View File

@@ -1,5 +1,12 @@
use crate::{
circuit::modules::polycommit::PolyCommitChip,
circuit::modules::{
polycommit::PolyCommitChip,
poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
},
Module,
},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
graph::{quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings},
};
@@ -8,7 +15,6 @@ use halo2_proofs::{
plonk::*,
poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG},
};
use halo2_solidity_verifier::Evm;
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::PrimeField,
@@ -219,9 +225,15 @@ pub fn bufferToVecOfFelt(
pub fn poseidonHash(
message: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
super::universal::poseidon_hash(message.0)
.map_err(JsError::from)
.map(|x| wasm_bindgen::Clamped(x.clone()))
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
.map_err(|e| JsError::new(&format!("{}", e)))?;
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(
|e| JsError::new(&format!("Failed to serialize poseidon hash output: {}", e)),
)?))
}
/// Generate a witness file from input.json, compiled model and a settings.json file.
@@ -267,33 +279,6 @@ pub fn verify(
super::universal::verify(proof_js.0, vk.0, settings.0, srs.0).map_err(JsError::from)
}
/// Verify proof in browser evm using wasm
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn verifyEVM(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
bytecode_verifier: Vec<u8>,
bytecode_vka: Option<Vec<u8>>,
) -> Result<bool, JsError> {
let mut evm = Evm::unlimited();
let decoded_verifier = utf8_bytes_to_hex_decoded(&bytecode_verifier)?;
let (verifier_address, _) = evm.create(decoded_verifier);
// if bytecode_vk is Some, then create the vk contract
let vk_address = if let Some(bytecode_vka) = bytecode_vka {
let decoded_vka = utf8_bytes_to_hex_decoded(&bytecode_vka)?;
let (address, _) = evm.create(decoded_vka);
Some(address.as_slice().to_vec())
// check if bytecode_verifier is none and if so then generate the
// reusable verifier
} else {
None
};
let calldata = encode_verifier_calldata(proof_js.0, vk_address).map_err(JsError::from);
let output = evm.call(verifier_address, calldata?).1;
let true_word = [vec![0; 31], vec![1]].concat();
Ok(output == true_word)
}
/// Verify aggregate proof in browser using wasm
#[wasm_bindgen]
#[allow(non_snake_case)]
@@ -386,13 +371,3 @@ pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
}
n
}
///
pub fn utf8_bytes_to_hex_decoded(input: &[u8]) -> Result<Vec<u8>, JsError> {
let string = std::str::from_utf8(input)?.trim();
let hex_string = if string.starts_with("0x") {
&string[2..]
} else {
string
};
hex::decode(hex_string).map_err(JsError::from)
}

View File

@@ -7,14 +7,17 @@ use halo2_proofs::{
};
use log::debug;
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*, IntoPyObject};
use pyo3::{
conversion::{FromPyObject, IntoPy},
exceptions::PyValueError,
prelude::*,
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
use crate::{
circuit::{
chip::einsum::analysis::EinsumAnalysis,
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
},
@@ -25,9 +28,6 @@ use std::{collections::BTreeMap, marker::PhantomData};
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
use halo2curves::ff::{Field, PrimeField};
///
pub mod einsum;
#[allow(missing_docs)]
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
#[derive(
@@ -86,17 +86,12 @@ impl CheckMode {
#[cfg(feature = "python-bindings")]
/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python)
impl<'py> IntoPyObject<'py> for CheckMode {
type Target = pyo3::PyAny;
type Output = pyo3::Bound<'py, Self::Target>;
type Error = pyo3::PyErr;
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
let result = match self {
CheckMode::SAFE => "safe",
CheckMode::UNSAFE => "unsafe",
};
Ok(result.into_pyobject(py)?.into_any())
impl IntoPy<PyObject> for CheckMode {
fn into_py(self, py: Python) -> PyObject {
match self {
CheckMode::SAFE => "safe".to_object(py),
CheckMode::UNSAFE => "unsafe".to_object(py),
}
}
}
@@ -270,8 +265,6 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub range_checks: RangeChecks<F>,
/// [Selector]s for the shuffles
pub shuffles: Shuffles,
/// Einsum-specific configuration
pub einsums: Option<einsum::Einsums<F>>,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
@@ -286,22 +279,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
einsums: Some(einsum::Einsums::<F>::dummy(col_size, num_inner_cols)),
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
shared_table_inputs: vec![],
_marker: PhantomData,
}
}
/// Returns a new [BaseConfig] with no inputs, no selectors, no tables, and no Freivalds' argument.
pub fn dummy_without_freivalds(col_size: usize, num_inner_cols: usize) -> Self {
Self {
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
einsums: None,
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
@@ -436,7 +413,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
},
static_lookups: StaticLookups::default(),
dynamic_lookups: DynamicLookups::default(),
einsums: None,
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
shared_table_inputs: vec![],
@@ -711,27 +687,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
Ok(())
}
/// Configures and creates einsums
#[allow(clippy::too_many_arguments)]
pub fn configure_einsums(
&mut self,
cs: &mut ConstraintSystem<F>,
analysis: &EinsumAnalysis,
num_inner_cols: usize,
logrows: usize,
) -> Result<(), CircuitError>
where
F: Field,
{
self.einsums = Some(einsum::Einsums::configure_universal(
cs,
analysis,
num_inner_cols,
logrows,
));
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_shuffles(
@@ -1007,7 +962,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn layout(
&mut self,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)

View File

@@ -1,210 +0,0 @@
use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use crate::circuit::{
einsum::reduction_planner::{self, Reduction},
CircuitError,
};
///
#[derive(Debug, Clone)]
pub struct EinsumAnalysis {
/// max size of input tensors
pub max_input_size: usize,
/// max size of output tensors
pub max_output_size: usize,
/// max number of input tensors
pub max_num_inputs: usize,
/// max number of output axes
pub max_num_output_axes: usize,
/// the sum of the lengths of dot product to compute all the reductions
pub reduction_length: usize,
}
/// The strategy to use for einsum
#[derive(Debug, Clone)]
pub enum EinsumStrategy {
/// Use only base ops
BaseOps,
/// Use Freivalds' argument
Freivalds,
}
///
#[derive(Debug, Clone)]
pub struct SingleEquationAnalysis {
///
pub equation: String,
///
pub num_inputs: usize,
///
pub max_input_size: usize,
///
pub output_size: usize,
///
pub num_output_axes: usize,
///
pub output_indices: Vec<char>,
/// the length of dot product to compute all the reductions
pub reduction_length: usize,
/// the strategy to use for einsum
pub strategy: EinsumStrategy,
}
///
pub fn analyze_einsum_usage(
equations: &HashMap<(usize, String), HashMap<char, usize>>,
) -> Result<EinsumAnalysis, CircuitError> {
let mut max_num_inputs = 0;
let mut max_input_size = 0;
let mut max_output_size = 0;
let mut max_num_output_axes = 0;
let mut reduction_length = 0;
for ((_, equation), input_axes_to_dim) in equations.iter() {
let analysis = analyze_single_equation(equation, input_axes_to_dim)?;
max_input_size = max_input_size.max(analysis.max_input_size);
max_output_size = max_output_size.max(analysis.output_size);
max_num_inputs = max_num_inputs.max(analysis.num_inputs);
max_num_output_axes = max_num_output_axes.max(analysis.num_output_axes);
reduction_length += analysis.reduction_length;
}
Ok(EinsumAnalysis {
max_input_size,
max_output_size,
max_num_inputs,
max_num_output_axes,
reduction_length,
})
}
///
pub fn analyze_single_equation(
equation: &str,
input_axes_to_dim: &HashMap<char, usize>,
) -> Result<SingleEquationAnalysis, CircuitError> {
// Sanitise equation to remove trivial axes
let equation = {
let (inputs_str, output_str) = equation.split_once("->").unwrap();
let input_equations: Vec<&str> = inputs_str.split(',').collect();
let inputs: Vec<String> = input_equations
.iter()
.map(|input| {
input
.chars()
.filter(|char| *input_axes_to_dim.get(char).unwrap() > 1)
.collect()
})
.collect();
let output = output_str
.chars()
.filter(|c| {
input_axes_to_dim.get(c).is_some() && *input_axes_to_dim.get(c).unwrap() > 1
})
.collect();
[inputs.join(","), output].join("->")
};
let (inputs_eq, output_eq) = equation.split_once("->").unwrap();
let input_equations: Vec<&str> = inputs_eq.split(',').collect();
let max_input_size = input_equations
.iter()
.map(|eqn| {
eqn.chars()
.map(|c| input_axes_to_dim.get(&c).unwrap())
.product()
})
.max()
.unwrap();
let output_indices: Vec<char> = output_eq.chars().collect();
let output_dims = output_indices
.iter()
.map(|c| input_axes_to_dim.get(&c).unwrap());
let output_size = output_dims.clone().product();
let output_reduction_length = {
let mut output_dims = output_dims.rev().cloned().collect_vec();
let mut total_length = 0;
for _ in 0..output_dims.len() {
let dot_product_len = output_dims.remove(0);
let num_dot_products: usize = output_dims.iter().product();
total_length += dot_product_len * num_dot_products;
}
total_length
};
let input_reductions_length = {
let input_reductions = reduction_planner::input_reductions(&equation)?;
input_reductions
.into_iter()
.map(|reduction| {
let (_, output_expr) = reduction.expression().split_once("->").unwrap();
let num_inputs = reduction.input_indices().len();
let dot_product_len = match reduction {
Reduction::RLC { axis, .. } => *input_axes_to_dim.get(&axis).unwrap(),
Reduction::Contraction { axis, .. } => *axis
.and_then(|axis| input_axes_to_dim.get(&axis))
.unwrap_or(&1),
};
let num_dot_products: usize = output_expr
.chars()
.map(|c| input_axes_to_dim.get(&c).unwrap())
.product();
// since `multi_dot` does pairwise mult between input pairs and final summation
if num_inputs <= 2 {
num_dot_products * dot_product_len
} else {
num_dot_products * (dot_product_len * num_inputs)
}
})
.sum::<usize>()
};
let dispatch_to_einsum_with_base_ops = {
let mut seen = HashSet::new();
let mut common_indices_to_inputs = vec![];
for input in input_equations.iter() {
for c in input.chars() {
if !seen.contains(&c) {
seen.insert(c);
} else {
common_indices_to_inputs.push(c);
}
}
}
let non_common_indices = input_axes_to_dim
.keys()
.filter(|&x| {
!common_indices_to_inputs.contains(x)
&& input_axes_to_dim.get(x).cloned().unwrap() > 1
})
.collect::<Vec<_>>();
!(output_indices.len() > 0
&& common_indices_to_inputs.len() > 0
&& non_common_indices.len() > 1)
};
let strategy = if dispatch_to_einsum_with_base_ops {
EinsumStrategy::BaseOps
} else {
EinsumStrategy::Freivalds
};
Ok(SingleEquationAnalysis {
output_size,
max_input_size,
equation: equation.to_string(),
num_inputs: input_equations.len(),
num_output_axes: output_indices.len(),
output_indices,
reduction_length: output_reduction_length + input_reductions_length,
strategy,
})
}

View File

@@ -1,54 +0,0 @@
use std::{collections::HashMap, marker::PhantomData};
use halo2_proofs::circuit::Value;
use halo2curves::ff::PrimeField;
use crate::{
circuit::CircuitError,
tensor::{Tensor, TensorError, TensorType},
};
/// Circuit parameter for a single einsum equation
#[derive(Clone, Debug, Default)]
pub struct SingleEinsumParams<F: PrimeField + TensorType + PartialOrd> {
///
pub equation: String,
/// Map from input axes to dimensions
pub input_axes_to_dims: HashMap<char, usize>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> SingleEinsumParams<F> {
///
pub fn new(equation: &str, inputs: &[&Tensor<Value<F>>]) -> Result<Self, CircuitError> {
let mut eq = equation.split("->");
let inputs_eq = eq.next().ok_or(CircuitError::InvalidEinsum)?;
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();
// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
let mut input_axes_to_dims = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
for j in 0..inputs_eq[i].len() {
let c = inputs_eq[i]
.chars()
.nth(j)
.ok_or(CircuitError::InvalidEinsum)?;
if let std::collections::hash_map::Entry::Vacant(e) = input_axes_to_dims.entry(c) {
e.insert(input.dims()[j]);
} else if input_axes_to_dims[&c] != input.dims()[j] {
return Err(TensorError::DimMismatch("einsum".to_string()).into());
}
}
}
Ok(Self {
equation: equation.to_owned(),
input_axes_to_dims,
_marker: PhantomData,
})
}
}

View File

@@ -1,359 +0,0 @@
use halo2curves::ff::PrimeField;
use log::{error, trace};
use crate::{
circuit::{base::BaseOp, einsum::BaseOpInfo, region::RegionCtx, CheckMode, CircuitError},
tensor::{
get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
TensorError, TensorType, ValTensor, ValType,
},
};
use super::ContractionConfig;
/// Pairwise (elementwise) op layout
pub fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 2],
op: BaseOp,
phases: &[usize; 2],
) -> Result<ValTensor<F>, CircuitError> {
let (mut lhs, mut rhs) = if phases[0] <= phases[1] {
(values[0].clone(), values[1].clone())
} else {
(values[1].clone(), values[0].clone())
};
let broadcasted_shape = get_broadcasted_shape(lhs.dims(), rhs.dims())?;
lhs.expand(&broadcasted_shape)?;
rhs.expand(&broadcasted_shape)?;
if lhs.len() != rhs.len() {
return Err(CircuitError::DimMismatch(format!(
"pairwise {} layout",
op.as_str()
)));
}
region.flush_einsum()?;
let input_vars = config.get_input_vars(phases.as_slice().into());
let output_var = config.get_output_var(phases.as_slice().into());
let inputs = [lhs, rhs]
.iter()
.zip(input_vars)
.map(|(val, var)| {
let res = region.assign_einsum(var, val)?;
Ok(res.get_inner()?)
})
.collect::<Result<Vec<_>, CircuitError>>()?;
// Now we can assign the dot product
// time the calc
let op_result = match op {
BaseOp::Add => add(&inputs),
BaseOp::Sub => sub(&inputs),
BaseOp::Mult => mult(&inputs),
_ => return Err(CircuitError::UnsupportedOp),
}
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let assigned_len = op_result.len();
let mut output = region.assign_einsum(output_var, &op_result.into())?;
// Enable the selectors
if !region.is_dummy() {
(0..assigned_len)
.map(|i| {
let (x, y, z) = output_var.cartesian_coord(region.einsum_col_coord() + i);
let op_info = BaseOpInfo {
op_kind: op.clone(),
input_phases: phases.as_slice().into(),
};
let selector = config.selectors.get(&(op_info, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
region.increment_einsum_col_coord(assigned_len);
output.reshape(&broadcasted_shape)?;
Ok(output)
}
pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 1],
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() == 1 {
return Ok(values[0].clone());
}
assert!(phase == 0 || phase == 1);
region.flush_einsum()?;
let mut input = values[0].clone();
let block_width = config.block_width();
let assigned_len: usize;
let input = {
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let var = config.get_input_vars([phase].as_slice().into())[0];
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
res.get_inner()?
};
// Now we can assign the dot product
let accumulated_sum = accumulated::sum(&input, block_width)?;
let output_var = config.get_output_var([phase].as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
output_var,
&accumulated_sum.into(),
check_mode,
)?;
// enable the selectors
if !region.is_dummy() {
for i in 0..output_assigned_len {
let (x, _, z) = output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// skip over duplicates at start of column
if z == 0 && i > 0 {
continue;
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::SumInit,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::Sum,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
}
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
// last element is the result
Ok(last_elem)
}
pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 1],
phase: usize,
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert!(phase == 0 || phase == 1);
region.flush_einsum()?;
let block_width = config.block_width();
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
let var = config.get_input_vars([phase].as_slice().into())[0];
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &input)?;
assigned_len = len;
res.get_inner()?
};
// Now we can assign the dot product
let accumulated_prod = accumulated::prod(&input, block_width)?;
let output_var = config.get_output_var([phase].as_slice().into());
let (output, output_assigned_len) = region.assign_einsum_with_duplication_constrained(
output_var,
&accumulated_prod.into(),
check_mode,
)?;
// enable the selectors
if !region.is_dummy() {
(0..output_assigned_len)
.map(|i| {
let (x, _, z) =
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// skip over duplicates at start of column
if z == 0 && i > 0 {
return Ok(());
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::CumProdInit,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::CumProd,
input_phases: [phase].as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
// last element is the result
Ok(last_elem)
}
pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>; 2],
phases: &[usize; 2],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
region.flush_einsum()?;
// time this entire function run
let global_start = instant::Instant::now();
let mut values = if phases[0] <= phases[1] {
[values[0].clone(), values[1].clone()]
} else {
[values[1].clone(), values[0].clone()]
};
let vars = config.get_input_vars(phases.as_slice().into());
let mut inputs = vec![];
let block_width = config.block_width();
let mut assigned_len = 0;
for (val, var) in values.iter_mut().zip(vars) {
val.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
let inp = {
let (res, len) = region.assign_einsum_with_duplication_unconstrained(var, &val)?;
assigned_len = len;
res.get_inner()?
};
inputs.push(inp);
}
// Now we can assign the dot product
// time this step
let accumulated_dot = accumulated::dot(&inputs[0], &inputs[1], block_width)?;
let output_var = config.get_output_var(phases.as_slice().into());
let (output, output_assigned_len) = region
.assign_einsum_with_duplication_constrained(output_var, &accumulated_dot.into(), check_mode)
.expect("failed to assign einsum with duplication constrained");
// enable the selectors
if !region.is_dummy() {
(0..output_assigned_len)
.map(|i| {
let (x, _, z) =
output_var.cartesian_coord(region.einsum_col_coord() + i * block_width);
// hop over duplicates at start of column
if z == 0 && i > 0 {
return Ok(());
}
let selector = if i == 0 {
let op_info = BaseOpInfo {
op_kind: BaseOp::DotInit,
input_phases: phases.as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
} else {
let op_info = BaseOpInfo {
op_kind: BaseOp::Dot,
input_phases: phases.as_slice().into(),
};
config.selectors.get(&(op_info, x, 0))
};
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.last()?;
region.increment_einsum_col_coord(assigned_len);
let elapsed = global_start.elapsed();
trace!("dot layout took: {:?}, row {}", elapsed, region.row());
trace!("----------------------------");
Ok(last_elem)
}
/// Dot product of more than two tensors
pub fn multi_dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &ContractionConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
phases: &[usize],
check_mode: &CheckMode,
) -> Result<ValTensor<F>, CircuitError> {
assert!(phases.iter().all(|phase| *phase == 0 || *phase == 1));
if !values.iter().all(|value| value.len() == values[0].len()) {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
// time this entire function run
let global_start = instant::Instant::now();
let values: Vec<ValTensor<F>> = values.iter().copied().cloned().collect();
// do pairwise dot product between intermediate tensor and the next tensor
let (intermediate, output_phase) = values
.into_iter()
.zip(phases.iter().cloned())
.reduce(|(intermediate, intermediate_phase), (input, phase)| {
(
pairwise(
config,
region,
&[&intermediate, &input],
BaseOp::Mult,
&[intermediate_phase, phase],
)
.unwrap(),
std::cmp::max(intermediate_phase, phase),
)
})
.unwrap();
let accumulated_dot = sum(config, region, &[&intermediate], output_phase, check_mode)?;
let last_elem = accumulated_dot.last()?;
let elapsed = global_start.elapsed();
trace!("multi_dot layout took: {:?}, row {}", elapsed, region.row());
trace!("----------------------------");
Ok(last_elem)
}

Some files were not shown because too many files have changed in this diff Show More