Compare commits

..

95 Commits

Author SHA1 Message Date
dante
1a75963705 refactor: DataSource enum -> struct 2025-04-29 12:18:24 -04:00
dante
0ef1f35e59 fix: uniffi bindings 2025-04-29 12:12:56 -04:00
dante
808ab7d0de chore: feature-gate eth (#978) 2025-04-29 12:02:23 -04:00
dante
68b2c96b97 Merge branch 'vka-hashing' of https://github.com/zkonduit/ezkl into vka-hashing 2025-04-29 11:31:20 -04:00
dante
9a0ab22fdb fix matches 2025-04-29 11:31:13 -04:00
dante
f2b1de3740 Merge branch 'main' into vka-hashing 2025-04-29 11:26:04 -04:00
dante
839030ce10 chore: rm halo2proofs patches (#976) 2025-04-29 10:58:35 -04:00
dante
cfccc5460c refactor: rm postgres (#977) 2025-04-29 08:59:14 -04:00
Ethan
dcb888ff1e fix wasm package graph data import error 2025-04-28 16:29:09 -05:00
Ethan
26f465e70c bring back zizmor analysis 2025-04-28 08:21:26 -05:00
Ethan
8eef53213d rmv data attestation 2025-04-27 19:36:54 -05:00
Ethan
a1345966d7 configure Git credentials more persistently 2025-04-27 18:09:37 -05:00
Ethan
640061c850 set git config after action checkouts 2025-04-27 17:48:30 -05:00
Ethan
da7db7d88d use git config local instead of global 2025-04-27 17:20:24 -05:00
Ethan
a55f75ff3f rmv debug statement on token 2025-04-24 11:19:19 -05:00
Ethan
bf6f704827 debug token 2025-04-24 10:56:27 -05:00
Ethan
0dbfdf4672 debug token 2025-04-24 10:54:56 -05:00
Ethan
98299356a6 *fix syntax error on yaml 2025-04-24 10:51:24 -05:00
Ethan
04805d2a91 move token env to job level 2025-04-24 10:42:35 -05:00
Ethan
ca18cf29bb set token as global env var 2025-04-24 10:36:37 -05:00
Ethan
78f8e23b55 use verification ezkl token 2025-04-24 10:26:15 -05:00
Ethan
7d40926082 activate git fetch with cli on runner 2025-04-24 09:53:20 -05:00
Ethan
e2c8182871 *update python bindings 2025-04-24 09:43:55 -05:00
Ethan
4f077c9134 *use https for loading h2 sol verifier crate 2025-04-23 21:57:07 -05:00
Ethan
038805ce02 Merge branch 'main' into vka-hashing 2025-04-23 21:32:56 -05:00
Ethan
0fb87c9a20 *update lock 2025-04-23 21:30:43 -05:00
Ethan
77423a6d07 *check that on-chain rescaled instances match what is stored in proof file. 2025-04-23 21:25:35 -05:00
dante
0de0682bfa refactor: configurable div epsilon (#968) 2025-04-23 09:12:24 +01:00
dante
bf9cf14ab7 refactor!: rpc url should be required (#965)
BREAKING CHANGE: in python the order of arguments for evm related functions has changed
2025-04-22 12:45:36 +01:00
Ethan
8b416c7a00 *comment out swift package test 2025-04-21 04:31:51 -05:00
Ethan
73ec5e549a *temporarily disable zizmor + swift package on ci. 2025-04-21 04:27:36 -05:00
Ethan
28386d8442 vka hashing + rescaling 2025-04-21 04:13:31 -05:00
dante
6818962ac2 chore: pass in raw data for gen-witness from file (#964) 2025-04-06 14:08:11 -04:00
dante
70469e3bf9 chore: add min/max to gen-random-data (#960) 2025-03-25 19:32:15 +00:00
dante
52ff187e55 refactor: command struct names should match str (#959) 2025-03-24 12:54:43 +00:00
dante
4e57a5a486 docs: link to audit (#958)
---------

Co-authored-by: Jason Morton <jason.morton@gmail.com>
2025-03-23 21:12:44 +00:00
Ethan Cemer
fe978caa85 fix!: bug fixes (#956)
BREAKING CHANGE: DA verifier no longer backwards compatible
2025-03-18 22:08:29 +00:00
dante
1bef92407c fix: recip denom epsilon can induce non opt res (#957) 2025-03-17 14:46:33 +00:00
dante
5ff1c48ede refactor: allow for negative stride downsample (#955) 2025-03-14 12:07:37 -04:00
dante
ab4997d0c2 chore: update docs and panics (#952) 2025-03-10 11:32:28 -04:00
dante
701e69dd2f fix: handle [] shapes in sort (#954) 2025-03-08 13:17:45 -05:00
dante
f631445e26 docs: document arguments better (#950) 2025-03-05 16:10:50 -05:00
dante
fcbb27677f fix: empty dim len can be 1 (#949) 2025-02-28 23:56:19 -05:00
dante
bc26691bd5 chore: smaller cat dog example (#947) 2025-02-28 10:37:08 -05:00
dante
73c813a81d feat: pass data directly in cli (#939) 2025-02-13 12:35:13 -05:00
dante
ae076aef09 refactor: rm tolerance parameter (#937) 2025-02-11 12:57:18 -05:00
dante
a7544f4060 feat: generalize conv mem layout and ND (#935) 2025-02-10 09:11:58 -05:00
dante
c19fa5218a refactor: enforce max decomp base/legs in args (#936) 2025-02-09 16:15:40 -05:00
rebustron
eb205d0c73 chore: fix typos in comments and docs (#934) 2025-02-08 19:13:17 -05:00
dante
db498f8d7c docs: cat-dog example (#929) 2025-02-08 17:30:13 -05:00
Cypher Pepe
a363c91160 fix: broken links in polycommit.rs and poseidon.rs (#932) 2025-02-08 12:40:53 -05:00
dante
f7f04415fa chore!: add model input/output types to settings (#933)
BREAKING CHANGE: compiled model serialization is not backwards compatible
2025-02-07 16:05:59 -05:00
Jseam
de8d419e5d ci: change to sha hashes (#922) 2025-02-07 12:27:35 -05:00
dante
a38d318923 fix: pypi publication pipeline (#931) 2025-02-05 23:03:21 -05:00
dante
864990fe2d fix: publishing path 2025-02-05 19:57:13 -05:00
dante
29c3e4f977 fix: bump download artifact to v4 2025-02-05 19:05:29 -05:00
dante
0689115828 fix: ezkl-gpu name (#930) 2025-02-05 18:29:28 -05:00
dante
99f741304a Revert "fix: ezkl-gpu install"
This reverts commit 20ac99fdbf.
2025-02-05 18:03:46 -05:00
dante
20ac99fdbf fix: ezkl-gpu install 2025-02-05 18:01:26 -05:00
dante
532fa65e93 fix: patch python release pipeline for v4 2025-02-05 17:59:35 -05:00
dante
cfe5db545c fix: npm and pypi releases 2025-02-05 17:26:36 -05:00
dante
21ad56aea1 refactor: serial lookup commits for metal (#928) 2025-02-05 16:54:12 -05:00
dante
4ed7e0fd29 fix: use variable len domain for poseidon (#927) 2025-02-05 16:52:28 -05:00
dante
05d1f10615 docs: advanced security notices (#926)
---------

Co-authored-by: jason <jason.morton@gmail.com>
2025-02-05 15:14:29 +00:00
dante
9a8c754e45 fix: use onnx convention when integer dividing (#925) 2025-02-05 09:32:44 +00:00
dante
d82766d413 fix: force prover det on argmax/min for collisions (#923) 2025-02-04 12:08:34 +00:00
dante
820a80122b fix: range-check graph input and outputs (#921) 2025-02-04 02:33:27 +00:00
dante
9c64e42bd3 docs: improve quality + code quality fixes (#920) 2025-01-31 10:48:25 +00:00
dante
27b5e5dde3 fix: make flushing err more informative (#919) 2025-01-28 14:53:05 -05:00
dante
83c4afce3b fix: version interpolation in npm publishing (#917) 2025-01-27 23:20:58 -05:00
dante
50740a22df fix: patch pypi whl version labels (#916) 2025-01-27 20:25:03 -05:00
dante
a2624f6303 fix: strict cvx opt bounds to stop prover non-det (#914) 2025-01-24 08:48:50 -05:00
dante
fc5be4f949 fix: syn-sel should be range-checked when overflow (#913) 2025-01-23 12:26:31 -05:00
dante
d0ba505baa fix: node parsing should not panic (#912) 2025-01-22 08:02:29 -05:00
dante
f35688917d fix: rm macos metal bindings from python (#911) 2025-01-21 00:36:57 -05:00
Artem
7ae541ed35 feat: metal acceleration for MSM solving (#909)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2025-01-20 22:17:24 -05:00
dante
675628cd08 fix!: shuffle argument should include an incrementing index (#904)
BREAKING CHANGE: pk and vk will not be backwards compatible
2025-01-17 09:19:10 -05:00
Artem
4fe7290240 fix: rust ci issue with updating swift pm testing files (#908) 2025-01-14 12:00:55 -05:00
dante
3e027db9b6 fix: apply zizmor suggestions to CI (#906)
---------

Co-authored-by: Jseam <hello.jseam@gmail.com>
2025-01-14 12:00:31 -05:00
Artem
e566acc22a fix: swift pm ci issue with updating testing files (#905) 2025-01-13 18:08:04 -05:00
dante
75ea99e81d fix: eager exec of ok_or error prints (#903) 2025-01-11 13:50:57 -05:00
dante
c5354c382d refactor: range check sanity toggled by CHECKMODE (#902) 2025-01-10 22:58:52 +00:00
dante
bdcba5ca61 feat: add gen-random-data helpers func (#901) 2025-01-09 00:14:27 +00:00
dante
6752a05f19 refactor: pregen mv-lookup blinds (#900) 2025-01-08 17:18:46 +00:00
dante
03aefb85eb chore: version mismatch warnings for artifacts (#899) 2025-01-06 16:01:34 +00:00
dante
e86caca8b6 refactor: batched poly reads (#897) 2025-01-06 15:49:47 +00:00
dante
c839a30ae6 fix: clearer duplication functions (#895) 2024-12-31 07:28:02 -05:00
dante
352812b9ac refactor!: simplified decompose op (#892) 2024-12-30 13:44:03 -05:00
dante
d48d0b0b3e fix: get_slice should not use intermediate Vec (#894) 2024-12-27 23:26:22 -05:00
Jseam
8b223354cc fix: add version string and sed (#893) 2024-12-27 14:24:28 -05:00
dante
caa6ef8e16 fix: const filtering strat is size dependent (#891) 2024-12-27 09:43:59 -05:00
Artem
c4354c10a5 fix: ios bindings update action (#886) 2024-12-16 10:49:13 -05:00
dante
c1ce8c88d0 chore: rm wasm serialization checks (#890) 2024-12-12 22:20:29 -05:00
dante
876a9584a1 chore: optimize wasm bundle for speed over size (#889) 2024-12-12 15:35:17 -05:00
dante
7d7f049cc4 chore: neural bag of words example (#888) 2024-12-12 14:20:21 -05:00
113 changed files with 12668 additions and 11950 deletions

View File

@@ -6,23 +6,16 @@ on:
description: "Test scenario tags"
jobs:
bench_elgamal:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- name: Bench elgamal
run: cargo bench --verbose --bench elgamal
bench_poseidon:
permissions:
contents: read
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -31,11 +24,15 @@ jobs:
run: cargo bench --verbose --bench poseidon
bench_einsum_accum_matmul:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -44,11 +41,15 @@ jobs:
run: cargo bench --verbose --bench accum_einsum_matmul
bench_accum_matmul_relu:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -57,11 +58,15 @@ jobs:
run: cargo bench --verbose --bench accum_matmul_relu
bench_accum_matmul_relu_overflow:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -70,11 +75,15 @@ jobs:
run: cargo bench --verbose --bench accum_matmul_relu_overflow
bench_relu:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -83,11 +92,15 @@ jobs:
run: cargo bench --verbose --bench relu
bench_accum_dot:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -96,11 +109,15 @@ jobs:
run: cargo bench --verbose --bench accum_dot
bench_accum_conv:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -109,11 +126,15 @@ jobs:
run: cargo bench --verbose --bench accum_conv
bench_accum_sumpool:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -122,11 +143,15 @@ jobs:
run: cargo bench --verbose --bench accum_sumpool
bench_pairwise_add:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -135,11 +160,15 @@ jobs:
run: cargo bench --verbose --bench pairwise_add
bench_accum_sum:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
@@ -148,11 +177,15 @@ jobs:
run: cargo bench --verbose --bench accum_sum
bench_pairwise_pow:
permissions:
contents: read
runs-on: self-hosted
needs: [bench_poseidon]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true

View File

@@ -15,17 +15,24 @@ defaults:
working-directory: .
jobs:
publish-wasm-bindings:
permissions:
contents: read
packages: write
name: publish-wasm-bindings
env:
RELEASE_TAG: ${{ github.ref_name }}
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
toolchain: nightly-2024-07-18
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
- uses: jetli/wasm-pack-action@0d096b08b4e5a7de8c28de67e11e945404e9eefa #v0.4.0
with:
# Pin to version 0.12.1
version: 'v0.12.1'
@@ -33,7 +40,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2025-02-17-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
@@ -42,41 +49,41 @@ jobs:
wasm-opt --version
- name: Build wasm files for both web and nodejs compilation targets
run: |
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
wasm-pack build --release --target web --out-dir ./pkg/web . -- -Z build-std="panic_abort,std" --features web
- name: Create package.json in pkg folder
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
echo '{
"name": "@ezkljs/engine",
"version": "${{ github.ref_name }}",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}' > pkg/package.json
cat > pkg/package.json << EOF
{
"name": "@ezkljs/engine",
"version": "$RELEASE_TAG",
"dependencies": {
"@types/json-bigint": "^1.0.1",
"json-bigint": "^1.0.0"
},
"files": [
"nodejs/ezkl_bg.wasm",
"nodejs/ezkl.js",
"nodejs/ezkl.d.ts",
"nodejs/package.json",
"nodejs/utils.js",
"web/ezkl_bg.wasm",
"web/ezkl.js",
"web/ezkl.d.ts",
"web/snippets/**/*",
"web/package.json",
"web/utils.js",
"ezkl.d.ts"
],
"main": "nodejs/ezkl.js",
"module": "web/ezkl.js",
"types": "nodejs/ezkl.d.ts",
"sideEffects": [
"web/snippets/*"
]
}
EOF
- name: Replace memory definition in nodejs
run: |
@@ -169,7 +176,7 @@ jobs:
curl -s "https://raw.githubusercontent.com/zkonduit/ezkljs-engine/main/README.md" > ./pkg/README.md
- name: Set up Node.js
uses: actions/setup-node@v3
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
@@ -184,21 +191,26 @@ jobs:
in-browser-evm-ver-publish:
permissions:
contents: read
packages: write
name: publish-in-browser-evm-verifier-package
needs: [publish-wasm-bindings]
runs-on: ubuntu-latest
env:
RELEASE_TAG: ${{ github.ref_name }}
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
sed -i "s|\"version\": \".*\"|\"version\": \"$RELEASE_TAG\"|" in-browser-evm-verifier/package.json
- name: Prepare tag and fetch package integrity
run: |
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
CLEANED_TAG=${RELEASE_TAG} # Get the tag from ref_name
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
@@ -218,13 +230,13 @@ jobs:
NR==30{$0=" specifier: \"" tag "\""}
NR==31{$0=" version: \"" tag "\""}
NR==400{$0=" /@ezkljs/engine@" tag ":"}
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
- name: Use pnpm 8
uses: pnpm/action-setup@v2
uses: pnpm/action-setup@eae0cfeb286e66ffb5155f1a79b90583a127a68b #v2.4.1
with:
version: 8
- name: Set up Node.js
uses: actions/setup-node@v3
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 #v3.8.2
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
@@ -235,4 +247,4 @@ jobs:
pnpm run build
pnpm publish --no-git-checks
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -6,12 +6,16 @@ on:
description: "Test scenario tags"
jobs:
large-tests:
permissions:
contents: read
runs-on: kaiju
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
toolchain: nightly-2024-07-18
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -18,38 +18,46 @@ defaults:
jobs:
linux:
permissions:
contents: read
packages: write
runs-on: GPU
strategy:
matrix:
target: [x86_64]
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
- name: Set pyproject.toml version to match github tag and rename ezkl to ezkl-gpu
shell: bash
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig > pyproject.toml.tmp
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.tmp > pyproject.toml
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- name: Set Cargo.toml version to match github tag
- name: Set Cargo.toml version to match github tag and rename ezkl to ezkl-gpu
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
# the ezkl substitution here looks for the first instance of name = "ezkl" and changes it to "ezkl-gpu"
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
sed "0,/name = \"ezkl\"/s/name = \"ezkl\"/name = \"ezkl-gpu\"/" Cargo.toml.orig > Cargo.toml.tmp
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.tmp > Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig > Cargo.lock
- name: Install required libraries
shell: bash
@@ -57,7 +65,7 @@ jobs:
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
manylinux: auto
@@ -70,7 +78,7 @@ jobs:
pip install ezkl-gpu --no-index --find-links dist --force-reinstall
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: wheels
path: dist
@@ -86,7 +94,7 @@ jobs:
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
needs: [linux]
steps:
- uses: actions/download-artifact@v3
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
with:
name: wheels
- name: List Files
@@ -98,14 +106,14 @@ jobs:
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
with:
packages-dir: ./
packages-dir: ./wheels
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
packages-dir: ./wheels

View File

@@ -16,36 +16,53 @@ defaults:
jobs:
macos:
permissions:
contents: read
runs-on: macos-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
matrix:
target: [x86_64, universal2-apple-darwin]
env:
RELEASE_TAG: ${{ github.ref_name }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Build wheels
uses: PyO3/maturin-action@v1
if: matrix.target == 'universal2-apple-darwin'
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
- name: Build wheels
if: matrix.target == 'x86_64'
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
@@ -56,24 +73,36 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: wheels
name: dist-macos-${{ matrix.target }}
path: dist
windows:
permissions:
contents: read
runs-on: windows-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
matrix:
target: [x64, x86]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -84,14 +113,14 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
args: --release --out dist --features python-bindings
@@ -101,24 +130,36 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0 #v4.6.0
with:
name: wheels
name: dist-windows-${{ matrix.target }}
path: dist
linux:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
matrix:
target: [x86_64]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -129,14 +170,13 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
sudo apt-get update && sudo apt-get install -y openssl pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
manylinux: auto
@@ -163,63 +203,14 @@ jobs:
python -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: wheels
name: dist-linux-${{ matrix.target }}
path: dist
# There's a problem with the maturin-action toolchain for arm arch leading to failed builds
# linux-cross:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# target: [aarch64, armv7]
# steps:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
# run: |
# sudo apt-get update
# sudo apt-get install -y gcc make gcc-aarch64-linux-gnu binutils-aarch64-linux-gnu libc6-dev-arm64-cross libusb-1.0-0-dev libatomic1-arm64-cross
# - name: Install cross-compilation tools for armv7
# if: matrix.target == 'armv7'
# run: |
# sudo apt-get update
# sudo apt-get install -y gcc make gcc-arm-linux-gnueabihf binutils-arm-linux-gnueabihf libc6-dev-armhf-cross libusb-1.0-0-dev libatomic1-armhf-cross
# - name: Build wheels
# uses: PyO3/maturin-action@v1
# with:
# target: ${{ matrix.target }}
# manylinux: auto
# args: --release --out dist --features python-bindings
# - uses: uraimo/run-on-arch-action@v2.5.0
# name: Install built wheel
# with:
# arch: ${{ matrix.target }}
# distro: ubuntu20.04
# githubToken: ${{ github.token }}
# install: |
# apt-get update
# apt-get install -y --no-install-recommends python3 python3-pip
# pip3 install -U pip
# run: |
# pip3 install ezkl --no-index --find-links dist/ --force-reinstall
# python3 -c "import ezkl"
# - name: Upload wheels
# uses: actions/upload-artifact@v3
# with:
# name: wheels
# path: dist
musllinux:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -227,12 +218,22 @@ jobs:
target:
- x86_64-unknown-linux-musl
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -249,7 +250,7 @@ jobs:
sudo apt-get update && sudo apt-get install -y pkg-config libssl-dev
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.target }}
manylinux: musllinux_1_2
@@ -257,7 +258,7 @@ jobs:
- name: Install built wheel
if: matrix.target == 'x86_64-unknown-linux-musl'
uses: addnab/docker-run-action@v3
uses: addnab/docker-run-action@3e77f186b7a929ef010f183a9e24c0f9955ea609
with:
image: alpine:latest
options: -v ${{ github.workspace }}:/io -w /io
@@ -270,12 +271,14 @@ jobs:
python3 -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: wheels
name: dist-musllinux-${{ matrix.target }}
path: dist
musllinux-cross:
permissions:
contents: read
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
strategy:
@@ -284,11 +287,21 @@ jobs:
- target: aarch64-unknown-linux-musl
arch: aarch64
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions/setup-python@b64ffcaf5b410884ad320a9cfac8866006a109aa #v4.8.0
with:
python-version: 3.12
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -300,13 +313,13 @@ jobs:
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Build wheels
uses: PyO3/maturin-action@v1
uses: PyO3/maturin-action@5f8a1b3b0aad13193f46c9131f9b9e663def8ce5 #v1.46.0
with:
target: ${{ matrix.platform.target }}
manylinux: musllinux_1_2
args: --release --out dist --features python-bindings
- uses: uraimo/run-on-arch-action@v2.8.1
- uses: uraimo/run-on-arch-action@5397f9e30a9b62422f302092631c99ae1effcd9e #v2.8.1
name: Install built wheel
with:
arch: ${{ matrix.platform.arch }}
@@ -321,9 +334,9 @@ jobs:
python3 -c "import ezkl"
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 #v4.6.0
with:
name: wheels
name: dist-musllinux-${{ matrix.platform.target }}
path: dist
pypi-publish:
@@ -332,44 +345,43 @@ jobs:
permissions:
id-token: write
if: "startsWith(github.ref, 'refs/tags/')"
# TODO: Uncomment if linux-cross is working
# needs: [ macos, windows, linux, linux-cross, musllinux, musllinux-cross ]
needs: [macos, windows, linux, musllinux, musllinux-cross]
steps:
- uses: actions/download-artifact@v3
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 #v4.1.8
with:
name: wheels
pattern: dist-*
merge-multiple: true
path: wheels
- name: List Files
run: ls -R
# Both publish steps will fail if there is no trusted publisher setup
# On failure the publish step will then simply continue to the next one
# # publishes to TestPyPI
# - name: Publish package distribution to TestPyPI
# uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
# with:
# repository-url: https://test.pypi.org/legacy/
# packages-dir: ./
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc #v1.12.4
with:
packages-dir: ./
packages-dir: ./wheels
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
permissions:
contents: read
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
uses: dfm/rtds-action@618148c547f4b56cdf4fa4dcf3a94c91ce025f2d
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}
commit_ref: ${{ github.ref_name }}

View File

@@ -10,6 +10,9 @@ on:
- "*"
jobs:
create-release:
permissions:
contents: read
packages: write
name: create-release
runs-on: ubuntu-22.04
if: startsWith(github.ref, 'refs/tags/')
@@ -27,12 +30,15 @@ jobs:
- name: Create Github Release
id: create-release
uses: softprops/action-gh-release@v1
uses: softprops/action-gh-release@c95fe1489396fe8a9eb87c0abf8aa5b2ef267fda #v2.2.1
with:
token: ${{ secrets.RELEASE_TOKEN }}
tag_name: ${{ env.EZKL_VERSION }}
build-release-gpu:
permissions:
contents: read
packages: write
name: build-release-gpu
needs: ["create-release"]
runs-on: GPU
@@ -43,13 +49,16 @@ jobs:
RUST_BACKTRACE: 1
PCRE2_SYS_STATIC: 1
steps:
- uses: actions-rs/toolchain@v1
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- name: Checkout repo
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Get release version from tag
shell: bash
@@ -81,7 +90,7 @@ jobs:
echo "ASSET=build-artifacts/ezkl-linux-gpu.tar.gz" >> $GITHUB_ENV
- name: Upload release archive
uses: actions/upload-release-asset@v1.0.2
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
env:
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
with:
@@ -91,6 +100,10 @@ jobs:
asset_content_type: application/octet-stream
build-release:
permissions:
contents: read
packages: write
issues: write
name: build-release
needs: ["create-release"]
runs-on: ${{ matrix.os }}
@@ -106,32 +119,34 @@ jobs:
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2025-02-17
target: aarch64-unknown-linux-gnu
steps:
- name: Checkout repo
uses: actions/checkout@v4
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Get release version from tag
shell: bash
@@ -155,7 +170,7 @@ jobs:
fi
- name: Install Rust
uses: dtolnay/rust-toolchain@nightly
uses: dtolnay/rust-toolchain@4f94fbe7e03939b0e674bcc9ca609a16088f63ff #nightly branch, TODO: update when required
with:
target: ${{ matrix.target }}
@@ -181,14 +196,18 @@ jobs:
echo "target flag is: ${{ env.TARGET_FLAGS }}"
echo "target dir is: ${{ env.TARGET_DIR }}"
- name: Build release binary (no asm)
if: matrix.build != 'linux-gnu'
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
run: strip "target/${{ matrix.target }}/release/ezkl"
@@ -214,7 +233,7 @@ jobs:
echo "ASSET=build-artifacts/ezkl-win.zip" >> $GITHUB_ENV
- name: Upload release archive
uses: actions/upload-release-asset@v1.0.2
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 #v1.0.2
env:
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
with:

File diff suppressed because it is too large Load Diff

32
.github/workflows/static-analysis.yml vendored Normal file
View File

@@ -0,0 +1,32 @@
name: Static Analysis
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
analyze:
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
# Run Zizmor static analysis
- name: Install Zizmor
run: cargo install --locked zizmor
- name: Run Zizmor Analysis
run: zizmor .

134
.github/workflows/swift-pm.yml vendored Normal file
View File

@@ -0,0 +1,134 @@
name: Build and Publish EZKL iOS SPM package
on:
push:
tags:
# Only support SemVer versioning tags
- 'v[0-9]+.[0-9]+.[0-9]+'
- '[0-9]+.[0-9]+.[0-9]+'
jobs:
build-and-update:
permissions:
contents: read
packages: write
runs-on: macos-latest
env:
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
RELEASE_TAG: ${{ github.ref_name }}
steps:
- name: Checkout EZKL
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Extract TAG from github.ref_name
run: |
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
TAG="${RELEASE_TAG}"
echo "Original TAG: $TAG"
# Remove leading 'v' if present to match the Swift Package Manager version format.
NEW_TAG=${TAG#v}
echo "Stripped TAG: $NEW_TAG"
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
- name: Install Rust (nightly)
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f #v1.0.6
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://${{ env.EZKL_SWIFT_PACKAGE_REPO }}
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/
mkdir -p ezkl-swift-package/Tests/EzklAssets/
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Check for changes
id: check_changes
run: |
cd ezkl-swift-package
if git diff --quiet Sources/EzklCoreBindings Tests/EzklAssets; then
echo "no_changes=true" >> $GITHUB_OUTPUT
else
echo "no_changes=false" >> $GITHUB_OUTPUT
fi
- name: Set up Xcode environment
if: steps.check_changes.outputs.no_changes == 'false'
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Setup Git
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git remote set-url origin https://zkonduit:${EZKL_SWIFT_PACKAGE_REPO_TOKEN}@${{ env.EZKL_SWIFT_PACKAGE_REPO }}
env:
EZKL_SWIFT_PACKAGE_REPO_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
- name: Commit and Push Changes
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
git add Sources/EzklCoreBindings Tests/EzklAssets
git commit -m "Automatically updated EzklCoreBindings for EZKL"
if ! git push origin; then
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
exit 1
fi
- name: Tag the latest commit
run: |
cd ezkl-swift-package
source $GITHUB_ENV
# Tag the latest commit on the current branch
if git rev-parse "$TAG" >/dev/null 2>&1; then
echo "Tag $TAG already exists locally. Skipping tag creation."
else
git tag "$TAG"
fi
if ! git push origin "$TAG"; then
echo "::error::Failed to push tag '$TAG' to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure EZKL_PORTER_TOKEN has correct permissions."
exit 1
fi

View File

@@ -11,10 +11,12 @@ jobs:
contents: write
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
persist-credentials: false
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.2
uses: mathieudutour/github-tag-action@a22cf08638b34d5badda920f9daf6e72c477b07b #v6.2
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
@@ -44,7 +46,7 @@ jobs:
git tag $RELEASE_TAG
- name: Push changes
uses: ad-m/github-push-action@master
uses: ad-m/github-push-action@77c5b412c50b723d2a4fbc6d71fb5723bcd439aa #master
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
with:

View File

@@ -1,85 +0,0 @@
name: Build and Publish EZKL iOS SPM package
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
jobs:
build-and-update:
runs-on: macos-latest
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/*
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Commit and Push Changes to feat/ezkl-direct-integration
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git add Sources/EzklCoreBindings
git add Tests/EzklAssets
git commit -m "Automatically updated EzklCoreBindings for EZKL"
git tag ${{ github.event.inputs.tag }}
git remote set-url origin https://zkonduit:${EZKL_PORTER_TOKEN}@github.com/zkonduit/ezkl-swift-package.git
git push origin
git push origin tag ${{ github.event.inputs.tag }}
env:
EZKL_PORTER_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}

3
.gitignore vendored
View File

@@ -9,6 +9,7 @@ pkg
!AttestData.sol
!VerifierBase.sol
!LoadInstances.sol
!AttestData.t.sol
*.pf
*.vk
*.pk
@@ -49,3 +50,5 @@ timingData.json
!tests/assets/vk.key
docs/python/build
!tests/assets/vk_aggr.key
cache
out

2661
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -35,12 +35,11 @@ halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", optional = true }
halo2_solidity_verifier = { git = "https://github.com/zkonduit/verification-ezkl", branch = "vka-hash", optional = true }
maybe-rayon = { version = "0.1.1", default-features = false }
bincode = { version = "1.3.3", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = { version = "1.6.0", optional = true }
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true }
semver = { version = "1.0.22", optional = true }
@@ -70,28 +69,24 @@ reqwest = { version = "0.12.4", default-features = false, features = [
"stream",
], optional = true }
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
regex = { version = "1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.23.2", features = [
pyo3 = { version = "0.24.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default-features = false, optional = true }
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.24.0", features = [
"attributes",
"tokio-runtime",
], default-features = false, optional = true }
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
@@ -147,6 +142,10 @@ shellexpand = "3.1.0"
runner = 'wasm-bindgen-test-runner'
[[bench]]
name = "zero_finder"
harness = false
[[bench]]
name = "accum_dot"
harness = false
@@ -218,15 +217,15 @@ required-features = ["python-bindings"]
[features]
web = ["wasm-bindgen-rayon"]
default = [
"eth-mv-lookup",
"ezkl",
"mv-lookup",
"precompute-coset",
"no-banner",
"parallel-poly-read",
]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings = ["eth-mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
ezkl = [
"onnx",
@@ -235,28 +234,41 @@ ezkl = [
"tabled/color",
"serde_json/std",
"colored_json",
"dep:alloy",
"dep:foundry-compilers",
"dep:ethabi",
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:openssl",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:regex",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:portable-atomic",
"dep:clap_complete",
"dep:halo2_solidity_verifier",
"dep:semver",
"dep:clap",
"dep:tosubcommand",
]
eth = [
"dep:alloy",
"dep:foundry-compilers",
"dep:ethabi",
]
solidity-verifier = [
"dep:halo2_solidity_verifier",
]
solidity-verifier-mv-lookup = [
"halo2_solidity_verifier/mv-lookup",
]
eth-mv-lookup = [
"solidity-verifier-mv-lookup",
"mv-lookup",
"eth",
]
eth-original-lookup = [
"eth",
"solidity-verifier",
]
parallel-poly-read = [
"halo2_proofs/circuit-params",
"halo2_proofs/parallel-poly-read",
@@ -264,7 +276,6 @@ parallel-poly-read = [
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
"halo2_solidity_verifier/mv-lookup",
]
asm = ["halo2curves/asm", "halo2_proofs/asm"]
precompute-coset = ["halo2_proofs/precompute-coset"]
@@ -273,11 +284,10 @@ icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b", package = "halo2_proofs" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
@@ -286,3 +296,11 @@ rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
# panic = "abort"
[profile.test-runs]
inherits = "dev"
opt-level = 3
[package.metadata.wasm-pack.profile.release]
wasm-opt = ["-O4", "--flexible-inline-max-function-size", "4294967295"]

View File

@@ -43,7 +43,7 @@ The generated proofs can then be verified with much less computational resources
----------------------
### getting started ⚙️
### Getting Started ⚙️
The easiest way to get started is to try out a notebook.
@@ -76,12 +76,12 @@ For more details visit the [docs](https://docs.ezkl.xyz). The CLI is faster than
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
#### In-browser EVM verifier
#### In-browser EVM Verifier
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
### building the project 🔨
### Building the Project 🔨
#### Rust CLI
@@ -96,7 +96,7 @@ cargo install --locked --path .
#### building python bindings
#### Building Python Bindings
Python bindings exists and can be built using `maturin`. You will need `rust` and `cargo` to be installed.
```bash
@@ -126,7 +126,7 @@ unset ENABLE_ICICLE_GPU
**NOTE:** Even with the above environment variable set, icicle is disabled for circuits where k <= 8. To change the value of `k` where icicle is enabled, you can set the environment variable `ICICLE_SMALL_K`.
### contributing 🌎
### Contributing 🌎
If you're interested in contributing and are unsure where to start, reach out to one of the maintainers:
@@ -144,13 +144,21 @@ More broadly:
Any contribution intentionally submitted for inclusion in the work by you shall be licensed to Zkonduit Inc. under the terms and conditions specified in the [CLA](https://github.com/zkonduit/ezkl/blob/main/cla.md), which you agree to by intentionally submitting a contribution. In particular, you have the right to submit the contribution and we can distribute it, among other terms and conditions.
### no security guarantees
Ezkl is unaudited, beta software undergoing rapid development. There may be bugs. No guarantees of security are made and it should not be relied on in production.
### Audits & Security
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
[v21.0.0](https://github.com/zkonduit/ezkl/releases/tag/v21.0.0) has been audited by Trail of Bits, the report can be found [here](https://github.com/trailofbits/publications/blob/master/reviews/2025-03-zkonduit-ezkl-securityreview.pdf).
### no warranty
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
Check out `docs/advanced_security` for more advanced information on potential threat vectors that are specific to zero-knowledge inference, quantization, and to machine learning models generally.
### No Warranty
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Copyright (c) 2025 Zkonduit Inc.

312
abis/DataAttestation.json Normal file
View File

@@ -0,0 +1,312 @@
[
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256[]",
"name": "_decimals",
"type": "uint256[]"
},
{
"internalType": "uint256[]",
"name": "_bits",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [],
"name": "HALF_ORDER",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "ORDER",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"name": "attestData",
"outputs": [],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "callData",
"outputs": [
{
"internalType": "bytes",
"name": "",
"type": "bytes"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "contractAddress",
"outputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "getInstancesCalldata",
"outputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "getInstancesMemory",
"outputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "index",
"type": "uint256"
}
],
"name": "getScalars",
"outputs": [
{
"components": [
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "bits",
"type": "uint256"
}
],
"internalType": "struct DataAttestation.Scalars",
"name": "",
"type": "tuple"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
{
"internalType": "uint8",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "x",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "y",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "denominator",
"type": "uint256"
}
],
"name": "mulDiv",
"outputs": [
{
"internalType": "uint256",
"name": "result",
"type": "uint256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "int256",
"name": "x",
"type": "int256"
},
{
"components": [
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "bits",
"type": "uint256"
}
],
"internalType": "struct DataAttestation.Scalars",
"name": "_scalars",
"type": "tuple"
}
],
"name": "quantizeData",
"outputs": [
{
"internalType": "int256",
"name": "quantized_data",
"type": "int256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "target",
"type": "address"
},
{
"internalType": "bytes",
"name": "data",
"type": "bytes"
}
],
"name": "staticCall",
"outputs": [
{
"internalType": "bytes",
"name": "",
"type": "bytes"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "int256",
"name": "x",
"type": "int256"
}
],
"name": "toFieldElement",
"outputs": [
{
"internalType": "uint256",
"name": "field_element",
"type": "uint256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
}
]

View File

@@ -1,167 +0,0 @@
[
{
"inputs": [
{
"internalType": "address[]",
"name": "_contractAddresses",
"type": "address[]"
},
{
"internalType": "bytes[][]",
"name": "_callData",
"type": "bytes[][]"
},
{
"internalType": "uint256[][]",
"name": "_decimals",
"type": "uint256[][]"
},
{
"internalType": "uint256[]",
"name": "_scales",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
},
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"name": "accountCalls",
"outputs": [
{
"internalType": "address",
"name": "contractAddress",
"type": "address"
},
{
"internalType": "uint256",
"name": "callCount",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "admin",
"outputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
{
"internalType": "uint8",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"name": "scales",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "address[]",
"name": "_contractAddresses",
"type": "address[]"
},
{
"internalType": "bytes[][]",
"name": "_callData",
"type": "bytes[][]"
},
{
"internalType": "uint256[][]",
"name": "_decimals",
"type": "uint256[][]"
}
],
"name": "updateAccountCalls",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"name": "updateAdmin",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
}
]

View File

@@ -1,147 +0,0 @@
[
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "_decimals",
"type": "uint256"
},
{
"internalType": "uint256[]",
"name": "_scales",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
},
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [],
"name": "accountCall",
"outputs": [
{
"internalType": "address",
"name": "contractAddress",
"type": "address"
},
{
"internalType": "bytes",
"name": "callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "admin",
"outputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
{
"internalType": "uint8",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "_decimals",
"type": "uint256"
}
],
"name": "updateAccountCalls",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"name": "updateAdmin",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
}
]

View File

@@ -73,6 +73,8 @@ impl Circuit<Fr> for MyCircuit {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
}),
)
.unwrap();

View File

@@ -69,6 +69,7 @@ impl Circuit<Fr> for MyCircuit {
stride: vec![1, 1],
kernel_shape: vec![2, 2],
normalized: false,
data_format: DataFormat::NCHW,
}),
)
.unwrap();

View File

@@ -23,8 +23,6 @@ use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
const L: usize = 10;
#[derive(Clone, Debug)]
struct MyCircuit {
image: ValTensor<Fr>,
@@ -40,7 +38,7 @@ impl Circuit<Fr> for MyCircuit {
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, 10>::configure(cs, ())
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::configure(cs, ())
}
fn synthesize(
@@ -48,7 +46,7 @@ impl Circuit<Fr> for MyCircuit {
config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
@@ -59,7 +57,7 @@ fn runposeidon(c: &mut Criterion) {
let mut group = c.benchmark_group("poseidon");
for size in [64, 784, 2352, 12288].iter() {
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::num_rows(*size)
let k = (PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::num_rows(*size)
as f32)
.log2()
.ceil() as u32;
@@ -67,7 +65,7 @@ fn runposeidon(c: &mut Criterion) {
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
let _output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.to_vec())
.unwrap();
let mut image = Tensor::from(message.into_iter().map(Value::known));

117
benches/zero_finder.rs Normal file
View File

@@ -0,0 +1,117 @@
use std::thread;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use halo2curves::{bn256::Fr as F, ff::Field};
use maybe_rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
slice::ParallelSlice,
};
use rand::Rng;
// Assuming these are your types
#[derive(Clone)]
#[allow(dead_code)]
enum ValType {
Constant(F),
AssignedConstant(usize, F),
Other,
}
// Helper to generate test data
fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
let mut rng = rand::thread_rng();
(0..size)
.map(|_i| {
if rng.r#gen::<f64>() < zero_probability {
ValType::Constant(F::ZERO)
} else {
ValType::Constant(F::ONE) // Or some other non-zero value
}
})
.collect()
}
fn bench_zero_finding(c: &mut Criterion) {
let sizes = [
1_000, // 1K
10_000, // 10K
100_000, // 100K
256 * 256 * 2, // Our specific case
1_000_000, // 1M
10_000_000, // 10M
];
let zero_probability = 0.1; // 10% zeros
let mut group = c.benchmark_group("zero_finding");
group.sample_size(10); // Adjust based on your needs
for &size in &sizes {
let data = generate_test_data(size, zero_probability);
// Benchmark sequential version
group.bench_function(format!("sequential_{}", size), |b| {
b.iter(|| {
let result = data
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});
// Benchmark parallel version
group.bench_function(format!("parallel_{}", size), |b| {
b.iter(|| {
let result = data
.par_iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});
// Benchmark chunked parallel version
group.bench_function(format!("chunked_parallel_{}", size), |b| {
b.iter(|| {
let num_cores = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (size / num_cores).max(100);
let result = data
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter() // Make sure we use par_iter() here
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>();
black_box(result)
})
});
}
group.finish();
}
criterion_group!(benches, bench_zero_finding);
criterion_main!(benches);

View File

@@ -8,21 +8,27 @@ contract LoadInstances {
*/
function getInstancesMemory(
bytes memory encoded
) internal pure returns (uint256[] memory instances) {
) public pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := mload(add(encoded, 0x20))
}
if (funcSig == 0xaf83a18d) {
instances_offset = 0x64;
} else if (funcSig == 0x1e8e1e13) {
instances_offset = 0x44;
} else {
revert("Invalid function signature");
}
assembly {
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := mload(
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
)
instances_offset := mload(add(encoded, instances_offset))
instances_length := mload(add(add(encoded, 0x24), instances_offset))
}
@@ -41,6 +47,10 @@ contract LoadInstances {
)
}
}
require(
funcSig == 0xaf83a18d || funcSig == 0x1e8e1e13,
"Invalid function signature"
);
}
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
@@ -49,23 +59,31 @@ contract LoadInstances {
*/
function getInstancesCalldata(
bytes calldata encoded
) internal pure returns (uint256[] memory instances) {
) public pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
}
if (funcSig == 0xaf83a18d) {
instances_offset = 0x44;
} else if (funcSig == 0x1e8e1e13) {
instances_offset = 0x24;
} else {
revert("Invalid function signature");
}
// We need to create a new assembly block in order for solidity
// to cast the funcSig to a bytes4 type. Otherwise it will load the entire first 32 bytes of the calldata
// within the block
assembly {
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := calldataload(
add(
encoded.offset,
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
add(encoded.offset, instances_offset)
)
instances_length := calldataload(
@@ -96,7 +114,7 @@ contract LoadInstances {
// The kzg commitments of a given model, all aggregated into a single bytes array.
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
// It will be used to check that the proof commitments match the expected commitments.
bytes constant COMMITMENT_KZG = hex"";
bytes constant COMMITMENT_KZG = hex"1234";
contract SwapProofCommitments {
/**
@@ -113,17 +131,20 @@ contract SwapProofCommitments {
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
}
if (funcSig == 0xaf83a18d) {
proof_offset = 0x24;
} else if (funcSig == 0x1e8e1e13) {
proof_offset = 0x04;
} else {
revert("Invalid function signature");
}
assembly {
// Fetch proof offset which is 4 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 away for `verifyProof(address,bytes,uint256[])`
proof_offset := calldataload(
add(
encoded.offset,
add(0x04, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
)
proof_offset := calldataload(add(encoded.offset, proof_offset))
proof_length := calldataload(
add(add(encoded.offset, 0x04), proof_offset)
@@ -154,7 +175,7 @@ contract SwapProofCommitments {
let wordCommitment := mload(add(commitment, i))
equal := eq(wordProof, wordCommitment)
if eq(equal, 0) {
return(0, 0)
break
}
}
}
@@ -163,36 +184,38 @@ contract SwapProofCommitments {
} /// end checkKzgCommits
}
contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only call to account to fetch the data that EZKL reads from.
* @param the address of the account to make calls to
* @param the abi encoded function calls to make to the `contractAddress`
*/
struct AccountCall {
address contractAddress;
bytes callData;
contract DataAttestation is LoadInstances, SwapProofCommitments {
// the address of the account to make calls to
address public immutable contractAddress;
// the abi encoded function calls to make to the `contractAddress` that returns the attested to data
bytes public callData;
struct Scalars {
// The number of base 10 decimals to scale the data by.
// For most ERC20 tokens this is 1e18
uint256 decimals;
// The number of fractional bits of the fixed point EZKL data points.
uint256 bits;
}
AccountCall public accountCall;
uint[] scales;
Scalars[] private scalars;
address public admin;
function getScalars(uint256 index) public view returns (Scalars memory) {
return scalars[index];
}
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER =
uint256 public constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 constant INPUT_LEN = 0;
uint256 constant OUTPUT_LEN = 0;
uint256 public constant HALF_ORDER = ORDER >> 1;
uint8 public instanceOffset;
@@ -204,53 +227,27 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
constructor(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals,
uint[] memory _scales,
uint8 _instanceOffset,
address _admin
uint256[] memory _decimals,
uint[] memory _bits,
uint8 _instanceOffset
) {
admin = _admin;
for (uint i; i < _scales.length; i++) {
scales.push(1 << _scales[i]);
require(
_bits.length == _decimals.length,
"Invalid scalar array lengths"
);
for (uint i; i < _bits.length; i++) {
scalars.push(Scalars(10 ** _decimals[i], 1 << _bits[i]));
}
populateAccountCalls(_contractAddresses, _callData, _decimals);
contractAddress = _contractAddresses;
callData = _callData;
instanceOffset = _instanceOffset;
}
function updateAdmin(address _admin) external {
require(msg.sender == admin, "Only admin can update admin");
if (_admin == address(0)) {
revert();
}
admin = _admin;
}
function updateAccountCalls(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals
) external {
require(msg.sender == admin, "Only admin can update account calls");
populateAccountCalls(_contractAddresses, _callData, _decimals);
}
function populateAccountCalls(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals
) internal {
AccountCall memory _accountCall = accountCall;
_accountCall.contractAddress = _contractAddresses;
_accountCall.callData = _callData;
_accountCall.decimals = 10 ** _decimals;
accountCall = _accountCall;
}
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) internal pure returns (uint256 result) {
) public pure returns (uint256 result) {
unchecked {
uint256 prod0;
uint256 prod1;
@@ -298,21 +295,28 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
/**
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param x - One of the elements of the data returned from the account calls
* @param _decimals - Number of base 10 decimals to scale the data by.
* @param _scale - The base 2 scale used to convert the floating point value into a fixed point value.
* @param _scalars - The scaling factors for the data returned from the account calls.
*
*/
function quantizeData(
int x,
uint256 _decimals,
uint256 _scale
) internal pure returns (int256 quantized_data) {
Scalars memory _scalars
) public pure returns (int256 quantized_data) {
if (_scalars.bits == 1 && _scalars.decimals == 1) {
return x;
}
bool neg = x < 0;
if (neg) x = -x;
uint output = mulDiv(uint256(x), _scale, _decimals);
if (mulmod(uint256(x), _scale, _decimals) * 2 >= _decimals) {
uint output = mulDiv(uint256(x), _scalars.bits, _scalars.decimals);
if (
mulmod(uint256(x), _scalars.bits, _scalars.decimals) * 2 >=
_scalars.decimals
) {
output += 1;
}
if (output > HALF_ORDER) {
revert("Overflow field modulus");
}
quantized_data = neg ? -int256(output) : int256(output);
}
/**
@@ -324,7 +328,7 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
function staticCall(
address target,
bytes memory data
) internal view returns (bytes memory) {
) public view returns (bytes memory) {
(bool success, bytes memory returndata) = target.staticcall(data);
if (success) {
if (returndata.length == 0) {
@@ -345,7 +349,7 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
*/
function toFieldElement(
int256 x
) internal pure returns (uint256 field_element) {
) public pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
@@ -355,315 +359,16 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
*/
function attestData(uint256[] memory instances) internal view {
require(
instances.length >= INPUT_LEN + OUTPUT_LEN,
"Invalid public inputs length"
);
AccountCall memory _accountCall = accountCall;
uint[] memory _scales = scales;
bytes memory returnData = staticCall(
_accountCall.contractAddress,
_accountCall.callData
);
function attestData(uint256[] memory instances) public view {
bytes memory returnData = staticCall(contractAddress, callData);
int256[] memory x = abi.decode(returnData, (int256[]));
uint _offset;
int output = quantizeData(x[0], _accountCall.decimals, _scales[0]);
uint field_element = toFieldElement(output);
int output;
uint fieldElement;
for (uint i = 0; i < x.length; i++) {
if (field_element != instances[i + instanceOffset]) {
_offset += 1;
} else {
break;
}
}
uint length = x.length - _offset;
for (uint i = 1; i < length; i++) {
output = quantizeData(x[i], _accountCall.decimals, _scales[i]);
field_element = toFieldElement(output);
require(
field_element == instances[i + instanceOffset + _offset],
"Public input does not match"
);
}
}
/**
* @dev Verify the proof with the data attestation.
* @param verifier - The address of the verifier contract.
* @param encoded - The verifier calldata.
*/
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
attestData(getInstancesCalldata(encoded));
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {
revert("low-level call to verifier failed");
}
}
}
// This contract serves as a Data Attestation Verifier for the EZKL model.
// It is designed to read and attest to instances of proofs generated from a specified circuit.
// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions.
// Overview of the contract functionality:
// 1. Initialization: Through the constructor, it sets up the contract calls that the EZKL model will read from.
// 2. Data Quantization: Quantizes the returned data into a scaled fixed-point representation. See the `quantizeData` method for details.
// 3. Static Calls: Makes static calls to fetch data from other contracts. See the `staticCall` method.
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
// 6. Proof Verification: The `verifyWithDataAttestationMulti` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
// then calls the `verifyProof` method to verify the proof on the verifier.
contract DataAttestationMulti is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
* @param the address of the account to make calls to
* @param the abi encoded function calls to make to the `contractAddress`
*/
struct AccountCall {
address contractAddress;
mapping(uint256 => bytes) callData;
mapping(uint256 => uint256) decimals;
uint callCount;
}
AccountCall[] public accountCalls;
uint[] public scales;
address public admin;
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 constant INPUT_CALLS = 0;
uint256 constant OUTPUT_CALLS = 0;
uint8 public instanceOffset;
/**
* @dev Initialize the contract with account calls the EZKL model will read from.
* @param _contractAddresses - The calls to all the contracts EZKL reads storage from.
* @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
*/
constructor(
address[] memory _contractAddresses,
bytes[][] memory _callData,
uint256[][] memory _decimals,
uint[] memory _scales,
uint8 _instanceOffset,
address _admin
) {
admin = _admin;
for (uint i; i < _scales.length; i++) {
scales.push(1 << _scales[i]);
}
populateAccountCalls(_contractAddresses, _callData, _decimals);
instanceOffset = _instanceOffset;
}
function updateAdmin(address _admin) external {
require(msg.sender == admin, "Only admin can update admin");
if (_admin == address(0)) {
revert();
}
admin = _admin;
}
function updateAccountCalls(
address[] memory _contractAddresses,
bytes[][] memory _callData,
uint256[][] memory _decimals
) external {
require(msg.sender == admin, "Only admin can update account calls");
populateAccountCalls(_contractAddresses, _callData, _decimals);
}
function populateAccountCalls(
address[] memory _contractAddresses,
bytes[][] memory _callData,
uint256[][] memory _decimals
) internal {
require(
_contractAddresses.length == _callData.length &&
accountCalls.length == _contractAddresses.length,
"Invalid input length"
);
require(
_decimals.length == _contractAddresses.length,
"Invalid number of decimals"
);
// fill in the accountCalls storage array
uint counter = 0;
for (uint256 i = 0; i < _contractAddresses.length; i++) {
AccountCall storage accountCall = accountCalls[i];
accountCall.contractAddress = _contractAddresses[i];
accountCall.callCount = _callData[i].length;
for (uint256 j = 0; j < _callData[i].length; j++) {
accountCall.callData[j] = _callData[i][j];
accountCall.decimals[j] = 10 ** _decimals[i][j];
}
// count the total number of storage reads across all of the accounts
counter += _callData[i].length;
}
require(
counter == INPUT_CALLS + OUTPUT_CALLS,
"Invalid number of calls"
);
}
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) internal pure returns (uint256 result) {
unchecked {
uint256 prod0;
uint256 prod1;
assembly {
let mm := mulmod(x, y, not(0))
prod0 := mul(x, y)
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
if (prod1 == 0) {
return prod0 / denominator;
}
require(denominator > prod1, "Math: mulDiv overflow");
uint256 remainder;
assembly {
remainder := mulmod(x, y, denominator)
prod1 := sub(prod1, gt(remainder, prod0))
prod0 := sub(prod0, remainder)
}
uint256 twos = denominator & (~denominator + 1);
assembly {
denominator := div(denominator, twos)
prod0 := div(prod0, twos)
twos := add(div(sub(0, twos), twos), 1)
}
prod0 |= prod1 * twos;
uint256 inverse = (3 * denominator) ^ 2;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
result = prod0 * inverse;
return result;
}
}
/**
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param data - The data returned from the account calls.
* @param decimals - The number of decimals the data returned from the account calls has (for floating point representation).
* @param scale - The scale used to convert the floating point value into a fixed point value.
*/
function quantizeData(
bytes memory data,
uint256 decimals,
uint256 scale
) internal pure returns (int256 quantized_data) {
int x = abi.decode(data, (int256));
bool neg = x < 0;
if (neg) x = -x;
uint output = mulDiv(uint256(x), scale, decimals);
if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) {
output += 1;
}
quantized_data = neg ? -int256(output) : int256(output);
}
/**
* @dev Make a static call to the account to fetch the data that EZKL reads from.
* @param target - The address of the account to make calls to.
* @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
* @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise)
*/
function staticCall(
address target,
bytes memory data
) internal view returns (bytes memory) {
(bool success, bytes memory returndata) = target.staticcall(data);
if (success) {
if (returndata.length == 0) {
require(
target.code.length > 0,
"Address: call to non-contract"
);
}
return returndata;
} else {
revert("Address: low-level call failed");
}
}
/**
* @dev Convert the fixed point quantized data into a field element.
* @param x - The quantized data.
* @return field_element - The field element.
*/
function toFieldElement(
int256 x
) internal pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
}
/**
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
*/
function attestData(uint256[] memory instances) internal view {
require(
instances.length >= INPUT_CALLS + OUTPUT_CALLS,
"Invalid public inputs length"
);
uint256 _accountCount = accountCalls.length;
uint counter = 0;
for (uint8 i = 0; i < _accountCount; ++i) {
address account = accountCalls[i].contractAddress;
for (uint8 j = 0; j < accountCalls[i].callCount; j++) {
bytes memory returnData = staticCall(
account,
accountCalls[i].callData[j]
);
uint256 scale = scales[counter];
int256 quantized_data = quantizeData(
returnData,
accountCalls[i].decimals[j],
scale
);
uint256 field_element = toFieldElement(quantized_data);
require(
field_element == instances[counter + instanceOffset],
"Public input does not match"
);
counter++;
output = quantizeData(x[i], scalars[i]);
fieldElement = toFieldElement(output);
if (fieldElement != instances[i]) {
revert("Public input does not match");
}
}
}

View File

@@ -0,0 +1,41 @@
## EZKL Security Note: Public Commitments and Low-Entropy Data
> **Disclaimer:** this a more technical post that requires some prior knowledge of how ZK proving systems like Halo2 operate, and in particular in how these APIs are constructed. For background reading we highly recommend the [Halo2 book](https://zcash.github.io/halo2/) and [Halo2 Club](https://halo2.club/).
## Overview of commitments in EZKL
A common design pattern in a zero knowledge (zk) application is thus:
- A prover has some data which is used within a circuit.
- This data, as it may be high-dimensional or somewhat private, is pre-committed to using some hash function.
- The zk-circuit which forms the core of the application then proves (para-phrasing) a statement of the form:
>"I know some data D which when hashed corresponds to the pre-committed to value H + whatever else the circuit is proving over D".
From our own experience, we've implemented such patterns using snark-friendly hash functions like [Poseidon](https://www.poseidon-hash.info/), for which there is a relatively well vetted [implementation](https://docs.rs/halo2_gadgets/latest/halo2_gadgets/poseidon/index.html) in Halo2. Even then these hash functions can introduce lots of overhead and can be very expensive to generate proofs for if the dimensionality of the data D is large.
You can also implement such a pattern using Halo2's `Fixed` columns _if the privacy preservation of the pre-image is not necessary_. These are Halo2 columns (i.e in reality just polynomials) that are left unblinded (unlike the blinded `Advice` columns), and whose commitments are shared with the verifier by way of the verifying key for the application's zk-circuit. These commitments are much lower cost to generate than implementing a hashing function, such as Poseidon, within a circuit.
> **Note:** Blinding is the process whereby a certain set of the final elements (i.e rows) of a Halo2 column are set to random field elements. This is the mechanism by which Halo2 achieves its zero knowledge properties for `Advice` columns. By contrast `Fixed` columns aren't zero-knowledge in that they are vulnerable to dictionary attacks in the same manner a hash function is. Given some set of known or popular data D an attacker can attempt to recover the pre-image of a hash by running D through the hash function to see if the outputs match a public commitment. These attacks aren't "possible" on blinded `Advice` columns.
> **Further Note:** Note that without blinding, with access to `M` proofs, each of which contains an evaluation of the polynomial at a different point, an attacker can more easily recover a non blinded column's pre-image. This is because each proof generates a new query and evaluation of the polynomial represented by the column and as such with repetition a clearer picture can emerge of the column's pre-image. Thus unblinded columns should only be used for privacy preservation, in the manner of a hash, if the number of proofs generated against a fixed set of values is limited. More formally if M independent and _unique_ queries are generated; if M is equal to the degree + 1 of the polynomial represented by the column (i.e the unique lagrange interpolation of the values in the columns), then the column's pre-image can be recovered. As such as the logrows K increases, the more queries are required to recover the pre-image (as 2^K unique queries are required). This assumes that the entries in the column are not structured, as if they are then the number of queries required to recover the pre-image is reduced (eg. if all rows above a certain point are known to be nil).
The annoyance in using `Fixed` columns comes from the fact that they require generating a new verifying key every time a new set of commitments is generated.
> **Example:** Say for instance an application leverages a zero-knowledge circuit to prove the correct execution of a neural network. Every week the neural network is finetuned or retrained on new data. If the architecture remains the same then commiting to the new network parameters, along with a new proof of performance on a test set, would be an ideal setup. If we leverage `Fixed` columns to commit to the model parameters, each new commitment will require re-generating a verifying key and sharing the new key with the verifier(s). This is not-ideal UX and can become expensive if the verifier is deployed on-chain.
An ideal commitment would thus have the low cost of a `Fixed` column but wouldn't require regenerating a new verifying key for each new commitment.
### Unblinded Advice Columns
A first step in designing such a commitment is to allow for optionally unblinded `Advice` columns within the Halo2 API. These won't be included in the verifying key, AND are blinded with a constant factor `1` -- such that if someone knows the pre-image to the commitment, they can recover it by running it through the corresponding polynomial commitment scheme (in ezkl's case [KZG commitments](https://dankradfeist.de/ethereum/2020/06/16/kate-polynomial-commitments.html)).
This is implemented using the `polycommit` visibility parameter in the ezkl API.
## The Vulnerability of Public Commitments
Public commitments in EZKL (both Poseidon-hashed inputs and KZG commitments) can be vulnerable to brute-force attacks when input data has low entropy. A malicious actor could reveal committed data by searching through possible input values, compromising privacy in applications like anonymous credentials. This is particularly relevant when input data comes from known finite sets (e.g., names, dates).
Example Risk: In an anonymous credential system using EZKL for ID verification, an attacker could match hashed outputs against a database of common identifying information to deanonymize users.

View File

@@ -0,0 +1,54 @@
# EZKL Security Note: Quantization-Activated Model Backdoors
## Model backdoors and provenance
Machine learning models inherently suffer from robustness issues, which can lead to various
kinds of attacks, from backdoors to evasion attacks. These vulnerabilities are a direct byproductof how machine learning models learn and cannot be remediated.
We say a model has a backdoor whenever a specific attacker-chosen trigger in the input leads
to the model misbehaving. For instance, if we have an image classifier discriminating cats from dogs, the ability to turn any image of a cat into an image classified as a dog by changing a specific pixel pattern constitutes a backdoor.
Backdoors can be introduced using many different vectors. An attacker can introduce a
backdoor using traditional security vulnerabilities. For instance, they could directly alter the file containing model weights or dynamically hack the Python code of the model. In addition, backdoors can be introduced by the training data through a process known as poisoning. In this case, an attacker adds malicious data points to the dataset before the model is trained so that the model learns to associate the backdoor trigger with the intended misbehavior.
All these vectors constitute a whole range of provenance challenges, as any component of an
AI system can virtually be an entrypoint for a backdoor. Although provenance is already a
concern with traditional code, the issue is exacerbated with AI, as retraining a model is
cost-prohibitive. It is thus impractical to translate the “recompile it yourself” thinking to AI.
## Quantization activated backdoors
Backdoors are a generic concern in AI that is outside the scope of EZKL. However, EZKL may
activate a specific subset of backdoors. Several academic papers have demonstrated the
possibility, both in theory and in practice, of implanting undetectable and inactive backdoors in a full precision model that can be reactivated by quantization.
An external attacker may trick the user of an application running EZKL into loading a model
containing a quantization backdoor. This backdoor is active in the resulting model and circuit but not in the full-precision model supplied to EZKL, compromising the integrity of the target application and the resulting proof.
### When is this a concern for me as a user?
Any untrusted component in your AI stack may be a backdoor vector. In practice, the most
sensitive parts include:
- Datasets downloaded from the web or containing crowdsourced data
- Models downloaded from the web even after finetuning
- Untrusted software dependencies (well-known frameworks such as PyTorch can typically
be considered trusted)
- Any component loaded through an unsafe serialization format, such as Pickle.
Because backdoors are inherent to ML and cannot be eliminated, reviewing the provenance of
these sensitive components is especially important.
### Responsibilities of the user and EZKL
As EZKL cannot prevent backdoored models from being used, it is the responsibility of the user to review the provenance of all the components in their AI stack to ensure that no backdoor could have been implanted. EZKL shall not be held responsible for misleading prediction proofs resulting from using a backdoored model or for any harm caused to a system or its users due to a misbehaving model.
### Limitations:
- Attack effectiveness depends on calibration settings and internal rescaling operations.
- Further research needed on backdoor persistence through witness/proof stages.
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model in pytorch or onnx-runtime as difference in evaluation could reveal a backdoor.
References:
1. [Quantization Backdoors to Deep Learning Commercial Frameworks (Ma et al., 2021)](https://arxiv.org/abs/2108.09187)
2. [Planting Undetectable Backdoors in Machine Learning Models (Goldwasser et al., 2022)](https://arxiv.org/abs/2204.06974)

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '16.2.0'
release = '0.0.0'
version = release

View File

@@ -32,6 +32,7 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -208,6 +209,8 @@ where
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
};
let x = config
.layer_config

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
# download tess data
# check if first argument has been set
if [ ! -z "$1" ]; then
DATA_DIR=$1
else
DATA_DIR=data
fi
echo "Downloading data to $DATA_DIR"
if [ ! -d "$DATA_DIR/CATDOG" ]; then
kaggle datasets download tongpython/cat-and-dog -p $DATA_DIR/CATDOG --unzip
fi

View File

@@ -1,601 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# data-attest-ezkl\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source.\n",
"\n",
"In this setup:\n",
"- the inputs and outputs are publicly known to the prover and verifier\n",
"- the on chain inputs will be fetched and then fed directly into the circuit\n",
"- the quantization of the on-chain inputs happens within the evm and is replicated at proving time \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"private\"\n",
"- `output_visibility`: public\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": [\n",
" {\n",
" \"call_data\": [\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
" 5\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
" 5\n",
" ]\n",
" ],\n",
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
" }\n",
" ]\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"await ezkl.setup_test_evm_witness(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
" data_path,\n",
" input_source=ezkl.PyTestDataSource.OnChain,\n",
" output_source=ezkl.PyTestDataSource.File,\n",
" rpc_url=RPC_URL)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,657 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# data-attest-ezkl hashed\n",
"\n",
"Here's an example leveraging EZKL whereby the hashes of the outputs to the model are read and attested to from an on-chain source.\n",
"\n",
"In this setup:\n",
"- the hashes of outputs are publicly known to the prover and verifier\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"# logging.basicConfig(format=FORMAT)\n",
"# logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"private\"\n",
"- `param_visibility`: \"private\"\n",
"- `output_visibility`: hashed\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"private\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"hashed\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now post the hashes of the outputs to the chain. This is the data that will be read from and attested to."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
"import os\n",
"import torch\n",
"\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL))\n",
"\n",
"def test_on_chain_data(res):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = [int(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
"\n",
" # Step 1: Prepare the data\n",
" # Step 2: Prepare and compile the contract.\n",
" # We are using a test contract here but in production you would\n",
" # use whatever contract you are fetching data from.\n",
" contract_source_code = '''\n",
" // SPDX-License-Identifier: UNLICENSED\n",
" pragma solidity ^0.8.17;\n",
"\n",
" contract TestReads {\n",
"\n",
" uint[] public arr;\n",
" constructor(uint256[] memory _numbers) {\n",
" for(uint256 i = 0; i < _numbers.length; i++) {\n",
" arr.push(_numbers[i]);\n",
" }\n",
" }\n",
" }\n",
" '''\n",
"\n",
" compiled_sol = compile_standard({\n",
" \"language\": \"Solidity\",\n",
" \"sources\": {\"testreads.sol\": {\"content\": contract_source_code}},\n",
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
" })\n",
"\n",
" # Get bytecode\n",
" bytecode = compiled_sol['contracts']['testreads.sol']['TestReads']['evm']['bytecode']['object']\n",
"\n",
" # Get ABI\n",
" # In production if you are reading from really large contracts you can just use\n",
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
" abi = json.loads(compiled_sol['contracts']['testreads.sol']['TestReads']['metadata'])['output']['abi']\n",
"\n",
" # Step 3: Deploy the contract\n",
" TestReads = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
" tx_hash = TestReads.constructor(data).transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
" # passing the address and abi of the contract you are fetching data from.\n",
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
"\n",
" # Step 4: Interact with the contract\n",
" calldata = []\n",
" for i, _ in enumerate(data):\n",
" call = contract.functions.arr(i).build_transaction()\n",
" calldata.append((call['data'][2:], 0))\n",
"\n",
" # Prepare the calls_to_account object\n",
" # If you were calling view functions across multiple contracts,\n",
" # you would have multiple entries in the calls_to_account array,\n",
" # one for each contract.\n",
" calls_to_account = [{\n",
" 'call_data': calldata,\n",
" 'address': contract.address[2:], # remove the '0x' prefix\n",
" }]\n",
"\n",
" print(f'calls_to_account: {calls_to_account}')\n",
"\n",
" return calls_to_account\n",
"\n",
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
"start_anvil()\n",
"\n",
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
"calls_to_account = test_on_chain_data(res)\n",
"\n",
"data = dict(input_data = [data_array], output_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,604 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# data-attest-kzg-vis\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source and the params and outputs are committed to using kzg-commitments. \n",
"\n",
"In this setup:\n",
"- the inputs and outputs are publicly known to the prover and verifier\n",
"- the on chain inputs will be fetched and then fed directly into the circuit\n",
"- the quantization of the on-chain inputs happens within the evm and is replicated at proving time \n",
"- The kzg commitment to the params and inputs will be read from the proof and checked to make sure it matches the expected commitment stored on-chain.\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"polycommitment\" \n",
"- `output_visibility`: \"polycommitment\"\n",
"\n",
"**Note**:\n",
"When we set this to polycommitment, we are saying that the model parameters are committed to using a polynomial commitment scheme. This commitment will be stored on chain as a constant stored in the DA contract, and the proof will contain the commitment to the parameters. The DA verification will then check that the commitment in the proof matches the commitment stored on chain. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"polycommit\"\n",
"run_args.output_visibility = \"polycommit\"\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": [\n",
" {\n",
" \"call_data\": [\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
" 5\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
" 5\n",
" ]\n",
" ],\n",
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
" }\n",
" ]\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"await ezkl.setup_test_evm_witness(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
" data_path,\n",
" input_source=ezkl.PyTestDataSource.OnChain,\n",
" output_source=ezkl.PyTestDataSource.File,\n",
" rpc_url=RPC_URL)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" \n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"When deploying a DA with kzg commitments, we need to make sure to also pass a witness file that contains the commitments to the parameters and inputs. This is because the verifier will need to check that the commitments in the proof match the commitments stored on chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" witness_path = witness_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -54,7 +54,7 @@
" gip_run_args.param_scale = 19\n",
" gip_run_args.logrows = 8\n",
" run_args = ezkl.gen_settings(py_run_args=gip_run_args)\n",
" ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n",
" await ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n",
" ezkl.compile_circuit()\n",
" res = await ezkl.gen_witness()\n",
" print(res)\n",

View File

@@ -77,6 +77,7 @@
"outputs": [],
"source": [
"gip_run_args = ezkl.PyRunArgs()\n",
"gip_run_args.ignore_range_check_inputs_outputs = True\n",
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
@@ -335,9 +336,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -453,8 +453,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -474,8 +474,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -510,4 +510,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -462,8 +462,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -483,8 +483,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -512,4 +512,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -1,279 +1,284 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
"cells": [
{
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## Linear Regression\n",
"\n",
"\n",
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
"This notebook showcases how to do so using the `hummingbird-ml` python package ! "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import os\n",
"import torch\n",
"import ezkl\n",
"import json\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
"# y = 1 * x_0 + 2 * x_1 + 3\n",
"y = np.dot(X, np.array([1, 2])) + 3\n",
"reg = LinearRegression().fit(X, y)\n",
"reg.score(X, y)\n",
"\n",
"circuit = convert(reg, \"torch\", X[:1]).model\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = X.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=True)\n",
"torch_out = circuit(x)\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"\n",
"# note that you can also call the following function to generate random data for the model\n",
"# it is functionally equivalent to the code above\n",
"ezkl.gen_random_data()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,456 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -504,8 +504,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -527,8 +527,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path\n",
")\n",
"assert res == True"
]
@@ -558,4 +558,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -453,18 +453,18 @@
"outputs": [],
"source": [
"# now mock aggregate the proofs\n",
"proofs = []\n",
"for i in range(3):\n",
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
" proofs.append(proof_path)\n",
"# proofs = []\n",
"# for i in range(3):\n",
"# proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
"# proofs.append(proof_path)\n",
"\n",
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
"# ezkl.mock_aggregate(proofs, logrows=26, split_proofs = True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"display_name": ".env",
"language": "python",
"name": "python3"
},
@@ -478,7 +478,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.7"
},
"orig_nbformat": 4
},

View File

@@ -0,0 +1,766 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"This is a zk version of the tutorial found [here](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/main/1%20-%20Neural%20Bag%20of%20Words.ipynb). The original tutorial is part of the PyTorch Sentiment Analysis series by Ben Trevett.\n",
"\n",
"1 - NBoW\n",
"\n",
"In this series we'll be building a machine learning model to perform sentiment analysis -- a subset of text classification where the task is to detect if a given sentence is positive or negative -- using PyTorch and torchtext. The dataset used will be movie reviews from the IMDb dataset, which we'll obtain using the datasets library.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Preparing Data\n",
"\n",
"Before we can implement our NBoW model, we first have to perform quite a few steps to get our data ready to use. NLP usually requires quite a lot of data wrangling beforehand, though libraries such as datasets and torchtext handle most of this for us.\n",
"\n",
"The steps to take are:\n",
"\n",
" 1. importing modules\n",
" 2. loading data\n",
" 3. tokenizing data\n",
" 4. creating data splits\n",
" 5. creating a vocabulary\n",
" 6. numericalizing data\n",
" 7. creating the data loaders\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install torchtex"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"\n",
"import datasets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchtext\n",
"import tqdm\n",
"\n",
"# It is usually good practice to run your experiments multiple times with different random seeds -- both to measure the variance of your model and also to avoid having results only calculated with either \"good\" or \"bad\" seeds, i.e. being very lucky or unlucky with the randomness in the training process.\n",
"\n",
"seed = 1234\n",
"\n",
"np.random.seed(seed)\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"torch.backends.cudnn.deterministic = True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data, test_data = datasets.load_dataset(\"imdb\", split=[\"train\", \"test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check the features attribute of a split to get more information about the features. We can see that text is a Value of dtype=string -- in other words, it's a string -- and that label is a ClassLabel. A ClassLabel means the feature is an integer representation of which class the example belongs to. num_classes=2 means that our labels are one of two values, 0 or 1, and names=['neg', 'pos'] gives us the human-readable versions of those values. Thus, a label of 0 means the example is a negative review and a label of 1 means the example is a positive review."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data.features\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the first things we need to do to our data is tokenize it. Machine learning models aren't designed to handle strings, they're design to handle numbers. So what we need to do is break down our string into individual tokens, and then convert these tokens to numbers. We'll get to the conversion later, but first we'll look at tokenization.\n",
"\n",
"Tokenization involves using a tokenizer to process the strings in our dataset. A tokenizer is a function that goes from a string to a list of strings. There are many types of tokenizers available, but we're going to use a relatively simple one provided by torchtext called the basic_english tokenizer. We load our tokenizer as such:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = torchtext.data.utils.get_tokenizer(\"basic_english\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_example(example, tokenizer, max_length):\n",
" tokens = tokenizer(example[\"text\"])[:max_length]\n",
" return {\"tokens\": tokens}\n",
"\n",
"\n",
"max_length = 256\n",
"\n",
"train_data = train_data.map(\n",
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
")\n",
"test_data = test_data.map(\n",
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
")\n",
"\n",
"\n",
"# create validation data \n",
"# Why have both a validation set and a test set? Your test set respresents the real world data that you'd see if you actually deployed this model. You won't be able to see what data your model will be fed once deployed, and your test set is supposed to reflect that. Every time we tune our model hyperparameters or training set-up to make it do a bit better on the test set, we are leak information from the test set into the training process. If we do this too often then we begin to overfit on the test set. Hence, we need some data which can act as a \"proxy\" test set which we can look at more frequently in order to evaluate how well our model actually does on unseen data -- this is the validation set.\n",
"\n",
"test_size = 0.25\n",
"\n",
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
"train_data = train_valid_data[\"train\"]\n",
"valid_data = train_valid_data[\"test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we have to build a vocabulary. This is look-up table where every unique token in your dataset has a corresponding index (an integer).\n",
"\n",
"We do this as machine learning models cannot operate on strings, only numerical vaslues. Each index is used to construct a one-hot vector for each token. A one-hot vector is a vector where all the elements are 0, except one, which is 1, and the dimensionality is the total number of unique tokens in your vocabulary, commonly denoted by V."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"min_freq = 5\n",
"special_tokens = [\"<unk>\", \"<pad>\"]\n",
"\n",
"vocab = torchtext.vocab.build_vocab_from_iterator(\n",
" train_data[\"tokens\"],\n",
" min_freq=min_freq,\n",
" specials=special_tokens,\n",
")\n",
"\n",
"# We store the indices of the unknown and padding tokens (zero and one, respectively) in variables, as we'll use these further on in this notebook.\n",
"\n",
"unk_index = vocab[\"<unk>\"]\n",
"pad_index = vocab[\"<pad>\"]\n",
"\n",
"\n",
"vocab.set_default_index(unk_index)\n",
"\n",
"# To look-up a list of tokens, we can use the vocabulary's lookup_indices method.\n",
"vocab.lookup_indices([\"hello\", \"world\", \"some_token\", \"<pad>\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have our vocabulary, we can numericalize our data. This involves converting the tokens within our dataset into indices. Similar to how we tokenized our data using the Dataset.map method, we'll define a function that takes an example and our vocabulary, gets the index for each token in each example and then creates an ids field which containes the numericalized tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def numericalize_example(example, vocab):\n",
" ids = vocab.lookup_indices(example[\"tokens\"])\n",
" return {\"ids\": ids}\n",
"\n",
"train_data = train_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"valid_data = valid_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"test_data = test_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"\n",
"train_data = train_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
"valid_data = valid_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
"test_data = test_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The final step of preparing the data is creating the data loaders. We can iterate over a data loader to retrieve batches of examples. This is also where we will perform any padding that is necessary.\n",
"\n",
"We first need to define a function to collate a batch, consisting of a list of examples, into what we want our data loader to output.\n",
"\n",
"Here, our desired output from the data loader is a dictionary with keys of \"ids\" and \"label\".\n",
"\n",
"The value of batch[\"ids\"] should be a tensor of shape [batch size, length], where length is the length of the longest sentence (in terms of tokens) within the batch, and all sentences shorter than this should be padded to that length.\n",
"\n",
"The value of batch[\"label\"] should be a tensor of shape [batch size] consisting of the label for each sentence in the batch.\n",
"\n",
"We define a function, get_collate_fn, which is passed the pad token index and returns the actual collate function. Within the actual collate function, collate_fn, we get a list of \"ids\" tensors for each example in the batch, and then use the pad_sequence function, which converts the list of tensors into the desired [batch size, length] shaped tensor and performs padding using the specified pad_index. By default, pad_sequence will return a [length, batch size] shaped tensor, but by setting batch_first=True, these two dimensions are switched. We get a list of \"label\" tensors and convert the list of tensors into a single [batch size] shaped tensor."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_collate_fn(pad_index):\n",
" def collate_fn(batch):\n",
" batch_ids = [i[\"ids\"] for i in batch]\n",
" batch_ids = nn.utils.rnn.pad_sequence(\n",
" batch_ids, padding_value=pad_index, batch_first=True\n",
" )\n",
" batch_label = [i[\"label\"] for i in batch]\n",
" batch_label = torch.stack(batch_label)\n",
" batch = {\"ids\": batch_ids, \"label\": batch_label}\n",
" return batch\n",
"\n",
" return collate_fn\n",
"\n",
"def get_data_loader(dataset, batch_size, pad_index, shuffle=False):\n",
" collate_fn = get_collate_fn(pad_index)\n",
" data_loader = torch.utils.data.DataLoader(\n",
" dataset=dataset,\n",
" batch_size=batch_size,\n",
" collate_fn=collate_fn,\n",
" shuffle=shuffle,\n",
" )\n",
" return data_loader\n",
"\n",
"\n",
"batch_size = 512\n",
"\n",
"train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)\n",
"valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)\n",
"test_data_loader = get_data_loader(test_data, batch_size, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class NBoW(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
"\n",
" def forward(self, ids):\n",
" # ids = [batch size, seq len]\n",
" embedded = self.embedding(ids)\n",
" # embedded = [batch size, seq len, embedding dim]\n",
" pooled = embedded.mean(dim=1)\n",
" # pooled = [batch size, embedding dim]\n",
" prediction = self.fc(pooled)\n",
" # prediction = [batch size, output dim]\n",
" return prediction\n",
"\n",
"\n",
"vocab_size = len(vocab)\n",
"embedding_dim = 300\n",
"output_dim = len(train_data.unique(\"label\"))\n",
"\n",
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)\n",
"\n",
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"\n",
"print(f\"The model has {count_parameters(model):,} trainable parameters\")\n",
"\n",
"vectors = torchtext.vocab.GloVe()\n",
"\n",
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())\n",
"\n",
"optimizer = optim.Adam(model.parameters())\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(data_loader, model, criterion, optimizer, device):\n",
" model.train()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
" for batch in tqdm.tqdm(data_loader, desc=\"training...\"):\n",
" ids = batch[\"ids\"].to(device)\n",
" label = batch[\"label\"].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
" return np.mean(epoch_losses), np.mean(epoch_accs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(data_loader, model, criterion, device):\n",
" model.eval()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
" with torch.no_grad():\n",
" for batch in tqdm.tqdm(data_loader, desc=\"evaluating...\"):\n",
" ids = batch[\"ids\"].to(device)\n",
" label = batch[\"label\"].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
" return np.mean(epoch_losses), np.mean(epoch_accs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(prediction, label):\n",
" batch_size, _ = prediction.shape\n",
" predicted_classes = prediction.argmax(dim=-1)\n",
" correct_predictions = predicted_classes.eq(label).sum()\n",
" accuracy = correct_predictions / batch_size\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_epochs = 10\n",
"best_valid_loss = float(\"inf\")\n",
"\n",
"metrics = collections.defaultdict(list)\n",
"\n",
"for epoch in range(n_epochs):\n",
" train_loss, train_acc = train(\n",
" train_data_loader, model, criterion, optimizer, device\n",
" )\n",
" valid_loss, valid_acc = evaluate(valid_data_loader, model, criterion, device)\n",
" metrics[\"train_losses\"].append(train_loss)\n",
" metrics[\"train_accs\"].append(train_acc)\n",
" metrics[\"valid_losses\"].append(valid_loss)\n",
" metrics[\"valid_accs\"].append(valid_acc)\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), \"nbow.pt\")\n",
" print(f\"epoch: {epoch}\")\n",
" print(f\"train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}\")\n",
" print(f\"valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(metrics[\"train_losses\"], label=\"train loss\")\n",
"ax.plot(metrics[\"valid_losses\"], label=\"valid loss\")\n",
"ax.set_xlabel(\"epoch\")\n",
"ax.set_ylabel(\"loss\")\n",
"ax.set_xticks(range(n_epochs))\n",
"ax.legend()\n",
"ax.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(metrics[\"train_accs\"], label=\"train accuracy\")\n",
"ax.plot(metrics[\"valid_accs\"], label=\"valid accuracy\")\n",
"ax.set_xlabel(\"epoch\")\n",
"ax.set_ylabel(\"loss\")\n",
"ax.set_xticks(range(n_epochs))\n",
"ax.legend()\n",
"ax.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(\"nbow.pt\"))\n",
"\n",
"test_loss, test_acc = evaluate(test_data_loader, model, criterion, device)\n",
"\n",
"print(f\"test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = vocab.lookup_indices(tokens)\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" prediction = model(tensor).squeeze(dim=0)\n",
" probability = torch.softmax(prediction, dim=-1)\n",
" predicted_class = prediction.argmax(dim=-1).item()\n",
" predicted_probability = probability[predicted_class].item()\n",
" return predicted_class, predicted_probability"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is not terrible, it's great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is not great, it's terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def text_to_tensor(text, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = vocab.lookup_indices(tokens)\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" return tensor\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we do onnx stuff to get the data ready for the zk-circuit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import json\n",
"\n",
"text = \"This film is terrible!\"\n",
"x = text_to_tensor(text, tokenizer, vocab, device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"model.eval()\n",
"model.to('cpu')\n",
"\n",
"model_path = \"network.onnx\"\n",
"data_path = \"input.json\"\n",
"\n",
" # Export the model\n",
"torch.onnx.export(model, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data_json = dict(input_data = [data_array])\n",
"\n",
"print(data_json)\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data_json, open(data_path, 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.logrows = 23\n",
"run_args.scale_rebase_multiplier = 10\n",
"# inputs should be auditable by all\n",
"run_args.input_visibility = \"public\"\n",
"# same with outputs\n",
"run_args.output_visibility = \"public\"\n",
"# for simplicity, we'll just use the fixed model visibility: i.e it is public and can't be changed by the prover\n",
"run_args.param_visibility = \"fixed\"\n",
"\n",
"\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(py_run_args=run_args)\n",
"assert res == True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit()\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file\n",
"res = await ezkl.gen_witness()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.mock()\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"res = ezkl.setup()\n",
"\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"res = ezkl.prove(proof_path=\"proof.json\")\n",
"\n",
"print(res)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"res = ezkl.verify()\n",
"\n",
"assert res == True\n",
"print(\"verified\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also verify it on chain by creating an onchain verifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"solc-select\"])\n",
" !solc-select install 0.8.20\n",
" !solc-select use 0.8.20\n",
" !solc --version\n",
" import os\n",
"\n",
"# rely on local installation if the notebook is not in colab\n",
"except:\n",
" import os\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.create_evm_verifier()\n",
"assert res == True\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see a `Verifier.sol`. Right-click and save it locally.\n",
"\n",
"Now go to [https://remix.ethereum.org](https://remix.ethereum.org).\n",
"\n",
"Create a new file within remix and copy the verifier code over.\n",
"\n",
"Finally, compile the code and deploy. For the demo you can deploy to the test environment within remix.\n",
"\n",
"If everything works, you would have deployed your verifer onchain! Copy the values in the cell above to the respective fields to test if the verifier is working."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -152,9 +152,11 @@
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"run_args = ezkl.PyRunArgs()\n",
"# logrows\n",
"run_args.logrows = 20\n",
"\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n"
]
},
@@ -302,7 +304,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.9.13"
}
},
"nbformat": 4,

View File

@@ -125,7 +125,7 @@
"\n",
" witness_path = os.path.join(name, \"witness.json\")\n",
" sol_code_path = os.path.join(name, 'test.sol')\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" vka_path = os.path.join(name, 'vka.bytes')\n",
" abi_path = os.path.join(name, 'test.abi')\n",
" proof_path = os.path.join(name, \"proof.json\")\n",
"\n",
@@ -177,7 +177,7 @@
" res = await ezkl.create_evm_verifier(vk_path, settings_path, sol_code_path, abi_path, reusable=True)\n",
" assert res == True\n",
"\n",
" res = await ezkl.create_evm_vka(vk_path, settings_path, sol_key_code_path, abi_path)\n",
" res = await ezkl.create_evm_vka(vk_path, settings_path, vka_path)\n",
" assert res == True\n"
]
},
@@ -220,15 +220,6 @@
"Check that the generated verifiers are identical for all models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"start_anvil()"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -270,8 +261,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" \"verifier/reusable\"\n",
")\n",
"\n",
@@ -296,20 +287,21 @@
"source": [
"for name in names:\n",
" addr_path_vk = \"addr_vk.txt\"\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
" vka_path = os.path.join(name, 'vka.bytes')\n",
" res = await ezkl.register_vka(\n",
" addr,\n",
" 'http://127.0.0.1:3030',\n",
" vka_path=vka_path,\n",
" )\n",
" assert res == True\n",
"\n",
" with open(addr_path_vk, 'r') as file:\n",
" addr_vk = file.read().rstrip()\n",
" \n",
" proof_path = os.path.join(name, \"proof.json\")\n",
" sol_code_path = os.path.join(name, 'vk.sol')\n",
" res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\",\n",
" addr_vk = addr_vk\n",
" proof_path,\n",
" vka_path = vka_path\n",
" )\n",
" assert res == True"
]

View File

@@ -167,6 +167,8 @@
"run_args = ezkl.PyRunArgs()\n",
"# \"hashed/private\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"hashed/private/0\"\n",
"# as the inputs are felts we turn off input range checks\n",
"run_args.ignore_range_check_inputs_outputs = True\n",
"# we set it to fix the set we want to check membership for\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public -- set membership fails if it is not = 0\n",
@@ -519,4 +521,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -204,6 +204,7 @@
"run_args = ezkl.PyRunArgs()\n",
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"polycommit\"\n",
"run_args.ignore_range_check_inputs_outputs = True\n",
"# the parameters are public\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public (this is the inequality test)\n",
@@ -514,4 +515,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -1,763 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# univ3-da-ezkl\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source. For this setup we make a single call to a view function that returns an array of UniV3 historical TWAP price data that we will attest to on-chain. \n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import the necessary dependencies and set up logging to be as informative as possible. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import logging\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# Defines the model\n",
"\n",
"class MyModel(nn.Module):\n",
" def __init__(self):\n",
" super(MyModel, self).__init__()\n",
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
"\n",
" def forward(self, x):\n",
" return self.layer(x)[0]\n",
"\n",
"\n",
"circuit = MyModel()\n",
"\n",
"# this is where you'd train your model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
"\n",
"You can replace the random `x` with real data if you so wish. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w' ))\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--fork-url\", \"https://arb1.arbitrum.io/rpc\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"private\"\n",
"- `output_visibility`: public\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
"\n",
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# generate a bunch of dummy calibration data\n",
"cal_data = {\n",
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
"}\n",
"\n",
"cal_path = os.path.join('val_data.json')\n",
"# save as json file\n",
"with open(cal_path, \"w\") as f:\n",
" json.dump(cal_data, f)\n",
"\n",
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": {\n",
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
" \"len\": 12 // The number of data points returned by the view function (the length of the array)\n",
" }\n",
" }\n",
"}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
"import os\n",
"import torch\n",
"import requests\n",
"\n",
"# This function counts the decimal places of a floating point number\n",
"def count_decimal_places(num):\n",
" num_str = str(num)\n",
" if '.' in num_str:\n",
" return len(num_str) - 1 - num_str.index('.')\n",
" else:\n",
" return 0\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"def set_next_block_timestamp(anvil_url, timestamp):\n",
" # Send the JSON-RPC request to Anvil\n",
" payload = {\n",
" \"jsonrpc\": \"2.0\",\n",
" \"id\": 1,\n",
" \"method\": \"evm_setNextBlockTimestamp\",\n",
" \"params\": [timestamp]\n",
" }\n",
" response = requests.post(anvil_url, json=payload)\n",
" if response.status_code == 200:\n",
" print(f\"Next block timestamp set to: {timestamp}\")\n",
" else:\n",
" print(f\"Failed to set next block timestamp: {response.text}\")\n",
"\n",
"def on_chain_data(tensor):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = tensor.view(-1).tolist()\n",
"\n",
" # Step 1: Prepare the calldata\n",
" secondsAgo = [len(data) - 1 - i for i in range(len(data))]\n",
"\n",
" # Step 2: Prepare and compile the contract UniTickAttestor contract\n",
" contract_source_code = '''\n",
" // SPDX-License-Identifier: MIT\n",
" pragma solidity ^0.8.20;\n",
"\n",
" /// @title Pool state that is not stored\n",
" /// @notice Contains view functions to provide information about the pool that is computed rather than stored on the\n",
" /// blockchain. The functions here may have variable gas costs.\n",
" interface IUniswapV3PoolDerivedState {\n",
" /// @notice Returns the cumulative tick and liquidity as of each timestamp `secondsAgo` from the current block timestamp\n",
" /// @dev To get a time weighted average tick or liquidity-in-range, you must call this with two values, one representing\n",
" /// the beginning of the period and another for the end of the period. E.g., to get the last hour time-weighted average tick,\n",
" /// you must call it with secondsAgos = [3600, 0].\n",
" /// log base sqrt(1.0001) of token1 / token0. The TickMath library can be used to go from a tick value to a ratio.\n",
" /// @dev The time weighted average tick represents the geometric time weighted average price of the pool, in\n",
" /// @param secondsAgos From how long ago each cumulative tick and liquidity value should be returned\n",
" /// @return tickCumulatives Cumulative tick values as of each `secondsAgos` from the current block timestamp\n",
" /// @return secondsPerLiquidityCumulativeX128s Cumulative seconds per liquidity-in-range value as of each `secondsAgos` from the current block\n",
" /// timestamp\n",
" function observe(\n",
" uint32[] calldata secondsAgos\n",
" )\n",
" external\n",
" view\n",
" returns (\n",
" int56[] memory tickCumulatives,\n",
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
" );\n",
" }\n",
"\n",
" /// @title Uniswap Wrapper around `pool.observe` that stores the parameters for fetching and then attesting to historical data\n",
" /// @notice Provides functions to integrate with V3 pool oracle\n",
" contract UniTickAttestor {\n",
" /**\n",
" * @notice Calculates time-weighted means of tick and liquidity for a given Uniswap V3 pool\n",
" * @param pool Address of the pool that we want to observe\n",
" * @param secondsAgo Number of seconds in the past from which to calculate the time-weighted means\n",
" * @return tickCumulatives The cumulative tick values as of each `secondsAgo` from the current block timestamp\n",
" */\n",
" function consult(\n",
" IUniswapV3PoolDerivedState pool,\n",
" uint32[] memory secondsAgo\n",
" ) public view returns (int256[] memory tickCumulatives) {\n",
" tickCumulatives = new int256[](secondsAgo.length);\n",
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
" for (uint256 i = 0; i < secondsAgo.length; i++) {\n",
" tickCumulatives[i] = int256(_ticks[i]);\n",
" }\n",
" }\n",
" }\n",
" '''\n",
"\n",
" compiled_sol = compile_standard({\n",
" \"language\": \"Solidity\",\n",
" \"sources\": {\"UniTickAttestor.sol\": {\"content\": contract_source_code}},\n",
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
" })\n",
"\n",
" # Get bytecode\n",
" bytecode = compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['evm']['bytecode']['object']\n",
"\n",
" # Get ABI\n",
" # In production if you are reading from really large contracts you can just use\n",
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
" abi = json.loads(compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['metadata'])['output']['abi']\n",
"\n",
" # Step 3: Deploy the contract\n",
" UniTickAttestor = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
" tx_hash = UniTickAttestor.constructor().transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
" # passing the address and abi of the contract you are fetching data from.\n",
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
"\n",
" # Step 4: Interact with the contract\n",
" call = contract.functions.consult(\n",
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).build_transaction()\n",
" result = contract.functions.consult(\n",
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).call()\n",
" \n",
" print(f'result: {result}')\n",
" calldata = call['data'][2:]\n",
"\n",
" time_stamp = w3.eth.get_block('latest')['timestamp']\n",
"\n",
" print(f'time_stamp: {time_stamp}')\n",
"\n",
" # Set the next block timestamp using the fetched time_stamp\n",
" set_next_block_timestamp(RPC_URL, time_stamp)\n",
"\n",
"\n",
" # Prepare the calls_to_account object\n",
" # If you were calling view functions across multiple contracts,\n",
" # you would have multiple entries in the calls_to_account array,\n",
" # one for each contract.\n",
" call_to_account = {\n",
" 'call_data': calldata,\n",
" 'decimals': 0,\n",
" 'address': contract.address[2:], # remove the '0x' prefix\n",
" 'len': len(data),\n",
" }\n",
"\n",
" print(f'call_to_account: {call_to_account}')\n",
"\n",
" return call_to_account\n",
"\n",
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
"start_anvil()\n",
"\n",
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
"calls_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we generate a full proof. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"And verify it as a sanity check. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create and then deploy a vanilla evm verifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
"So should only be used for testing purposes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we need to regenerate the witness, prove and then verify all within the same cell. This is because we want to reduce the amount of latency between reading on-chain state and verifying it on-chain. This is because the attest input values read from the oracle are time sensitive (their values are derived from computing on block.timestamp) and can change between the time of reading and the time of verifying.\n",
"\n",
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"calls_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"time_stamp = w3.eth.get_block('latest')['timestamp']\n",
"\n",
"print(f'time_stamp: {time_stamp}')\n",
"\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)\n",
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -666,7 +666,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -689,8 +689,8 @@
"# await\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -701,7 +701,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -722,8 +722,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -743,7 +743,8 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
@@ -756,7 +757,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.9"
}
},
"nbformat": 4,

View File

@@ -849,8 +849,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -870,8 +870,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path\n",
")\n",
"assert res == True"
]
@@ -905,4 +905,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -1,547 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## World rotation\n",
"\n",
"Here we demonstrate how to use the EZKL package to rotate an on-chain world. \n",
"\n",
"![zk-gaming-diagram-transformed](https://hackmd.io/_uploads/HkApuQGV6.png)\n",
"> **A typical ZK application flow**. For the shape rotators out there — this is an easily digestible example. A user computes a ZK-proof that they have calculated a valid rotation of a world. They submit this proof to a verifier contract which governs an on-chain world, along with a new set of coordinates, and the world rotation updates. Observe that its possible for one player to initiate a *global* change.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import torch\n",
"import math\n",
"\n",
"# these are constants for the rotation\n",
"phi = torch.tensor(5 * math.pi / 180)\n",
"s = torch.sin(phi)\n",
"c = torch.cos(phi)\n",
"\n",
"\n",
"class RotateStuff(nn.Module):\n",
" def __init__(self):\n",
" super(RotateStuff, self).__init__()\n",
"\n",
" # create a rotation matrix -- the matrix is constant and is transposed for convenience\n",
" self.rot = torch.stack([torch.stack([c, -s]),\n",
" torch.stack([s, c])]).t()\n",
"\n",
" def forward(self, x):\n",
" x_rot = x @ self.rot # same as x_rot = (rot @ x.t()).t() due to rot in O(n) (SO(n) even)\n",
" return x_rot\n",
"\n",
"\n",
"circuit = RotateStuff()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This will showcase the principle directions of rotation by plotting the rotation of a single unit vector."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot\n",
"pyplot.figure(figsize=(3, 3))\n",
"pyplot.arrow(0, 0, 1, 0, width=0.02, alpha=0.5)\n",
"pyplot.arrow(0, 0, 0, 1, width=0.02, alpha=0.5)\n",
"pyplot.arrow(0, 0, circuit.rot[0, 0].item(), circuit.rot[0, 1].item(), width=0.02)\n",
"pyplot.arrow(0, 0, circuit.rot[1, 0].item(), circuit.rot[1, 1].item(), width=0.02)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# initial principle vectors for the rotation are as in the plot above\n",
"x = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" )\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path, 'w' ))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### World rotation in 2D on-chain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For demo purposes we deploy these coordinates to a contract running locally using Anvil. This creates our on-chain world. We then rotate the world using the EZKL package and submit the proof to the contract. The contract then updates the world rotation. For demo purposes we do this repeatedly, rotating the world by 1 transform each time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"import time\n",
"import threading\n",
"\n",
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"RPC_URL = \"http://localhost:3030\"\n",
"\n",
"# Save process globally\n",
"anvil_process = None\n",
"\n",
"def start_anvil():\n",
" global anvil_process\n",
" if anvil_process is None:\n",
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
" if anvil_process.returncode is not None:\n",
" raise Exception(\"failed to start anvil process\")\n",
" time.sleep(3)\n",
"\n",
"def stop_anvil():\n",
" global anvil_process\n",
" if anvil_process is not None:\n",
" anvil_process.terminate()\n",
" anvil_process = None\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
"- `input_visibility` defines the visibility of the model inputs\n",
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
"- `output_visibility` defines the visibility of the model outputs\n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
"- `param_visibility`: \"fixed\"\n",
"- `output_visibility`: public"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"py_run_args = ezkl.PyRunArgs()\n",
"py_run_args.input_visibility = \"public\"\n",
"py_run_args.output_visibility = \"public\"\n",
"py_run_args.param_visibility = \"private\" # private by default\n",
"py_run_args.scale_rebase_multiplier = 10\n",
"\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also define a contract that holds out test data. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2007dc77",
"metadata": {},
"outputs": [],
"source": [
"ezkl.setup_test_evm_witness(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
" data_path,\n",
" input_source=ezkl.PyTestDataSource.OnChain,\n",
" output_source=ezkl.PyTestDataSource.File,\n",
" rpc_url=RPC_URL)"
]
},
{
"cell_type": "markdown",
"id": "ab993958",
"metadata": {},
"source": [
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
"\n",
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"witness = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "markdown",
"id": "ad58432e",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "markdown",
"id": "1746c8d1",
"metadata": {},
"source": [
"We can now create an EVM verifier contract from our circuit. This contract will be deployed to the chain we are using. In this case we are using a local anvil instance."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1920c0f",
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = await ezkl.create_evm_verifier(\n",
" vk_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fd7f22b",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"addr_path_verifier = \"addr_verifier.txt\"\n",
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"id": "9c0dffab",
"metadata": {},
"source": [
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc888848",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2db14d7",
"metadata": {},
"outputs": [],
"source": [
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"input_path = 'input.json'\n",
"\n",
"res = await ezkl.create_evm_data_attestation(\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a018ba6",
"metadata": {},
"outputs": [],
"source": [
"addr_path_da = \"addr_da.txt\"\n",
"\n",
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "2adad845",
"metadata": {},
"source": [
"Now we can pull in the data from the contract and calculate a new set of coordinates. We then rotate the world by 1 transform and submit the proof to the contract. The contract could then update the world rotation (logic not inserted here). For demo purposes we do this repeatedly, rotating the world by 1 transform. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "markdown",
"id": "90eda56e",
"metadata": {},
"source": [
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# read the verifier address\n",
"addr_verifier = None\n",
"with open(addr_path_verifier, 'r') as f:\n",
" addr = f.read()\n",
"#read the data attestation address\n",
"addr_da = None\n",
"with open(addr_path_da, 'r') as f:\n",
" addr_da = f.read()\n",
"\n",
"res = ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" addr_da,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a sanity check lets plot the rotations of the unit vectors. We can see that the unit vectors rotate as expected by the output of the circuit. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness['outputs'][0][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"settings = json.load(open(settings_path, 'r'))\n",
"out_scale = settings[\"model_output_scales\"][0]\n",
"\n",
"from matplotlib import pyplot\n",
"pyplot.figure(figsize=(3, 3))\n",
"pyplot.arrow(0, 0, 1, 0, width=0.02, alpha=0.5)\n",
"pyplot.arrow(0, 0, 0, 1, width=0.02, alpha=0.5)\n",
"\n",
"arrow_x = ezkl.felt_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.felt_to_float(witness['outputs'][0][1], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)\n",
"arrow_x = ezkl.felt_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.felt_to_float(witness['outputs'][0][3], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,106 @@
{
"input_data": [
[
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127,
8761,
7654,
8501,
2404,
6929,
8858,
5946,
3673,
4131,
3854,
8137,
8239,
9038,
6299,
1118,
9737,
208,
7954,
3691,
610,
3468,
3314,
8658,
8366,
2850,
477,
6114,
232,
4601,
7420,
5713,
2936,
6061,
2870,
8421,
177,
7107,
7382,
6115,
5487,
8502,
2559,
1875,
129,
8533,
8201,
8414,
4775,
9817,
3127
]
]
}

Binary file not shown.

View File

@@ -0,0 +1 @@
{"run_args":{"input_scale":7,"param_scale":7,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"num_inner_cols":2,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private","rebase_frac_zero_constants":false,"check_mode":"UNSAFE","commitment":"KZG","decomp_base":16384,"decomp_legs":2,"bounded_log_lookup":false,"ignore_range_check_inputs_outputs":false},"num_rows":54,"total_assignments":109,"total_const_size":4,"total_dynamic_col_size":0,"max_dynamic_input_len":0,"num_dynamic_lookups":0,"num_shuffles":0,"total_shuffle_col_size":0,"model_instance_shapes":[[1,1]],"model_output_scales":[7],"model_input_scales":[7],"module_sizes":{"polycommit":[],"poseidon":[0,[0]]},"required_lookups":[],"required_range_checks":[[-1,1],[0,16383]],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null,"timestamp":1739396322131,"input_types":["F32"],"output_types":["F32"]}

File diff suppressed because one or more lines are too long

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
return x // 3
circuit = MyModel()
x = torch.randint(0, 10, (1, 2, 2, 8))
out = circuit(x)
print(x)
print(out)
print(x/3)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}

Binary file not shown.

View File

@@ -160,30 +160,6 @@ def compile_circuit(model:str | os.PathLike | pathlib.Path,compiled_circuit:str
"""
...
def create_evm_data_attestation(input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,witness_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
Arguments
---------
input_data: str
The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifier
abi_path: str
The path to create the ABI for the solidity verifier
Returns
-------
bool
"""
...
def create_evm_verifier(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
r"""
Creates an EVM compatible verifier, you will need solc installed in your environment to run this
@@ -247,7 +223,7 @@ def create_evm_verifier_aggr(aggregation_settings:typing.Sequence[str | os.PathL
"""
...
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vka_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
@@ -260,8 +236,8 @@ def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str |
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifying key.
vka_path: str
The path to the create the vka calldata.
abi_path: str
The path to create the ABI for the solidity verifier
@@ -275,12 +251,6 @@ def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str |
"""
...
def deploy_da_evm(addr_path:str | os.PathLike | pathlib.Path,input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
r"""
deploys the solidity da verifier
"""
...
def deploy_evm(addr_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],contract_type:str,optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
r"""
deploys the solidity verifier
@@ -706,35 +676,6 @@ def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Pa
"""
...
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
r"""
Setup test evm witness
Arguments
---------
data_path: str
The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
compiled_circuit_path: str
The path to the compiled model file (generated using the compile-circuit command)
test_data: str
For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
input_sources: str
Where the input data comes from
output_source: str
Where the output data comes from
rpc_url: str
RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
Returns
-------
bool
"""
...
def swap_proof_commitments(proof_path:str | os.PathLike | pathlib.Path,witness_path:str | os.PathLike | pathlib.Path) -> None:
r"""
@@ -823,7 +764,7 @@ def verify_aggr(proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.Pat
"""
...
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],addr_da:typing.Optional[str],addr_vk:typing.Optional[str]) -> typing.Any:
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],vka_path:typing.Optional[str]) -> typing.Any:
r"""
verifies an evm compatible proof, you will need solc installed in your environment to run this
@@ -838,11 +779,8 @@ def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc
rpc_url: str
RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
addr_da: str
does the verifier use data attestation ?
addr_vk: str
The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
vka_path: str
The path to the VKA calldata bytes file (generated using the create_evm_vka command)
Returns
-------
bool

View File

@@ -12,6 +12,7 @@ asyncio_mode = "auto"
[project]
name = "ezkl"
version = "0.0.0"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-07-18"
channel = "nightly-2025-02-17"
components = ["rustfmt", "clippy"]

View File

@@ -1,7 +1,11 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
static GLOBAL: MiMalloc = MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
@@ -24,6 +28,8 @@ use std::env;
#[tokio::main(flavor = "current_thread")]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn main() {
use log::debug;
let args = Cli::parse();
if let Some(generator) = args.generator {
@@ -38,7 +44,7 @@ pub async fn main() {
} else {
info!("Running with CPU");
}
info!(
debug!(
"command: \n {}",
&command.as_json().to_colored_json_auto().unwrap()
);

View File

@@ -4,11 +4,10 @@ use crate::circuit::modules::poseidon::{
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::{CheckMode, Tolerance};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::modules::POSEIDON_LEN_GRAPH;
use crate::graph::TestDataSource;
use crate::graph::{
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
@@ -156,9 +155,6 @@ impl pyo3::ToPyObject for PyG1Affine {
#[derive(Clone)]
#[gen_stub_pyclass]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
pub tolerance: f32,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing inputs
pub input_scale: crate::Scale,
@@ -207,6 +203,12 @@ struct PyRunArgs {
/// bool: Should the circuit use unbounded lookups for log
#[pyo3(get, set)]
pub bounded_log_lookup: bool,
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
#[pyo3(get, set)]
pub ignore_range_check_inputs_outputs: bool,
/// float: epsilon used for arguments that use division
#[pyo3(get, set)]
pub epsilon: f64,
}
/// default instantiation of PyRunArgs
@@ -223,7 +225,6 @@ impl From<PyRunArgs> for RunArgs {
fn from(py_run_args: PyRunArgs) -> Self {
RunArgs {
bounded_log_lookup: py_run_args.bounded_log_lookup,
tolerance: Tolerance::from(py_run_args.tolerance),
input_scale: py_run_args.input_scale,
param_scale: py_run_args.param_scale,
num_inner_cols: py_run_args.num_inner_cols,
@@ -239,15 +240,17 @@ impl From<PyRunArgs> for RunArgs {
commitment: Some(py_run_args.commitment.into()),
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
epsilon: Some(py_run_args.epsilon),
}
}
}
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
let eps = self.get_epsilon();
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
tolerance: self.tolerance.val,
input_scale: self.input_scale,
param_scale: self.param_scale,
num_inner_cols: self.num_inner_cols,
@@ -263,6 +266,8 @@ impl Into<PyRunArgs> for RunArgs {
commitment: self.commitment.into(),
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
epsilon: eps,
}
}
}
@@ -333,6 +338,8 @@ enum PyInputType {
Int,
///
TDim,
///
Unknown,
}
impl From<InputType> for PyInputType {
@@ -344,6 +351,7 @@ impl From<InputType> for PyInputType {
InputType::F64 => PyInputType::F64,
InputType::Int => PyInputType::Int,
InputType::TDim => PyInputType::TDim,
InputType::Unknown => PyInputType::Unknown,
}
}
}
@@ -357,6 +365,7 @@ impl From<PyInputType> for InputType {
PyInputType::F64 => InputType::F64,
PyInputType::Int => InputType::Int,
PyInputType::TDim => InputType::TDim,
PyInputType::Unknown => InputType::Unknown,
}
}
}
@@ -371,6 +380,7 @@ impl FromStr for PyInputType {
"f64" => Ok(PyInputType::F64),
"int" => Ok(PyInputType::Int),
"tdim" => Ok(PyInputType::TDim),
"unknown" => Ok(PyInputType::Unknown),
_ => Err("Invalid value for InputType".to_string()),
}
}
@@ -573,10 +583,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
message.clone(),
)
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
.map_err(|_| PyIOError::new_err("Failed to run poseidon"))?;
let hash = output[0]
@@ -591,7 +598,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
/// List of field elements represented as strings
///
/// vk_path: str
/// Path to the verification key
@@ -650,7 +657,7 @@ fn kzg_commit(
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
/// List of field elements represented as strings
///
/// vk_path: str
/// Path to the verification key
@@ -938,6 +945,49 @@ fn gen_settings(
Ok(true)
}
/// Generates random data for the model
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// output: str
/// Path to create the data file
///
/// seed: int
/// Random seed to use for generated data
///
/// variables
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
output=PathBuf::from(DEFAULT_SETTINGS),
variables=Vec::from([("batch_size".to_string(), 1)]),
seed=DEFAULT_SEED.parse().unwrap(),
min=None,
max=None
))]
#[gen_stub_pyfunction]
fn gen_random_data(
model: PathBuf,
output: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
min: Option<f32>,
max: Option<f32>,
) -> Result<bool, PyErr> {
crate::execute::gen_random_data(model, output, variables, seed, min, max).map_err(|e| {
let err_str = format!("Failed to generate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
}
/// Calibrates the circuit settings
///
/// Arguments
@@ -969,7 +1019,7 @@ fn gen_settings(
/// bool
///
#[pyfunction(signature = (
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
data = String::from(DEFAULT_CALIBRATION_FILE),
model = PathBuf::from(DEFAULT_MODEL),
settings = PathBuf::from(DEFAULT_SETTINGS),
target = CalibrationTarget::default(), // default is "resources
@@ -981,7 +1031,7 @@ fn gen_settings(
#[gen_stub_pyfunction]
fn calibrate_settings(
py: Python,
data: PathBuf,
data: String,
model: PathBuf,
settings: PathBuf,
target: CalibrationTarget,
@@ -990,25 +1040,22 @@ fn calibrate_settings(
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::calibrate(
model,
data,
settings,
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
max_logrows,
)
.await
.map_err(|e| {
let err_str = format!("Failed to calibrate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
crate::execute::calibrate(
model,
data,
settings,
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
max_logrows,
)
.map_err(|e| {
let err_str = format!("Failed to calibrate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
Ok(true)
}
/// Runs the forward pass operation to generate a witness
@@ -1036,7 +1083,7 @@ fn calibrate_settings(
/// Python object containing the witness values
///
#[pyfunction(signature = (
data=PathBuf::from(DEFAULT_DATA),
data=String::from(DEFAULT_DATA),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
output=PathBuf::from(DEFAULT_WITNESS),
vk_path=None,
@@ -1045,21 +1092,18 @@ fn calibrate_settings(
#[gen_stub_pyfunction]
fn gen_witness(
py: Python,
data: PathBuf,
data: String,
model: PathBuf,
output: Option<PathBuf>,
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
.await
.map_err(|e| {
let err_str = format!("Failed to generate witness: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
})
let output =
crate::execute::gen_witness(model, data, output, vk_path, srs_path).map_err(|e| {
let err_str = format!("Failed to generate witness: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// Mocks the prover
@@ -1557,22 +1601,15 @@ fn verify_aggr(
#[pyfunction(signature = (
proof=PathBuf::from(DEFAULT_PROOF),
calldata=PathBuf::from(DEFAULT_CALLDATA),
addr_vk=None,
vka_path=None,
))]
#[gen_stub_pyfunction]
fn encode_evm_calldata<'a>(
proof: PathBuf,
calldata: PathBuf,
addr_vk: Option<&'a str>,
vka_path: Option<PathBuf>,
) -> Result<Vec<u8>, PyErr> {
let addr_vk = if let Some(addr_vk) = addr_vk {
let addr_vk = H160Flag::from(addr_vk);
Some(addr_vk)
} else {
None
};
crate::execute::encode_evm_calldata(proof, calldata, addr_vk).map_err(|e| {
crate::execute::encode_evm_calldata(proof, calldata, vka_path).map_err(|e| {
let err_str = format!("Failed to generate calldata: {}", e);
PyRuntimeError::new_err(err_str)
})
@@ -1652,15 +1689,15 @@ fn create_evm_verifier(
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifying key.
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
/// vka_path: str
/// The path to the verification artifact calldata bytes file.
///
/// srs_path: str
/// The path to the SRS file
///
/// decimals: int
/// The number of decimals used for the rescaling of fixed point felt instances into on-chain floats.
///
/// Returns
/// -------
/// bool
@@ -1668,21 +1705,21 @@ fn create_evm_verifier(
#[pyfunction(signature = (
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None
vka_path=PathBuf::from(DEFAULT_VKA),
srs_path=None,
decimals=DEFAULT_DECIMALS.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn create_evm_vka(
py: Python,
vk_path: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
vka_path: PathBuf,
srs_path: Option<PathBuf>,
decimals: usize,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, sol_code_path, abi_path)
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, vka_path, decimals)
.await
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier: {}", e);
@@ -1693,128 +1730,11 @@ fn create_evm_vka(
})
}
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// input_data: str
/// The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifier
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
input_data=PathBuf::from(DEFAULT_DATA),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
witness_path=None,
))]
#[gen_stub_pyfunction]
fn create_evm_data_attestation(
py: Python,
input_data: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
witness_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_data_attestation(
settings_path,
sol_code_path,
abi_path,
input_data,
witness_path,
)
.await
.map_err(|e| {
let err_str = format!("Failed to run create_evm_data_attestation: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// Setup test evm witness
///
/// Arguments
/// ---------
/// data_path: str
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
///
/// compiled_circuit_path: str
/// The path to the compiled model file (generated using the compile-circuit command)
///
/// test_data: str
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
///
/// input_sources: str
/// Where the input data comes from
///
/// output_source: str
/// Where the output data comes from
///
/// rpc_url: str
/// RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
data_path,
compiled_circuit_path,
test_data,
input_source,
output_source,
rpc_url=None,
))]
#[gen_stub_pyfunction]
fn setup_test_evm_witness(
py: Python,
data_path: PathBuf,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
input_source: PyTestDataSource,
output_source: PyTestDataSource,
rpc_url: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::setup_test_evm_witness(
data_path,
compiled_circuit_path,
test_data,
rpc_url,
input_source.into(),
output_source.into(),
)
.await
.map_err(|e| {
let err_str = format!("Failed to run setup_test_evm_witness: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// deploys the solidity verifier
/// Deploys the solidity verifier
#[pyfunction(signature = (
addr_path,
rpc_url,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
rpc_url=None,
contract_type=ContractType::default(),
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
@@ -1823,8 +1743,8 @@ fn setup_test_evm_witness(
fn deploy_evm(
py: Python,
addr_path: PathBuf,
rpc_url: String,
sol_code_path: PathBuf,
rpc_url: Option<String>,
contract_type: ContractType,
optimizer_runs: usize,
private_key: Option<String>,
@@ -1848,46 +1768,64 @@ fn deploy_evm(
})
}
/// deploys the solidity da verifier
/// Registers a VKA on the EZKL reusable verifier contract
///
/// Arguments
/// ---------
/// addr_verifier: str
/// The reusable verifier contract's address as a hex string
///
/// rpc_url: str
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
///
/// vka_path: str
/// The path to the VKA calldata bytes file (generated using the create_evm_vka command)
///
/// vka_digest_path: str
/// The path to the VKA digest file, aka hash of the VKA calldata bytes file
///
/// private_key: str
/// The private key to use for signing the transaction. If None, will use the default private key
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
addr_path,
input_data,
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
rpc_url=None,
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None
addr_verifier,
rpc_url,
vka_path=PathBuf::from(DEFAULT_VKA),
vka_digest_path=PathBuf::from(DEFAULT_VKA_DIGEST),
private_key=None,
))]
#[gen_stub_pyfunction]
fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
input_data: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
optimizer_runs: usize,
fn register_vka<'a>(
py: Python<'a>,
addr_verifier: &'a str,
rpc_url: String,
vka_path: PathBuf,
vka_digest_path: PathBuf,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
) -> PyResult<Bound<'a, PyAny>> {
let addr_verifier = H160Flag::from(addr_verifier);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::deploy_da_evm(
input_data,
settings_path,
sol_code_path,
crate::execute::register_vka(
rpc_url,
addr_path,
optimizer_runs,
addr_verifier,
vka_path,
vka_digest_path,
private_key,
)
.await
.map_err(|e| {
let err_str = format!("Failed to run deploy_da_evm: {}", e);
let err_str = format!("Failed to run register_vka: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// verifies an evm compatible proof, you will need solc installed in your environment to run this
///
/// Arguments
@@ -1901,47 +1839,30 @@ fn deploy_da_evm(
/// rpc_url: str
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
///
/// addr_da: str
/// does the verifier use data attestation ?
///
/// addr_vk: str
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
/// vka_path: str
/// The path to the VKA calldata bytes file (generated using the create_evm_vka command)
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
addr_verifier,
rpc_url,
proof_path=PathBuf::from(DEFAULT_PROOF),
rpc_url=None,
addr_da = None,
addr_vk = None,
vka_path = None,
))]
#[gen_stub_pyfunction]
fn verify_evm<'a>(
py: Python<'a>,
addr_verifier: &'a str,
rpc_url: String,
proof_path: PathBuf,
rpc_url: Option<String>,
addr_da: Option<&'a str>,
addr_vk: Option<&'a str>,
vka_path: Option<PathBuf>,
) -> PyResult<Bound<'a, PyAny>> {
let addr_verifier = H160Flag::from(addr_verifier);
let addr_da = if let Some(addr_da) = addr_da {
let addr_da = H160Flag::from(addr_da);
Some(addr_da)
} else {
None
};
let addr_vk = if let Some(addr_vk) = addr_vk {
let addr_vk = H160Flag::from(addr_vk);
Some(addr_vk)
} else {
None
};
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk)
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, vka_path)
.await
.map_err(|e| {
let err_str = format!("Failed to run verify_evm: {}", e);
@@ -2055,6 +1976,7 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_srs, m)?)?;
m.add_function(wrap_pyfunction!(gen_witness, m)?)?;
m.add_function(wrap_pyfunction!(gen_settings, m)?)?;
m.add_function(wrap_pyfunction!(gen_random_data, m)?)?;
m.add_function(wrap_pyfunction!(calibrate_settings, m)?)?;
m.add_function(wrap_pyfunction!(aggregate, m)?)?;
m.add_function(wrap_pyfunction!(mock_aggregate, m)?)?;
@@ -2064,12 +1986,10 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_vka, m)?)?;
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
m.add_function(wrap_pyfunction!(register_vka, m)?)?;
Ok(())
}

View File

@@ -1,6 +1,7 @@
use halo2_proofs::{
plonk::*,
poly::{
VerificationStrategy,
commitment::{CommitmentScheme, ParamsProver},
ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
@@ -12,7 +13,6 @@ use halo2_proofs::{
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy as KZGSingleStrategy,
},
VerificationStrategy,
},
};
use std::fmt::Display;
@@ -20,15 +20,15 @@ use std::io::BufReader;
use std::str::FromStr;
use crate::{
CheckMode, Commitments, EZKLError as InnerEZKLError,
circuit::region::RegionSettings,
graph::GraphSettings,
pfsys::{
create_proof_circuit,
TranscriptType, create_proof_circuit,
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
verify_proof_circuit, TranscriptType,
verify_proof_circuit,
},
tensor::TensorType,
CheckMode, Commitments, EZKLError as InnerEZKLError,
};
use crate::graph::{GraphCircuit, GraphWitness};
@@ -66,26 +66,24 @@ impl From<InnerEZKLError> for EZKLError {
pub(crate) fn encode_verifier_calldata(
// TODO - shuold it be pub(crate) or pub or pub(super)?
proof: Vec<u8>,
vk_address: Option<Vec<u8>>,
vka: Option<Vec<u8>>,
) -> Result<Vec<u8>, EZKLError> {
let snark: crate::pfsys::Snark<Fr, G1Affine> =
serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?;
let vk_address: Option<[u8; 20]> = if let Some(vk_address) = vk_address {
let array: [u8; 20] =
serde_json::from_slice(&vk_address[..]).map_err(InnerEZKLError::from)?;
let vka_buf: Option<Vec<[u8; 32]>> = if let Some(vka) = vka {
let array: Vec<[u8; 32]> =
serde_json::from_slice(&vka[..]).map_err(InnerEZKLError::from)?;
Some(array)
} else {
None
};
let vka: Option<&[[u8; 32]]> = vka_buf.as_deref();
let flattened_instances = snark.instances.into_iter().flatten();
let encoded = encode_calldata(
vk_address,
&snark.proof,
&flattened_instances.collect::<Vec<_>>(),
);
let encoded = encode_calldata(vka, &snark.proof, &flattened_instances.collect::<Vec<_>>());
Ok(encoded)
}
@@ -141,10 +139,11 @@ pub(crate) fn gen_vk(
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
let mut serialized_vk = Vec::new();
vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes)
.map_err(|e| {
EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e))
})?;
vk.write(
&mut serialized_vk,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
Ok(serialized_vk)
}
@@ -165,7 +164,7 @@ pub(crate) fn gen_pk(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
@@ -197,7 +196,7 @@ pub(crate) fn verify(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings.clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
@@ -277,7 +276,7 @@ pub(crate) fn verify_aggr(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
@@ -365,7 +364,7 @@ pub(crate) fn prove(
let mut reader = BufReader::new(&pk[..]);
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
@@ -487,7 +486,7 @@ pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
let mut reader = BufReader::new(&vk[..]);
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
@@ -504,7 +503,7 @@ pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
let mut reader = BufReader::new(&pk[..]);
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;

View File

@@ -8,10 +8,7 @@ use crate::{
Module,
},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
graph::{
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
GraphSettings,
},
graph::{quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings},
};
use console_error_panic_hook;
use halo2_proofs::{
@@ -231,10 +228,7 @@ pub fn poseidonHash(
let message: Vec<Fr> = serde_json::from_slice(&message[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?;
let output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>::run(
message.clone(),
)
let output = PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>::run(message.clone())
.map_err(|e| JsError::new(&format!("{}", e)))?;
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err(

View File

@@ -1,7 +1,7 @@
/*
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
use std::collections::HashMap;

View File

@@ -1,20 +1,18 @@
/*
An easy-to-use implementation of the Poseidon Hash in the form of a Halo2 Chip. While the Poseidon Hash function
is already implemented in halo2_gadgets, there is no wrapper chip that makes it easy to use in other circuits.
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/zk_prover/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
pub mod poseidon_params;
pub mod spec;
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
use halo2_gadgets::poseidon::{primitives::*, Hash, Pow5Chip, Pow5Config};
use halo2_proofs::arithmetic::Field;
use halo2_gadgets::poseidon::{
primitives::VariableLength, primitives::*, Hash, Pow5Chip, Pow5Config,
};
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::{circuit::*, plonk::*};
// use maybe_rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
use maybe_rayon::prelude::ParallelIterator;
use maybe_rayon::slice::ParallelSlice;
use std::marker::PhantomData;
@@ -40,22 +38,17 @@ pub struct PoseidonConfig<const WIDTH: usize, const RATE: usize> {
pub pow5_config: Pow5Config<Fp, WIDTH, RATE>,
}
type InputAssignments = (Vec<AssignedCell<Fp, Fp>>, AssignedCell<Fp, Fp>);
type InputAssignments = Vec<AssignedCell<Fp, Fp>>;
/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash
#[derive(Debug, Clone)]
pub struct PoseidonChip<
S: Spec<Fp, WIDTH, RATE> + Sync,
const WIDTH: usize,
const RATE: usize,
const L: usize,
> {
pub struct PoseidonChip<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> {
config: PoseidonConfig<WIDTH, RATE>,
_marker: PhantomData<S>,
}
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
PoseidonChip<S, WIDTH, RATE, L>
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
PoseidonChip<S, WIDTH, RATE>
{
/// Creates a new PoseidonChip
pub fn configure_with_cols(
@@ -82,8 +75,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
}
}
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
PoseidonChip<S, WIDTH, RATE, L>
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize>
PoseidonChip<S, WIDTH, RATE>
{
/// Configuration of the PoseidonChip
pub fn configure_with_optional_instance(
@@ -100,9 +93,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
for input in hash_inputs.iter().take(WIDTH) {
meta.enable_equality(*input);
}
meta.enable_constant(rc_b[0]);
Self::configure_with_cols(
@@ -116,8 +106,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
}
}
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, const L: usize>
Module<Fp> for PoseidonChip<S, WIDTH, RATE, L>
impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize> Module<Fp>
for PoseidonChip<S, WIDTH, RATE>
{
type Config = PoseidonConfig<WIDTH, RATE>;
type InputAssignments = InputAssignments;
@@ -152,9 +142,6 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::<Vec<_>>();
for input in hash_inputs.iter().take(WIDTH) {
meta.enable_equality(*input);
}
meta.enable_constant(rc_b[0]);
let instance = meta.instance_column();
@@ -176,7 +163,10 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, ModuleError> {
assert_eq!(message.len(), 1);
if message.len() != 1 {
return Err(ModuleError::InputWrongLength(message.len()));
}
let message = message[0].clone();
let start_time = instant::Instant::now();
@@ -186,95 +176,81 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let res = layouter.assign_region(
|| "load message",
|mut region| {
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
match &message {
ValTensor::Value { inner: v, .. } => {
v.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, _> = match &message {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.map(|(i, value)| {
let x = i % WIDTH;
let y = i / WIDTH;
match value {
ValType::Value(v) => region
.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
)
.map_err(|e| e.into()),
ValType::PrevAssigned(v)
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants
.get(f)
.unwrap()
.assigned_cell()
.ok_or(ModuleError::ConstantNotAssigned)?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants.insert(
*f,
ValType::AssignedConstant(res.clone(), *f),
);
Ok(res)
}
}
e => Err(ModuleError::WrongInputType(
format!("{:?}", e),
"PrevAssigned".to_string(),
)),
}
})
.collect()
}
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
match value {
ValType::Value(v) => region
.assign_advice(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
|| *v,
)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into())
}
};
.map_err(|e| e.into()),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants
.get(f)
.unwrap()
.assigned_cell()
.ok_or(ModuleError::ConstantNotAssigned)?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
let offset = message.len() / WIDTH + 1;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
let zero_val = region
.assign_advice_from_constant(
|| "",
self.config.hash_inputs[0],
offset,
Fp::ZERO,
)
.unwrap();
Ok(res)
}
}
e => Err(ModuleError::WrongInputType(
format!("{:?}", e),
"AssignedValue".to_string(),
)),
}
})
.collect(),
ValTensor::Instance {
dims,
inner: col,
idx,
initial_offset,
..
} => {
// this should never ever fail
let num_elems = dims[*idx].iter().product::<usize>();
(0..num_elems)
.map(|i| {
let x = i % WIDTH;
let y = i / WIDTH;
region.assign_advice_from_instance(
|| "pub input anchor",
*col,
initial_offset + i,
self.config.hash_inputs[x],
y,
)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into())
}
};
Ok((assigned_message?, zero_val))
Ok(assigned_message?)
},
);
log::trace!(
@@ -295,7 +271,13 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, ModuleError> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
let input_cells = self.layout_inputs(layouter, input, constants)?;
// empty hash case
if input_cells.is_empty() {
return Ok(input[0].clone());
}
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
input_cells.iter().map(|e| ValType::from(e.clone())).into();
@@ -303,52 +285,25 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
let start_time = instant::Instant::now();
let mut one_iter = false;
// do the Tree dance baby
while input_cells.len() > 1 || !one_iter {
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
.chunks(L)
.enumerate()
.map(|(i, block)| {
let _start_time = instant::Instant::now();
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
// initialize the hasher
let hasher = Hash::<_, _, S, VariableLength, WIDTH, RATE>::init(
pow5_chip,
layouter.namespace(|| "block_hasher"),
)?;
let mut block = block.to_vec();
let remainder = block.len() % L;
if remainder != 0 {
block.extend(vec![zero_val.clone(); L - remainder]);
}
let pow5_chip = Pow5Chip::construct(self.config.pow5_config.clone());
// initialize the hasher
let hasher = Hash::<_, _, S, ConstantLength<L>, WIDTH, RATE>::init(
pow5_chip,
layouter.namespace(|| "block_hasher"),
)?;
let hash = hasher.hash(
layouter.namespace(|| "hash"),
block.to_vec().try_into().map_err(|_| Error::Synthesis)?,
);
if i == 0 {
log::trace!("block (L={:?}) took: {:?}", L, _start_time.elapsed());
}
hash
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.into());
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
one_iter = true;
input_cells = hashes?;
}
let hash: AssignedCell<Fp, Fp> = hasher.hash(
layouter.namespace(|| "hash"),
input_cells
.to_vec()
.try_into()
.map_err(|_| Error::Synthesis)?,
)?;
let duration = start_time.elapsed();
log::trace!("layout (N={:?}) took: {:?}", len, duration);
let result = Tensor::from(input_cells.iter().map(|e| ValType::from(e.clone())));
let result = Tensor::from(vec![ValType::from(hash.clone())].into_iter());
let output = match result[0].clone() {
ValType::PrevAssigned(v) => v,
@@ -387,69 +342,59 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
///
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
let mut hash_inputs = message;
let len = hash_inputs.len();
let len = message.len();
if len == 0 {
return Ok(vec![vec![]]);
}
let start_time = instant::Instant::now();
let mut one_iter = false;
// do the Tree dance baby
while hash_inputs.len() > 1 || !one_iter {
let hashes: Vec<Fp> = hash_inputs
.par_chunks(L)
.map(|block| {
let mut block = block.to_vec();
let remainder = block.len() % L;
if remainder != 0 {
block.extend(vec![Fp::ZERO; L - remainder].iter());
}
let block_len = block.len();
let message = block
.try_into()
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
Ok(halo2_gadgets::poseidon::primitives::Hash::<
_,
S,
ConstantLength<L>,
{ WIDTH },
{ RATE },
>::init()
.hash(message))
})
.collect::<Result<Vec<_>, ModuleError>>()?;
one_iter = true;
hash_inputs = hashes;
}
let hash = halo2_gadgets::poseidon::primitives::Hash::<
_,
S,
VariableLength,
{ WIDTH },
{ RATE },
>::init()
.hash(message);
let duration = start_time.elapsed();
log::trace!("run (N={:?}) took: {:?}", len, duration);
Ok(vec![hash_inputs])
Ok(vec![vec![hash]])
}
fn num_rows(mut input_len: usize) -> usize {
fn num_rows(input_len: usize) -> usize {
// this was determined by running the circuit and looking at the number of constraints
// in the test called hash_for_a_range_of_input_sizes, then regressing in python to find the slope
let fixed_cost: usize = 41 * L;
// import numpy as np
// from scipy import stats
let mut num_rows = 0;
// x = np.array([32, 64, 96, 128, 160, 192])
// y = np.array([1298, 2594, 3890, 5186, 6482, 7778])
loop {
// the number of times the input_len is divisible by L
let num_chunks = input_len / L + 1;
num_rows += num_chunks * fixed_cost;
if num_chunks == 1 {
break;
}
input_len = num_chunks;
}
// slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
num_rows
// print(f"slope: {slope}")
// print(f"intercept: {intercept}")
// print(f"R^2: {r_value**2}")
// # Predict for any x
// def predict(x):
// return slope * x + intercept
// # Test prediction
// test_x = 256
// print(f"Predicted value for x={test_x}: {predict(test_x)}")
// our output:
// slope: 40.5
// intercept: 2.0
// R^2: 1.0
// Predicted value for x=256: 10370.0
let fixed_cost: usize = 41 * input_len;
// the cost of the hash function is linear with the number of inputs
fixed_cost + 2
}
}
@@ -476,12 +421,12 @@ mod tests {
const RATE: usize = POSEIDON_RATE;
const R: usize = 240;
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>, const L: usize> {
struct HashCircuit<S: Spec<Fp, WIDTH, RATE>> {
message: ValTensor<Fp>,
_spec: PhantomData<S>,
}
impl<S: Spec<Fp, WIDTH, RATE>, const L: usize> Circuit<Fp> for HashCircuit<S, L> {
impl<S: Spec<Fp, WIDTH, RATE>> Circuit<Fp> for HashCircuit<S> {
type Config = PoseidonConfig<WIDTH, RATE>;
type FloorPlanner = ModulePlanner;
type Params = ();
@@ -497,7 +442,7 @@ mod tests {
}
fn configure(meta: &mut ConstraintSystem<Fp>) -> PoseidonConfig<WIDTH, RATE> {
PoseidonChip::<PoseidonSpec, WIDTH, RATE, L>::configure(meta, ())
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(meta, ())
}
fn synthesize(
@@ -505,7 +450,7 @@ mod tests {
config: PoseidonConfig<WIDTH, RATE>,
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> = PoseidonChip::new(config);
chip.layout(
&mut layouter,
&[self.message.clone()],
@@ -517,18 +462,33 @@ mod tests {
}
}
#[test]
fn poseidon_hash_empty() {
let message = [];
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 9;
let circuit = HashCircuit::<PoseidonSpec> {
message: message.into(),
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, vec![vec![]]).unwrap();
assert_eq!(prover.verify(), Ok(()))
}
#[test]
fn poseidon_hash() {
let rng = rand::rngs::OsRng;
let message = [Fp::random(rng), Fp::random(rng)];
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 2>::run(message.to_vec()).unwrap();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 9;
let circuit = HashCircuit::<PoseidonSpec, 2> {
let circuit = HashCircuit::<PoseidonSpec> {
message: message.into(),
_spec: PhantomData,
};
@@ -541,13 +501,13 @@ mod tests {
let rng = rand::rngs::OsRng;
let message = [Fp::random(rng), Fp::random(rng), Fp::random(rng)];
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 3>::run(message.to_vec()).unwrap();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.to_vec()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 9;
let circuit = HashCircuit::<PoseidonSpec, 3> {
let circuit = HashCircuit::<PoseidonSpec> {
message: message.into(),
_spec: PhantomData,
};
@@ -563,23 +523,21 @@ mod tests {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
env_logger::init();
{
let i = 32;
for i in (32..128).step_by(32) {
// print a bunch of new lines
println!(
log::info!(
"i is {} -------------------------------------------------",
i
);
let message: Vec<Fp> = (0..i).map(|_| Fp::random(rng)).collect::<Vec<_>>();
let output =
PoseidonChip::<PoseidonSpec, WIDTH, RATE, 32>::run(message.clone()).unwrap();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 17;
let circuit = HashCircuit::<PoseidonSpec, 32> {
let circuit = HashCircuit::<PoseidonSpec> {
message: message.into(),
_spec: PhantomData,
};
@@ -596,13 +554,13 @@ mod tests {
let mut message: Vec<Fp> = (0..2048).map(|_| Fp::random(rng)).collect::<Vec<_>>();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE, 25>::run(message.clone()).unwrap();
let output = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(message.clone()).unwrap();
let mut message: Tensor<ValType<Fp>> =
message.into_iter().map(|m| Value::known(m).into()).into();
let k = 17;
let circuit = HashCircuit::<PoseidonSpec, 25> {
let circuit = HashCircuit::<PoseidonSpec> {
message: message.into(),
_spec: PhantomData,
};

View File

@@ -17,12 +17,14 @@ pub enum BaseOp {
Sub,
SumInit,
Sum,
IsBoolean,
}
/// Matches a [BaseOp] to an operation over inputs
impl BaseOp {
/// forward func
/// forward func for non-accumulating operations
/// # Panics
/// Panics if called on an accumulating operation
/// # Examples
pub fn nonaccum_f<
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
>(
@@ -34,12 +36,13 @@ impl BaseOp {
BaseOp::Add => a + b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
}
/// forward func
/// forward func for accumulating operations
/// # Panics
/// Panics if called on a non-accumulating operation
pub fn accum_f<
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
>(
@@ -74,7 +77,6 @@ impl BaseOp {
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -90,7 +92,6 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -106,7 +107,6 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::IsBoolean => 0,
}
}
@@ -122,7 +122,6 @@ impl BaseOp {
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsBoolean => 0,
}
}
}

View File

@@ -2,7 +2,7 @@ use std::str::FromStr;
use halo2_proofs::{
circuit::Layouter,
plonk::{ConstraintSystem, Constraints, Expression, Selector},
plonk::{ConstraintSystem, Constraints, Expression, Selector, TableColumn},
poly::Rotation,
};
use log::debug;
@@ -20,7 +20,6 @@ use crate::{
circuit::{
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
@@ -75,51 +74,12 @@ impl FromStr for CheckMode {
}
}
#[allow(missing_docs)]
/// An enum representing the tolerance we can accept for the accumulated arguments, either absolute or percentage
#[derive(Clone, Default, Debug, PartialEq, PartialOrd, Serialize, Deserialize, Copy)]
pub struct Tolerance {
pub val: f32,
pub scale: utils::F32,
}
impl std::fmt::Display for Tolerance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.2}", self.val)
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Tolerance {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl FromStr for Tolerance {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(val) = s.parse::<f32>() {
Ok(Tolerance {
val,
scale: utils::F32(1.0),
})
} else {
Err(
"Invalid tolerance value provided. It should expressed as a percentage (f32)."
.to_string(),
)
}
}
}
impl From<f32> for Tolerance {
fn from(value: f32) -> Self {
Tolerance {
val: value,
scale: utils::F32(1.0),
impl CheckMode {
/// Returns the value of the check mode
pub fn is_safe(&self) -> bool {
match self {
CheckMode::SAFE => true,
CheckMode::UNSAFE => false,
}
}
}
@@ -148,29 +108,6 @@ impl<'source> FromPyObject<'source> for CheckMode {
}
}
#[cfg(feature = "python-bindings")]
/// Converts Tolerance into a PyObject (Required for Tolerance to be compatible with Python)
impl IntoPy<PyObject> for Tolerance {
fn into_py(self, py: Python) -> PyObject {
(self.val, self.scale.0).to_object(py)
}
}
#[cfg(feature = "python-bindings")]
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
impl<'source> FromPyObject<'source> for Tolerance {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
Ok(Tolerance {
val,
scale: utils::F32(scale),
})
} else {
Err(PyValueError::new_err("Invalid tolerance value provided. "))
}
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct DynamicLookups {
@@ -205,15 +142,16 @@ impl DynamicLookups {
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct Shuffles {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub input_selectors: BTreeMap<(usize, (usize, usize)), Selector>,
/// Selectors for the dynamic lookup tables
pub reference_selectors: Vec<Selector>,
pub output_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub references: Vec<VarTensor>,
pub outputs: Vec<VarTensor>,
}
impl Shuffles {
@@ -224,9 +162,13 @@ impl Shuffles {
Self {
input_selectors: BTreeMap::new(),
reference_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone()],
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
output_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
outputs: vec![
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
],
}
}
}
@@ -326,6 +268,8 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
/// shared table inputs
pub shared_table_inputs: Vec<TableColumn>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
@@ -338,6 +282,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
shared_table_inputs: vec![],
_marker: PhantomData,
}
}
@@ -364,13 +309,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
if inputs[0].num_cols() != output.num_cols() {
log::warn!("input and output shapes do not match");
}
if inputs[0].num_inner_cols() != inputs[1].num_inner_cols() {
log::warn!("input number of inner columns do not match");
}
if inputs[0].num_inner_cols() != output.num_inner_cols() {
log::warn!("input and output number of inner columns do not match");
}
for i in 0..output.num_blocks() {
for j in 0..output.num_inner_cols() {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -404,24 +354,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
// Get output expressions for each input channel
let (rotation_offset, rng) = base_op.query_offset_rng();
let constraints = match base_op {
BaseOp::IsBoolean => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
let constraints = {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("non accum: output query failed");
let output = expected_output[base_op.constraint_idx()].clone();
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
.expect("non accum: output query failed");
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.constraint_idx()].clone() - res]
}
let res = base_op.nonaccum_f((qis[0].clone(), qis[1].clone()));
vec![expected_output[base_op.constraint_idx()].clone() - res]
};
Constraints::with_selector(selector, constraints)
@@ -476,6 +415,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
dynamic_lookups: DynamicLookups::default(),
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
shared_table_inputs: vec![],
check_mode,
_marker: PhantomData,
}
@@ -506,21 +446,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
return Err(CircuitError::WrongColumnType(output.name().to_string()));
}
// we borrow mutably twice so we need to do this dance
let table = if !self.static_lookups.tables.contains_key(nl) {
// as all tables have the same input we see if there's another table who's input we can reuse
let table = if let Some(table) = self.static_lookups.tables.values().next() {
Table::<F>::configure(
cs,
lookup_range,
logrows,
nl,
Some(table.table_inputs.clone()),
)
} else {
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
};
let table =
Table::<F>::configure(cs, lookup_range, logrows, nl, &mut self.shared_table_inputs);
self.static_lookups.tables.insert(nl.clone(), table.clone());
table
} else {
@@ -571,9 +499,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
// this is 0 if the index is the same as the column index (starting from 1)
let col_expr = sel.clone()
* table
* (table
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
.get_expr_at_idx(col_idx, synthetic_sel));
let multiplier =
table.selector_constructor.get_selector_val_at_idx(col_idx);
@@ -605,6 +533,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
res
});
}
// add a degree-k custom constraint of the following form to the range check and
// static lookup configuration.
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
cs.create_gate("range_check_on_sel", |cs| {
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let range_check_on_synthetic_sel = match len {
1 => Expression::Constant(F::from(0)),
_ => {
let mut initial_expr = Expression::Constant(F::from(1));
for i in 0..len {
initial_expr = initial_expr
* (synthetic_sel.clone()
- Expression::Constant(F::from(i as u64)))
}
initial_expr
}
};
let sel = cs.query_selector(multi_col_selector);
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
});
self.static_lookups
.selectors
.insert((nl.clone(), x, y), multi_col_selector);
@@ -730,8 +692,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn configure_shuffles(
&mut self,
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
inputs: &[VarTensor; 3],
outputs: &[VarTensor; 3],
) -> Result<(), CircuitError>
where
F: Field,
@@ -742,14 +704,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
}
}
for t in references.iter() {
for t in outputs.iter() {
if !t.is_advice() || t.num_inner_cols() > 1 {
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
}
}
// assert all tables have the same number of blocks
if references
if outputs
.iter()
.map(|t| t.num_blocks())
.collect::<Vec<_>>()
@@ -757,23 +719,23 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
.any(|w| w[0] != w[1])
{
return Err(CircuitError::WrongDynamicColumnType(
"references inner cols".to_string(),
"outputs inner cols".to_string(),
));
}
let one = Expression::Constant(F::ONE);
for q in 0..references[0].num_blocks() {
let s_reference = cs.complex_selector();
for q in 0..outputs[0].num_blocks() {
let s_output = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
cs.lookup_any("shuffle", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_referenceq = cs.query_selector(s_reference);
let s_outputq = cs.query_selector(s_output);
let mut input_queries = vec![one.clone()];
for input in inputs {
@@ -785,9 +747,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
});
}
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
let mut output_queries = vec![one.clone()];
for output in outputs {
output_queries.push(match output {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[q][0], Rotation(0))
}
@@ -796,7 +758,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
}
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
let rhs = output_queries.into_iter().map(|c| c * s_outputq.clone());
expression.extend(lhs.zip(rhs));
expression
@@ -807,13 +769,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
.or_insert(s_input);
}
}
self.shuffles.reference_selectors.push(s_reference);
self.shuffles.output_selectors.push(s_output);
}
// if we haven't previously initialized the input/output, do so now
if self.shuffles.references.is_empty() {
debug!("assigning shuffles reference");
self.shuffles.references = references.to_vec();
if self.shuffles.outputs.is_empty() {
debug!("assigning shuffles output");
self.shuffles.outputs = outputs.to_vec();
}
if self.shuffles.inputs.is_empty() {
debug!("assigning shuffles input");
@@ -845,7 +807,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
self.range_checks.ranges.entry(range)
{
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
@@ -883,9 +844,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
let default_x = range_check.get_first_element(col_idx);
let col_expr = sel.clone()
* range_check
* (range_check
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
.get_expr_at_idx(col_idx, synthetic_sel));
let multiplier = range_check
.selector_constructor
@@ -908,6 +869,40 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
res
});
}
// add a degree-k custom constraint of the following form to the range check and
// static lookup configuration.
// 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 · ∏ (𝑠𝑒𝑙 𝑖) = 0 where 𝑠𝑒𝑙 is the synthetic_sel, and the product is over the set of overflowed columns
// and 𝑚𝑢𝑙𝑡𝑖𝑠𝑒𝑙 is the selector value at the column index
cs.create_gate("range_check_on_sel", |cs| {
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let range_check_on_synthetic_sel = match len {
1 => Expression::Constant(F::from(0)),
_ => {
let mut initial_expr = Expression::Constant(F::from(1));
for i in 0..len {
initial_expr = initial_expr
* (synthetic_sel.clone()
- Expression::Constant(F::from(i as u64)))
}
initial_expr
}
};
let sel = cs.query_selector(multi_col_selector);
Constraints::with_selector(sel, vec![range_check_on_synthetic_sel])
});
self.range_checks
.selectors
.insert((range, x, y), multi_col_selector);

View File

@@ -25,7 +25,7 @@ pub enum CircuitError {
/// This operation is unsupported
#[error("unsupported operation in graph")]
UnsupportedOp,
///
/// Invalid einsum expression
#[error("invalid einsum expression")]
InvalidEinsum,
/// Flush error
@@ -100,4 +100,13 @@ pub enum CircuitError {
#[error("invalid input type {0}")]
/// Invalid input type
InvalidInputType(String),
#[error("an element is missing from the shuffled version of the tensor")]
/// An element is missing from the shuffled version of the tensor
MissingShuffleElement,
/// Visibility has not been set
#[error("visibility has not been set")]
UnsetVisibility,
/// A decomposition base overflowed
#[error("decomposition base overflowed")]
DecompositionBaseOverflow,
}

View File

@@ -1,9 +1,9 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::integer_rep_to_felt,
circuit::{layouts, utils},
fieldutils::{IntegerRep, integer_rep_to_felt},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
@@ -15,10 +15,12 @@ use serde::{Deserialize, Serialize};
pub enum HybridOp {
Ln {
scale: utils::F32,
eps: f64,
},
Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
},
Sqrt {
scale: utils::F32,
@@ -42,6 +44,7 @@ pub enum HybridOp {
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
},
Div {
denom: utils::F32,
@@ -57,11 +60,13 @@ pub enum HybridOp {
stride: Vec<usize>,
kernel_shape: Vec<usize>,
normalized: bool,
data_format: DataFormat,
},
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
data_format: DataFormat,
},
ReduceMin {
axes: Vec<usize>,
@@ -75,8 +80,11 @@ pub enum HybridOp {
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
eps: f64,
},
Output {
decomp: bool,
},
RangeCheck(Tolerance),
Greater,
GreaterEqual,
Less,
@@ -124,12 +132,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Rsqrt {
input_scale,
output_scale,
eps,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
"RSQRT (input_scale={}, output_scale={}, eps={})",
input_scale, output_scale, eps
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::Ln { scale, eps } => format!("LN(scale={}, eps={})", scale, eps),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
@@ -142,9 +151,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Recip {
input_scale,
output_scale,
eps,
} => format!(
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
"RECIP (input_scale={}, output_scale={}, eps={})",
input_scale, output_scale, eps
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
@@ -152,9 +162,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
normalized,
data_format,
} => format!(
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={})",
padding, stride, kernel_shape, normalized
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
padding, stride, kernel_shape, normalized, data_format
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
@@ -162,9 +173,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => format!(
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?}, data_format={:?})",
padding, stride, pool_dims, data_format
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
@@ -172,13 +184,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale,
output_scale,
axes,
eps,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
"SOFTMAX (input_scale={}, output_scale={}, axes={:?}, eps={})",
input_scale, output_scale, axes, eps
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Output { decomp } => {
format!("OUTPUT (decomp={})", decomp)
}
HybridOp::Greater => "GREATER".to_string(),
HybridOp::GreaterEqual => "GREATEREQUAL".to_string(),
HybridOp::Less => "LESS".to_string(),
@@ -204,17 +219,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Rsqrt {
input_scale,
output_scale,
eps,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
*eps,
)?,
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::Ln { scale, eps } => {
layouts::ln(config, region, values[..].try_into()?, *scale, *eps)?
}
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
@@ -234,6 +253,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
normalized,
data_format,
} => layouts::sumpool(
config,
region,
@@ -242,16 +262,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
stride,
kernel_shape,
*normalized,
*data_format,
)?,
HybridOp::Recip {
input_scale,
output_scale,
eps,
} => layouts::recip(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as i128),
integer_rep_to_felt(output_scale.0 as i128),
integer_rep_to_felt(input_scale.0 as IntegerRep),
integer_rep_to_felt(output_scale.0 as IntegerRep),
*eps,
)?,
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
@@ -259,7 +282,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
config,
region,
values[..].try_into()?,
integer_rep_to_felt(denom.0 as i128),
integer_rep_to_felt(denom.0 as IntegerRep),
)?
} else {
layouts::nonlinearity(
@@ -282,6 +305,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
data_format,
} => layouts::max_pool(
config,
region,
@@ -289,6 +313,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
padding,
stride,
pool_dims,
*data_format,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?
@@ -306,6 +331,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale,
output_scale,
axes,
eps,
} => layouts::softmax_axes(
config,
region,
@@ -313,14 +339,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
*input_scale,
*output_scale,
axes,
*eps,
)?,
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
config,
region,
values[..].try_into()?,
tol.scale,
tol.val,
)?,
HybridOp::Output { decomp } => {
layouts::output(config, region, values[..].try_into()?, *decomp)?
}
HybridOp::Greater => layouts::greater(config, region, values[..].try_into()?)?,
HybridOp::GreaterEqual => {
layouts::greater_equal(config, region, values[..].try_into()?)?
@@ -357,6 +380,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
HybridOp::Ln {
scale: output_scale,
eps: _,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,8 @@
use std::any::Any;
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::prelude::DatumType;
use crate::{
graph::quantize_tensor,
@@ -96,6 +98,8 @@ pub enum InputType {
Int,
///
TDim,
///
Unknown,
}
impl InputType {
@@ -132,6 +136,7 @@ impl InputType {
let int_input = input.clone().to_i64().unwrap();
*input = T::from_i64(int_input).unwrap();
}
InputType::Unknown => {}
}
}
}
@@ -152,6 +157,30 @@ impl std::str::FromStr for InputType {
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<DatumType> for InputType {
/// # Panics
/// Panics if the datum type is not supported
fn from(datum_type: DatumType) -> Self {
match datum_type {
DatumType::Bool => InputType::Bool,
DatumType::F16 => InputType::F16,
DatumType::F32 => InputType::F32,
DatumType::F64 => InputType::F64,
DatumType::I8 => InputType::Int,
DatumType::I16 => InputType::Int,
DatumType::I32 => InputType::Int,
DatumType::I64 => InputType::Int,
DatumType::U8 => InputType::Int,
DatumType::U16 => InputType::Int,
DatumType::U32 => InputType::Int,
DatumType::U64 => InputType::Int,
DatumType::TDim => InputType::TDim,
_ => unimplemented!(),
}
}
}
///
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Input {
@@ -159,6 +188,8 @@ pub struct Input {
pub scale: crate::Scale,
///
pub datum_type: InputType,
/// decomp check
pub decomp: bool,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
@@ -196,6 +227,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
config,
region,
values[..].try_into()?,
self.decomp,
)?)),
}
} else {
@@ -251,20 +283,26 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
#[serde(skip)]
pub pre_assigned_val: Option<ValTensor<F>>,
///
pub decomp: bool,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>, decomp: bool) -> Self {
Self {
quantized_values,
raw_values,
pre_assigned_val: None,
decomp,
}
}
/// Rebase the scale of the constant
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
let visibility = self.quantized_values.visibility().unwrap();
let visibility = match self.quantized_values.visibility() {
Some(v) => v,
None => return Err(CircuitError::UnsetVisibility),
};
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
Ok(())
}
@@ -281,13 +319,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
@@ -308,7 +341,12 @@ impl<
self.quantized_values.clone().try_into()?
};
// we gotta constrain it once if its used multiple times
Ok(Some(layouts::identity(config, region, &[value])?))
Ok(Some(layouts::identity(
config,
region,
&[value],
self.decomp,
)?))
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {

View File

@@ -4,6 +4,7 @@ use crate::{
utils::{self, F32},
},
tensor::{self, Tensor, TensorError},
tensor::{DataFormat, KernelFormat},
};
use super::{base::BaseOp, *};
@@ -43,10 +44,12 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
},
Downsample {
axis: usize,
stride: usize,
stride: isize,
modulo: usize,
},
DeConv {
@@ -54,6 +57,8 @@ pub enum PolyOp {
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
data_format: DataFormat,
kernel_format: KernelFormat,
},
Add,
Sub,
@@ -103,13 +108,8 @@ pub enum PolyOp {
}
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -165,10 +165,12 @@ impl<
stride,
padding,
group,
data_format,
kernel_format,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
"CONV (stride={:?}, padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, group, data_format, kernel_format
)
}
PolyOp::DeConv {
@@ -176,10 +178,12 @@ impl<
padding,
output_padding,
group,
data_format,
kernel_format,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, output_padding, group, data_format, kernel_format
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -242,6 +246,8 @@ impl<
padding,
stride,
group,
data_format,
kernel_format,
} => layouts::conv(
config,
region,
@@ -249,9 +255,17 @@ impl<
padding,
stride,
*group,
*data_format,
*kernel_format,
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
if values.len() != 1 {
return Err(TensorError::DimError(
"GatherElements only accepts single inputs".to_string(),
)
.into());
}
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
@@ -269,6 +283,12 @@ impl<
}
PolyOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
if values.len() != 2 {
return Err(TensorError::DimError(
"ScatterElements requires two inputs".to_string(),
)
.into());
}
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
@@ -297,6 +317,8 @@ impl<
output_padding,
stride,
group,
data_format,
kernel_format,
} => layouts::deconv(
config,
region,
@@ -305,13 +327,17 @@ impl<
output_padding,
stride,
*group,
*data_format,
*kernel_format,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Identity { .. } => {
layouts::identity(config, region, values[..].try_into()?, false)?
}
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {

View File

@@ -671,22 +671,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication(
pub fn assign_with_duplication_unconstrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
check_mode: &crate::circuit::CheckMode,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication(
let (res, len) = var.assign_with_duplication_unconstrained(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
Ok((res, len))
@@ -695,7 +690,37 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.row,
self.linear_coord,
values,
single_inner_col,
false,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_constrained(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
check_mode: &crate::circuit::CheckMode,
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len) = var.assign_with_duplication_constrained(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
&mut self.assigned_constants,
)?;
Ok((res, len))
} else {
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
true,
&mut self.assigned_constants,
)?;
Ok((values.clone(), len))

View File

@@ -132,21 +132,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
(first_element, op_f.output[0])
}
///
/// calculates the column size given the number of rows and reserved blinding rows
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(logrows as u32) - reserved_blinding_rows
}
///
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
}
///
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as IntegerRep)) as usize + 1
(range_len / col_size as IntegerRep) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
@@ -168,7 +163,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
range: Range,
logrows: usize,
nonlinearity: &LookupOp,
preexisting_inputs: Option<Vec<TableColumn>>,
preexisting_inputs: &mut Vec<TableColumn>,
) -> Table<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
@@ -177,28 +172,28 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
debug!("table range: {:?}", range);
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
let mut cols = vec![];
for _ in 0..num_cols {
cols.push(cs.lookup_table_column());
// validate enough columns are provided to store the range
if preexisting_inputs.len() < num_cols {
// add columns to match the required number of columns
let diff = num_cols - preexisting_inputs.len();
for _ in 0..diff {
preexisting_inputs.push(cs.lookup_table_column());
}
cols
});
let num_cols = table_inputs.len();
}
let num_cols = preexisting_inputs.len();
if num_cols > 1 {
warn!("Using {} columns for non-linearity table.", num_cols);
}
let table_outputs = table_inputs
let table_outputs = preexisting_inputs
.iter()
.map(|_| cs.lookup_table_column())
.collect::<Vec<_>>();
Table {
nonlinearity: nonlinearity.clone(),
table_inputs,
table_inputs: preexisting_inputs.clone(),
table_outputs,
is_assigned: false,
selector_constructor: SelectorConstructor::new(num_cols),
@@ -355,16 +350,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
}
///
/// calculates the column size
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(logrows as u32) - reserved_blinding_rows
}
///
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in

View File

@@ -1,5 +1,6 @@
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{DataFormat, KernelFormat};
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -1040,6 +1041,10 @@ mod conv {
let a = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
let b = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
let output = VarTensor::new_advice(cs, K, 1, (LEN + 1) * LEN);
// column for constants
let _constant = VarTensor::constant_cols(cs, K, 8, false);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
@@ -1061,6 +1066,8 @@ mod conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1171,7 +1178,7 @@ mod conv_col_ultra_overflow {
use super::*;
const K: usize = 4;
const K: usize = 6;
const LEN: usize = 10;
#[derive(Clone)]
@@ -1191,9 +1198,10 @@ mod conv_col_ultra_overflow {
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN);
let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * LEN * LEN * LEN, false);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
@@ -1215,6 +1223,8 @@ mod conv_col_ultra_overflow {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
@@ -1372,6 +1382,8 @@ mod conv_relu_col_ultra_overflow {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
data_format: DataFormat::default(),
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis);
@@ -1776,13 +1788,18 @@ mod shuffle {
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
config
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.configure_shuffles(
cs,
&[a.clone(), b.clone(), c.clone()],
&[d.clone(), e.clone(), f.clone()],
)
.unwrap();
config
}
@@ -1803,6 +1820,7 @@ mod shuffle {
&mut region,
&self.inputs[i],
&self.references[i],
layouts::SortCollisionMode::Unsorted,
)
.map_err(|_| Error::Synthesis)?;
}
@@ -1988,7 +2006,7 @@ mod add_with_overflow_and_poseidon {
let base = BaseConfig::configure(cs, &[a, b], &output, CheckMode::SAFE);
VarTensor::constant_cols(cs, K, 2, false);
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::configure(cs, ());
let poseidon = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::configure(cs, ());
MyCircuitConfig { base, poseidon }
}
@@ -1998,7 +2016,7 @@ mod add_with_overflow_and_poseidon {
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE> =
PoseidonChip::new(config.poseidon.clone());
let assigned_inputs_a =
@@ -2033,11 +2051,9 @@ mod add_with_overflow_and_poseidon {
let b = (0..LEN)
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
.collect::<Vec<_>>();
let commitment_a =
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone()).unwrap()[0][0];
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0];
let commitment_b =
PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone()).unwrap()[0][0];
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0];
// parameters
let a = Tensor::from(a.into_iter().map(Value::known));
@@ -2059,13 +2075,11 @@ mod add_with_overflow_and_poseidon {
let b = (0..LEN)
.map(|i| halo2curves::bn256::Fr::from(i as u64 + 1))
.collect::<Vec<_>>();
let commitment_a = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(a.clone())
.unwrap()[0][0]
+ Fr::one();
let commitment_a =
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(a.clone()).unwrap()[0][0] + Fr::one();
let commitment_b = PoseidonChip::<PoseidonSpec, WIDTH, RATE, WIDTH>::run(b.clone())
.unwrap()[0][0]
+ Fr::one();
let commitment_b =
PoseidonChip::<PoseidonSpec, WIDTH, RATE>::run(b.clone()).unwrap()[0][0] + Fr::one();
// parameters
let a = Tensor::from(a.into_iter().map(Value::known));

View File

@@ -1,3 +1,4 @@
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{generate, Generator, Shell};
@@ -11,7 +12,6 @@ use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
use crate::graph::TestDataSource;
use crate::pfsys::TranscriptType;
/// The default path to the .json data file
@@ -42,20 +42,14 @@ pub const DEFAULT_SPLIT: &str = "false";
pub const DEFAULT_VERIFIER_ABI: &str = "verifier_abi.json";
/// Default verifier abi for aggregated proofs
pub const DEFAULT_VERIFIER_AGGREGATED_ABI: &str = "verifier_aggr_abi.json";
/// Default verifier abi for data attestation
pub const DEFAULT_VERIFIER_DA_ABI: &str = "verifier_da_abi.json";
/// Default solidity code
pub const DEFAULT_SOL_CODE: &str = "evm_deploy.sol";
/// Default calldata path
pub const DEFAULT_CALLDATA: &str = "calldata.bytes";
/// Default solidity code for aggregated proofs
pub const DEFAULT_SOL_CODE_AGGREGATED: &str = "evm_deploy_aggr.sol";
/// Default solidity code for data attestation
pub const DEFAULT_SOL_CODE_DA: &str = "evm_deploy_da.sol";
/// Default contract address
pub const DEFAULT_CONTRACT_ADDRESS: &str = "contract.address";
/// Default contract address for data attestation
pub const DEFAULT_CONTRACT_ADDRESS_DA: &str = "contract_da.address";
/// Default contract address for vk
pub const DEFAULT_CONTRACT_ADDRESS_VK: &str = "contract_vk.address";
/// Default check mode
@@ -78,18 +72,24 @@ pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
pub const DEFAULT_RENDER_REUSABLE: &str = "false";
/// Default contract deployment type
pub const DEFAULT_CONTRACT_DEPLOYMENT_TYPE: &str = "verifier";
/// Default VK sol path
pub const DEFAULT_VK_SOL: &str = "vk.sol";
/// Default VKA calldata path
pub const DEFAULT_VKA: &str = "vka.bytes";
/// Default VK abi path
pub const DEFAULT_VK_ABI: &str = "vk.abi";
/// Default scale rebase multipliers for calibration
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,10";
/// Default use reduced srs for verification
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
/// Default only check for range check rebase
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
/// Default commitment
pub const DEFAULT_COMMITMENT: &str = "kzg";
/// Default seed used to generate random data
pub const DEFAULT_SEED: &str = "21242";
/// Default number of decimals for instances rescaling on-chain.
pub const DEFAULT_DECIMALS: &str = "18";
/// Default path for the vka digest file
pub const DEFAULT_VKA_DIGEST: &str = "vka.digest";
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
@@ -185,8 +185,6 @@ pub enum ContractType {
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
reusable: bool,
},
/// Deploys a verifying key artifact that the reusable verifier loads into memory during runtime. Encodes the circuit specific data that was otherwise hardcoded onto the stack.
VerifyingKeyArtifact,
}
impl Default for ContractType {
@@ -205,7 +203,6 @@ impl std::fmt::Display for ContractType {
"verifier/reusable".to_string()
}
ContractType::Verifier { reusable: false } => "verifier".to_string(),
ContractType::VerifyingKeyArtifact => "vka".to_string(),
}
)
}
@@ -222,7 +219,6 @@ impl From<&str> for ContractType {
match s {
"verifier" => ContractType::Verifier { reusable: false },
"verifier/reusable" => ContractType::Verifier { reusable: true },
"vka" => ContractType::VerifyingKeyArtifact,
_ => {
log::error!("Invalid value for ContractType");
log::warn!("Defaulting to verifier");
@@ -232,24 +228,25 @@ impl From<&str> for ContractType {
}
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
/// wrapper for H160 to make it easy to parse into flag vals
pub struct H160Flag {
inner: H160,
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
impl From<H160Flag> for H160 {
fn from(val: H160Flag) -> H160 {
val.inner
}
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
impl ToFlags for H160Flag {
fn to_flags(&self) -> Vec<String> {
vec![format!("{:#x}", self.inner)]
}
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
impl From<&str> for H160Flag {
fn from(s: &str) -> Self {
Self {
@@ -297,7 +294,6 @@ impl IntoPy<PyObject> for ContractType {
match self {
ContractType::Verifier { reusable: true } => "verifier/reusable".to_object(py),
ContractType::Verifier { reusable: false } => "verifier".to_object(py),
ContractType::VerifyingKeyArtifact => "vka".to_object(py),
}
}
}
@@ -310,7 +306,6 @@ impl<'source> FromPyObject<'source> for ContractType {
match strval.to_lowercase().as_str() {
"verifier" => Ok(ContractType::Verifier { reusable: false }),
"verifier/reusable" => Ok(ContractType::Verifier { reusable: true }),
"vka" => Ok(ContractType::VerifyingKeyArtifact),
_ => Err(PyValueError::new_err("Invalid value for ContractType")),
}
}
@@ -358,8 +353,13 @@ pub fn get_styles() -> clap::builder::Styles {
}
/// Print completions for the given generator
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
pub fn print_completions<G: Generator>(r#gen: G, cmd: &mut Command) {
generate(
r#gen,
cmd,
cmd.get_name().to_string(),
&mut std::io::stdout(),
);
}
#[allow(missing_docs)]
@@ -375,6 +375,44 @@ pub struct Cli {
pub command: Option<Commands>,
}
/// Custom parser for data field that handles both direct JSON strings and file paths with '@' prefix
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub struct DataField(pub String);
impl FromStr for DataField {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Check if the input starts with '@'
if s.starts_with('@') {
// Extract the file path (remove the '@' prefix)
let file_path = &s[1..];
// Read the file content
let content = std::fs::read_to_string(file_path)
.map_err(|e| format!("Failed to read data file '{}': {}", file_path, e))?;
// Return the file content as the data field value
Ok(DataField(content))
} else {
// Use the input string directly
Ok(DataField(s.to_string()))
}
}
}
impl ToFlags for DataField {
fn to_flags(&self) -> Vec<String> {
vec![self.0.clone()]
}
}
impl std::fmt::Display for DataField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
@@ -393,9 +431,9 @@ pub enum Commands {
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the .json data file (with @ prefix) or a raw data string of the form '{"input_data": [[1, 2, 3]]}'
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_parser = DataField::from_str)]
data: Option<DataField>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
@@ -422,12 +460,32 @@ pub enum Commands {
#[clap(flatten)]
args: RunArgs,
},
/// Generate random data for a model
GenRandomData {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// Hand-written parser for graph variables, eg. batch_size=1
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = crate::parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
variables: Vec<(String, usize)>,
/// random seed for reproducibility (optional)
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
seed: u64,
/// min value for random data
#[arg(long, value_hint = clap::ValueHint::Other)]
min: Option<f32>,
/// max value for random data
#[arg(long, value_hint = clap::ValueHint::Other)]
max: Option<f32>,
},
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
CalibrateSettings {
/// The path to the .json calibration data file.
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
data: Option<String>,
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
@@ -606,43 +664,6 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
},
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information
/// derived from the file information in the data .json file.
/// Should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'T', long, value_hint = clap::ValueHint::FilePath)]
test_data: PathBuf,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// where the input data come from
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
input_source: TestDataSource,
/// where the output data come from
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
output_source: TestDataSource,
},
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
#[command(arg_required_else_help = true)]
TestUpdateAccountCalls {
/// The path to the verifier contract's address
#[arg(long, value_hint = clap::ValueHint::Other)]
addr: H160Flag,
/// The path to the .json data file.
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
},
/// Swaps the positions in the transcript that correspond to commitments
SwapProofCommitments {
/// The path to the proof file
@@ -685,6 +706,7 @@ pub enum Commands {
},
/// Encodes a proof into evm calldata
#[command(name = "encode-evm-calldata")]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
EncodeEvmCalldata {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
@@ -692,12 +714,13 @@ pub enum Commands {
/// The path to save the calldata to
#[arg(long, default_value = DEFAULT_CALLDATA, value_hint = clap::ValueHint::FilePath)]
calldata_path: Option<PathBuf>,
/// The path to the verification key address (only used if the vk is rendered as a separate contract)
/// The path to the serialized VKA file
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
vka_path: Option<PathBuf>,
},
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
CreateEvmVerifier {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
@@ -718,9 +741,10 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
reusable: Option<bool>,
},
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
/// Creates an evm verifier artifact to be used by the reusable verifier
#[command(name = "create-evm-vka")]
CreateEvmVKArtifact {
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
CreateEvmVka {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
@@ -730,39 +754,18 @@ pub enum Commands {
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
vk_path: Option<PathBuf>,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
},
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
/// The path to the .json data file, which should
/// contain the necessary calldata and account addresses
/// needed to read from all the on-chain
/// view functions that return the data that the network
/// ingests as inputs.
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to the witness file. This is needed for proof swapping for kzg commitments.
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
witness: Option<PathBuf>,
/// The path to output the vka calldata
#[arg(long, default_value = DEFAULT_VKA, value_hint = clap::ValueHint::FilePath)]
vka_path: Option<PathBuf>,
/// The number of decimals we want to use for the rescaling of the instances into on-chain floats
/// Default is 18, which is the number of decimals used by most ERC20 tokens
#[arg(long, default_value = DEFAULT_DECIMALS, value_hint = clap::ValueHint::Other)]
decimals: Option<usize>,
},
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
CreateEvmVerifierAggr {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
@@ -826,13 +829,14 @@ pub enum Commands {
commitment: Option<Commitments>,
},
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
DeployEvm {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Url)]
rpc_url: String,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
@@ -846,33 +850,9 @@ pub enum Commands {
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
contract: ContractType,
},
/// Deploys an evm verifier that allows for data attestation
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA, value_hint = clap::ValueHint::FilePath)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
/// The optimizer runs to set on the verifier. (Lower values optimize for deployment, while higher values optimize for execution)
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
/// Verifies a proof using a local Evm executor, returning accept or reject
#[command(name = "verify-evm")]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
VerifyEvm {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
@@ -880,15 +860,32 @@ pub enum Commands {
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: String,
/// The path to the serialized vka file
#[arg(long, default_value = DEFAULT_VKA, value_hint = clap::ValueHint::FilePath)]
vka_path: Option<PathBuf>,
},
/// Registers a VKA, returning the its digest used to identify it on-chain.
#[command(name = "register-vka")]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
RegisterVka {
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// does the verifier use data attestation ?
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_da: Option<H160Flag>,
// is the vk rendered seperately, if so specify an address
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
rpc_url: String,
/// The path to the reusable verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
addr_verifier: H160Flag,
/// The path to the serialized VKA file
#[arg(long, default_value = DEFAULT_VKA, value_hint = clap::ValueHint::FilePath)]
vka_path: Option<PathBuf>,
/// The path to output the VKA digest to
#[arg(long, default_value = DEFAULT_VKA_DIGEST, value_hint = clap::ValueHint::FilePath)]
vka_digest_path: Option<PathBuf>,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
#[cfg(not(feature = "no-update"))]
/// Updates ezkl binary to version specified (or latest if not specified)

1124
src/eth.rs

File diff suppressed because one or more lines are too long

View File

@@ -1,15 +1,13 @@
use crate::circuit::region::RegionSettings;
use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::eth::{
deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_multi_sol,
fix_da_single_sol,
};
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
use crate::eth::{deploy_contract_via_solidity, register_vka_via_rv};
#[allow(unused_imports)]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
use crate::eth::{get_contract_artifacts, verify_proof_via_solidity};
use crate::graph::input::{Calls, GraphData};
use crate::graph::input::GraphData;
use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::{
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
@@ -41,6 +39,7 @@ use halo2_proofs::poly::kzg::{
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
use halo2_solidity_verifier;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, WithSmallOrderMulGroup};
@@ -48,6 +47,7 @@ use halo2curves::serde::SerdeObject;
use indicatif::{ProgressBar, ProgressStyle};
use instant::Instant;
use itertools::Itertools;
use lazy_static::lazy_static;
use log::debug;
use log::{info, trace, warn};
use serde::de::DeserializeOwned;
@@ -56,17 +56,20 @@ use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
use std::fs::File;
use std::io::BufWriter;
use std::io::{Cursor, Write};
use std::io::Cursor;
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;
use tabled::Tabled;
use thiserror::Error;
use lazy_static::lazy_static;
use tract_onnx::prelude::IntoTensor;
use tract_onnx::prelude::Tensor as TractTensor;
lazy_static! {
#[derive(Debug)]
@@ -116,7 +119,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
} => gen_srs_cmd(
srs_path,
logrows as u32,
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
commitment.unwrap_or_else(|| Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
),
Commands::GetSrs {
srs_path,
@@ -134,6 +137,21 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
args,
),
Commands::GenRandomData {
model,
data,
variables,
seed,
min,
max,
} => gen_random_data(
model.unwrap_or(DEFAULT_MODEL.into()),
data.unwrap_or(DEFAULT_DATA.into()),
variables,
seed,
min,
max,
),
Commands::CalibrateSettings {
model,
settings_path,
@@ -153,7 +171,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
scale_rebase_multiplier,
max_logrows,
)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::GenWitness {
data,
@@ -163,17 +180,17 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
srs_path,
} => gen_witness(
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
data.unwrap_or(DEFAULT_DATA.into()),
data.unwrap_or(DataField(DEFAULT_DATA.into())).to_string(),
Some(output.unwrap_or(DEFAULT_WITNESS.into())),
vk_path,
srs_path,
)
.await
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Mock { model, witness } => mock(
model.unwrap_or(DEFAULT_MODEL.into()),
witness.unwrap_or(DEFAULT_WITNESS.into()),
),
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
Commands::CreateEvmVerifier {
vk_path,
srs_path,
@@ -192,49 +209,35 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
)
.await
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
Commands::EncodeEvmCalldata {
proof_path,
calldata_path,
addr_vk,
vka_path,
} => encode_evm_calldata(
proof_path.unwrap_or(DEFAULT_PROOF.into()),
calldata_path.unwrap_or(DEFAULT_CALLDATA.into()),
addr_vk,
vka_path,
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::CreateEvmVKArtifact {
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
Commands::CreateEvmVka {
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
vka_path,
decimals,
} => {
create_evm_vka(
vk_path.unwrap_or(DEFAULT_VK.into()),
srs_path,
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
sol_code_path.unwrap_or(DEFAULT_VK_SOL.into()),
abi_path.unwrap_or(DEFAULT_VK_ABI.into()),
)
.await
}
Commands::CreateEvmDataAttestation {
settings_path,
sol_code_path,
abi_path,
data,
witness,
} => {
create_evm_data_attestation(
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_DA.into()),
abi_path.unwrap_or(DEFAULT_VERIFIER_DA_ABI.into()),
data.unwrap_or(DEFAULT_DATA.into()),
witness,
vka_path.unwrap_or(DEFAULT_VKA.into()),
decimals.unwrap_or(DEFAULT_DECIMALS.parse().unwrap()),
)
.await
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
Commands::CreateEvmVerifierAggr {
vk_path,
srs_path,
@@ -280,29 +283,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
disable_selector_compression
.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()),
),
Commands::SetupTestEvmData {
data,
compiled_circuit,
test_data,
rpc_url,
input_source,
output_source,
} => {
setup_test_evm_witness(
data.unwrap_or(DEFAULT_DATA.into()),
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
test_data,
rpc_url,
input_source,
output_source,
)
.await
}
Commands::TestUpdateAccountCalls {
addr,
data,
rpc_url,
} => test_update_account_calls(addr, data.unwrap_or(DEFAULT_DATA.into()), rpc_url).await,
Commands::SwapProofCommitments {
proof_path,
witness_path,
@@ -411,6 +391,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
Commands::DeployEvm {
sol_code_path,
rpc_url,
@@ -429,39 +410,35 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
)
.await
}
Commands::DeployEvmDataAttestation {
data,
settings_path,
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
} => {
deploy_da_evm(
data.unwrap_or(DEFAULT_DATA.into()),
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_DA.into()),
rpc_url,
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS_DA.into()),
optimizer_runs,
private_key,
)
.await
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
Commands::VerifyEvm {
proof_path,
addr_verifier,
rpc_url,
addr_da,
addr_vk,
vka_path,
} => {
verify_evm(
proof_path.unwrap_or(DEFAULT_PROOF.into()),
addr_verifier,
rpc_url,
addr_da,
addr_vk,
vka_path,
)
.await
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
Commands::RegisterVka {
addr_verifier,
vka_path,
rpc_url,
vka_digest_path,
private_key,
} => {
register_vka(
rpc_url,
addr_verifier,
vka_path.unwrap_or(DEFAULT_VKA.into()),
vka_digest_path.unwrap_or(DEFAULT_VKA_DIGEST.into()),
private_key,
)
.await
}
@@ -503,7 +480,9 @@ fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
.status()
.is_err()
{
log::warn!("bash is not installed on this system, trying to run the install script with sh (may fail)");
log::warn!(
"bash is not installed on this system, trying to run the install script with sh (may fail)"
);
"sh"
} else {
"bash"
@@ -710,9 +689,9 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLErr
Ok(String::new())
}
pub(crate) async fn gen_witness(
pub(crate) fn gen_witness(
compiled_circuit_path: PathBuf,
data: PathBuf,
data: String,
output: Option<PathBuf>,
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
@@ -720,7 +699,7 @@ pub(crate) async fn gen_witness(
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
let data: GraphData = GraphData::from_path(data)?;
let data = GraphData::from_str(&data)?;
let settings = circuit.settings().clone();
let vk = if let Some(vk) = vk_path {
@@ -732,7 +711,7 @@ pub(crate) async fn gen_witness(
None
};
let mut input = circuit.load_graph_input(&data).await?;
let mut input = circuit.load_graph_input(&data)?;
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
let mut input = circuit.load_graph_input(&data)?;
@@ -828,6 +807,85 @@ pub(crate) fn gen_circuit_settings(
Ok(String::new())
}
/// Generate a circuit settings file
pub(crate) fn gen_random_data(
model_path: PathBuf,
data_path: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
min: Option<f32>,
max: Option<f32>,
) -> Result<String, EZKLError> {
let mut file = std::fs::File::open(&model_path).map_err(|e| {
crate::graph::errors::GraphError::ReadWriteFileError(
model_path.display().to_string(),
e.to_string(),
)
})?;
let (tract_model, _symbol_values) = Model::load_onnx_using_tract(&mut file, &variables)?;
let input_facts = tract_model
.input_outlets()
.map_err(|e| EZKLError::from(e.to_string()))?
.iter()
.map(|&i| tract_model.outlet_fact(i))
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
.map_err(|e| EZKLError::from(e.to_string()))?;
let min = min.unwrap_or(0.0);
let max = max.unwrap_or(1.0);
/// Generates a random tensor of a given size and type.
fn random(
sizes: &[usize],
datum_type: tract_onnx::prelude::DatumType,
seed: u64,
min: f32,
max: f32,
) -> TractTensor {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
let slice = tensor.as_slice_mut::<f32>().unwrap();
slice.iter_mut().for_each(|x| *x = rng.gen_range(min..max));
tensor.cast_to_dt(datum_type).unwrap().into_owned()
}
fn tensor_for_fact(
fact: &tract_onnx::prelude::TypedFact,
seed: u64,
min: f32,
max: f32,
) -> TractTensor {
if let Some(value) = &fact.konst {
return value.clone().into_tensor();
}
random(
fact.shape
.as_concrete()
.expect("Expected concrete shape, found: {fact:?}"),
fact.datum_type,
seed,
min,
max,
)
}
let generated = input_facts
.iter()
.map(|v| tensor_for_fact(v, seed, min, max))
.collect_vec();
let data = GraphData::from_tract_data(&generated)?;
data.save(data_path)?;
Ok(String::new())
}
// not for wasm targets
pub(crate) fn init_spinner() -> ProgressBar {
let pb = indicatif::ProgressBar::new_spinner();
@@ -964,9 +1022,9 @@ impl AccuracyResults {
/// Calibrate the circuit parameters to a given a dataset
#[allow(trivial_casts)]
#[allow(clippy::too_many_arguments)]
pub(crate) async fn calibrate(
pub(crate) fn calibrate(
model_path: PathBuf,
data: PathBuf,
data: String,
settings_path: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: f64,
@@ -980,7 +1038,7 @@ pub(crate) async fn calibrate(
use crate::fieldutils::IntegerRep;
let data = GraphData::from_path(data)?;
let data = GraphData::from_str(&data)?;
// load the pre-generated settings
let settings = GraphSettings::load(&settings_path)?;
// now retrieve the run args
@@ -990,7 +1048,7 @@ pub(crate) async fn calibrate(
let input_shapes = model.graph.input_shapes()?;
let chunks = data.split_into_batches(input_shapes).await?;
let chunks = data.split_into_batches(input_shapes)?;
info!("num calibration batches: {}", chunks.len());
debug!("running onnx predictions...");
@@ -1101,7 +1159,7 @@ pub(crate) async fn calibrate(
let chunk = chunk.clone();
let data = circuit
.load_graph_from_file_exclusively(&chunk)
.load_graph_input(&chunk)
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
let forward_res = circuit
@@ -1356,6 +1414,7 @@ pub(crate) fn mock(
Ok(String::new())
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
pub(crate) async fn create_evm_verifier(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
@@ -1373,7 +1432,9 @@ pub(crate) async fn create_evm_verifier(
)?;
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
// create a scales array that is the same length as the number of instances, all populated with 0
let scales = vec![0; num_instance.len()];
// let poseidon_instance = settings.module_sizes.num_instances().iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
@@ -1382,7 +1443,10 @@ pub(crate) async fn create_evm_verifier(
&params,
&vk,
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
num_instance,
&num_instance,
&scales,
0,
0,
);
let (verifier_solidity, name) = if reusable {
(generator.render_separately()?.0, "Halo2VerifierReusable") // ignore the rendered vk artifact for now and generate it in create_evm_vka
@@ -1400,12 +1464,13 @@ pub(crate) async fn create_evm_verifier(
Ok(String::new())
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
pub(crate) async fn create_evm_vka(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
vka_path: PathBuf,
decimals: usize,
) -> Result<String, EZKLError> {
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
@@ -1415,165 +1480,55 @@ pub(crate) async fn create_evm_vka(
commitment,
)?;
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let num_poseidon_instance = settings.module_sizes.num_instances().iter().sum::<usize>();
let num_fixed_point_instance = settings
.model_instance_shapes
.iter()
.map(|x| x.iter().product::<usize>())
.collect_vec();
let scales = settings.get_model_instance_scales();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
// assert that the decimals must be less than or equal to 38 to prevent overflow
if decimals > 38 {
return Err("decimals must be less than or equal to 38".into());
}
let generator = halo2_solidity_verifier::SolidityGenerator::new(
&params,
&vk,
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
num_instance,
&num_fixed_point_instance,
&scales,
decimals,
num_poseidon_instance,
);
let vk_solidity = generator.render_separately()?.1;
let vka_words: Vec<[u8; 32]> = generator.render_separately_vka_words()?.1;
let serialized_vka_words = bincode::serialize(&vka_words).or_else(|e| {
Err(EZKLError::from(format!(
"Failed to serialize vka words: {}",
e
)))
})?;
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
File::create(vka_path.clone())?.write_all(&serialized_vka_words)?;
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingArtifact", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
// Load in the vka words and deserialize them and check that they match the original
let bytes = std::fs::read(vka_path)?;
let vka_buf: Vec<[u8; 32]> = bincode::deserialize(&bytes)
.map_err(|e| EZKLError::from(format!("Failed to deserialize vka words: {e}")))?;
if vka_buf != vka_words {
return Err("vka words do not match".into());
};
Ok(String::new())
}
pub(crate) async fn create_evm_data_attestation(
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
input: PathBuf,
witness: Option<PathBuf>,
) -> Result<String, EZKLError> {
#[allow(unused_imports)]
use crate::graph::{DataSource, VarVisibility};
use crate::{graph::Visibility, pfsys::get_proof_commitments};
let settings = GraphSettings::load(&settings_path)?;
let visibility = VarVisibility::from_args(&settings.run_args)?;
trace!("params computed");
// if input is not provided, we just instantiate dummy input data
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
// The number of input and output instances we attest to for the single call data attestation
let mut input_len = None;
let mut output_len = None;
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
if visibility.output.is_private() {
return Err("private output data on chain is not supported on chain".into());
}
let mut on_chain_output_data = vec![];
match source.calls {
Calls::Multiple(calls) => {
for call in calls {
on_chain_output_data.push(call);
}
}
Calls::Single(call) => {
output_len = Some(call.len);
}
}
Some(on_chain_output_data)
} else {
None
};
let input_data = if let DataSource::OnChain(source) = data.input_data {
if visibility.input.is_private() {
return Err("private input data on chain is not supported on chain".into());
}
let mut on_chain_input_data = vec![];
match source.calls {
Calls::Multiple(calls) => {
for call in calls {
on_chain_input_data.push(call);
}
}
Calls::Single(call) => {
input_len = Some(call.len);
}
}
Some(on_chain_input_data)
} else {
None
};
// Read the settings file. Look if either the run_ars.input_visibility, run_args.output_visibility or run_args.param_visibility is KZGCommit
// if so, then we need to load the witness
let commitment_bytes = if settings.run_args.input_visibility == Visibility::KZGCommit
|| settings.run_args.output_visibility == Visibility::KZGCommit
|| settings.run_args.param_visibility == Visibility::KZGCommit
{
let witness = GraphWitness::from_path(witness.unwrap_or(DEFAULT_WITNESS.into()))?;
let commitments = witness.get_polycommitments();
let proof_first_bytes = get_proof_commitments::<
KZGCommitmentScheme<Bn256>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&commitments);
Some(proof_first_bytes.unwrap())
} else {
None
};
// if either input_len or output_len is Some then we are in the single call data attestation mode
if input_len.is_some() || output_len.is_some() {
let output = fix_da_single_sol(input_len, output_len)?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationSingle", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
} else {
let output = fix_da_multi_sol(input_data, output_data, commitment_bytes)?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationMulti", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
}
Ok(String::new())
}
pub(crate) async fn deploy_da_evm(
data: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
) -> Result<String, EZKLError> {
let contract_address = deploy_da_verifier_via_solidity(
settings_path,
data,
sol_code_path,
rpc_url.as_deref(),
runs,
private_key.as_deref(),
)
.await?;
info!("Contract deployed at: {}", contract_address);
let mut f = File::create(addr_path)?;
write!(f, "{:#?}", contract_address)?;
Ok(String::new())
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
pub(crate) async fn deploy_evm(
sol_code_path: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
@@ -1582,11 +1537,10 @@ pub(crate) async fn deploy_evm(
let contract_name = match contract {
ContractType::Verifier { reusable: false } => "Halo2Verifier",
ContractType::Verifier { reusable: true } => "Halo2VerifierReusable",
ContractType::VerifyingKeyArtifact => "Halo2VerifyingArtifact",
};
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
&rpc_url,
runs,
private_key.as_deref(),
contract_name,
@@ -1600,21 +1554,61 @@ pub(crate) async fn deploy_evm(
Ok(String::new())
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
pub(crate) async fn register_vka(
rpc_url: String,
rv_addr: H160Flag,
vka_path: PathBuf,
vka_digest_path: PathBuf,
private_key: Option<String>,
) -> Result<String, EZKLError> {
// Load the vka, which is bincode serialized, from the vka_path
let bytes = std::fs::read(vka_path)?;
let vka_buf: Vec<[u8; 32]> = bincode::deserialize(&bytes)
.map_err(|e| EZKLError::from(format!("Failed to deserialize vka words: {e}")))?;
let vka_digest = register_vka_via_rv(
rpc_url.as_ref(),
private_key.as_deref(),
rv_addr.into(),
&vka_buf,
)
.await?;
info!("VKA digest: {:#?}", vka_digest);
let mut f = File::create(vka_digest_path)?;
write!(f, "{:#?}", vka_digest)?;
Ok(String::new())
}
/// Encodes the calldata for the EVM verifier (both aggregated and single proof)
/// TODO: Add a "RV address param" which will query the "RegisteredVKA" events to fetch the
/// VKA from the vka_digest.
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
pub(crate) fn encode_evm_calldata(
proof_path: PathBuf,
calldata_path: PathBuf,
addr_vk: Option<H160Flag>,
vka_path: Option<PathBuf>,
) -> Result<Vec<u8>, EZKLError> {
let snark = Snark::load::<IPACommitmentScheme<G1Affine>>(&proof_path)?;
let flattened_instances = snark.instances.into_iter().flatten();
// Load the vka, which is bincode serialized, from the vka_path
let vka_buf: Option<Vec<[u8; 32]>> =
match vka_path {
Some(path) => {
let bytes = std::fs::read(path)?;
Some(bincode::deserialize(&bytes).map_err(|e| {
EZKLError::from(format!("Failed to deserialize vka words: {e}"))
})?)
}
None => None,
};
let vka: Option<&[[u8; 32]]> = vka_buf.as_deref();
let encoded = halo2_solidity_verifier::encode_calldata(
addr_vk
.as_ref()
.map(|x| alloy::primitives::Address::from(*x).0)
.map(|x| x.0),
vka,
&snark.proof,
&flattened_instances.collect::<Vec<_>>(),
);
@@ -1626,35 +1620,24 @@ pub(crate) fn encode_evm_calldata(
Ok(encoded)
}
/// TODO: Add an optional vka_digest param that will allow use to fetch the assocaited VKA
/// from the RegisteredVKA events on the RV.
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
pub(crate) async fn verify_evm(
proof_path: PathBuf,
addr_verifier: H160Flag,
rpc_url: Option<String>,
addr_da: Option<H160Flag>,
addr_vk: Option<H160Flag>,
rpc_url: String,
vka_path: Option<PathBuf>,
) -> Result<String, EZKLError> {
use crate::eth::verify_proof_with_data_attestation;
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let result = if let Some(addr_da) = addr_da {
verify_proof_with_data_attestation(
proof.clone(),
addr_verifier.into(),
addr_da.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
)
.await?
} else {
verify_proof_via_solidity(
proof.clone(),
addr_verifier.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
)
.await?
};
let result = verify_proof_via_solidity(
proof.clone(),
addr_verifier.into(),
vka_path.map(|s| s.into()),
rpc_url.as_ref(),
)
.await?;
info!("Solidity verification result: {}", result);
@@ -1665,6 +1648,7 @@ pub(crate) async fn verify_evm(
Ok(String::new())
}
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
pub(crate) async fn create_evm_aggregate_verifier(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
@@ -1690,8 +1674,8 @@ pub(crate) async fn create_evm_aggregate_verifier(
.sum();
let num_instance = AggregationCircuit::num_instance(num_instance);
let scales = vec![0; num_instance.len()];
assert_eq!(num_instance.len(), 1);
let num_instance = num_instance[0];
let agg_vk = load_vk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(vk_path, ())?;
@@ -1699,7 +1683,10 @@ pub(crate) async fn create_evm_aggregate_verifier(
&params,
&agg_vk,
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
num_instance,
&num_instance,
&scales,
0,
0,
);
let acc_encoding = halo2_solidity_verifier::AccumulatorEncoding::new(
@@ -1788,53 +1775,7 @@ pub(crate) fn setup(
Ok(String::new())
}
pub(crate) async fn setup_test_evm_witness(
data_path: PathBuf,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
rpc_url: Option<String>,
input_source: TestDataSource,
output_source: TestDataSource,
) -> Result<String, EZKLError> {
use crate::graph::TestOnChainData;
let mut data = GraphData::from_path(data_path)?;
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
// if both input and output are from files fail
if matches!(input_source, TestDataSource::File) && matches!(output_source, TestDataSource::File)
{
return Err("Both input and output cannot be from files".into());
}
let test_on_chain_data = TestOnChainData {
data: test_data.clone(),
rpc: rpc_url,
data_sources: TestSources {
input: input_source,
output: output_source,
},
};
circuit
.populate_on_chain_test_data(&mut data, test_on_chain_data)
.await?;
Ok(String::new())
}
use crate::pfsys::ProofType;
pub(crate) async fn test_update_account_calls(
addr: H160Flag,
data: PathBuf,
rpc_url: Option<String>,
) -> Result<String, EZKLError> {
use crate::eth::update_account_calls;
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
Ok(String::new())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn prove(
@@ -2048,6 +1989,7 @@ pub(crate) fn mock_aggregate(
Ok(String::new())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
vk_path: PathBuf,

View File

@@ -5,10 +5,12 @@ use halo2curves::ff::PrimeField;
/// Integer representation of a PrimeField element.
pub type IntegerRep = i128;
/// Converts an i64 to a PrimeField element.
/// Converts an integer rep to a PrimeField element.
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else if x == IntegerRep::MIN {
-F::from_u128(x.saturating_neg() as u128) - F::ONE
} else {
-F::from_u128(x.saturating_neg() as u128)
}
@@ -32,6 +34,9 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
/// Converts a PrimeField element to an i64.
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
if x > F::from_u128(IntegerRep::MAX as u128) {
if x == -F::from_u128(IntegerRep::MAX as u128) - F::ONE {
return IntegerRep::MIN;
}
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
@@ -51,7 +56,7 @@ mod test {
use halo2curves::pasta::Fp as F;
#[test]
fn test_conv() {
fn integerreptofelt() {
let res: F = integer_rep_to_felt(-15);
assert_eq!(res, -F::from(15));
@@ -69,8 +74,24 @@ mod test {
fn felttointegerrep() {
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
}
#[test]
fn felttointegerrepmin() {
let x = IntegerRep::MIN;
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
#[test]
fn felttointegerrepmax() {
let x = IntegerRep::MAX;
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: IntegerRep = felt_to_integer_rep::<F>(fieldx);
assert_eq!(x, xf);
}
}

View File

@@ -11,6 +11,12 @@ pub enum GraphError {
/// Shape mismatch in circuit construction
#[error("invalid dimensions used for node {0} ({1})")]
InvalidDims(usize, String),
/// Non scalar power
#[error("we only support scalar powers")]
NonScalarPower,
/// Non scalar base for exponentiation
#[error("we only support scalar bases for exponentiation")]
NonScalarBase,
/// Wrong method was called to configure an op
#[error("wrong method was called to configure node {0} ({1})")]
WrongMethod(usize, String),
@@ -27,7 +33,7 @@ pub enum GraphError {
#[error("a node is missing required params: {0}")]
MissingParams(String),
/// A node has missing parameters
#[error("a node is has misformed params: {0}")]
#[error("a node has misformed params: {0}")]
MisformedParams(String),
/// Error in the configuration of the visibility of variables
#[error("there should be at least one set of public variables")]
@@ -92,14 +98,13 @@ pub enum GraphError {
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[tokio postgres] {0}")]
TokioPostgresError(#[from] tokio_postgres::Error),
/// Eth error
#[cfg(all(
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[eth] {0}")]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
EthError(#[from] crate::eth::EthError),
/// Json error
#[error("[json] {0}")]
@@ -113,13 +118,13 @@ pub enum GraphError {
/// Missing input for a node
#[error("missing input for node {0}")]
MissingInput(usize),
///
/// Ranges can only be constant
#[error("range only supports constant inputs in a zk circuit")]
NonConstantRange,
///
/// Trilu diagonal must be constant
#[error("trilu only supports constant diagonals in a zk circuit")]
NonConstantTrilu,
///
/// The witness was too short
#[error("insufficient witness values to generate a fixed output")]
InsufficientWitnessValues,
/// Missing scale
@@ -135,7 +140,9 @@ pub enum GraphError {
#[error("range check {0} is too large")]
RangeCheckTooLarge(usize),
///Cannot use on-chain data source as private data
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
#[error(
"cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm."
)]
OnChainDataSource,
/// Missing data source
#[error("missing data source")]
@@ -143,4 +150,13 @@ pub enum GraphError {
/// Invalid RunArg
#[error("invalid RunArgs: {0}")]
InvalidRunArgs(String),
/// Only nearest neighbor interpolation is supported
#[error("only nearest neighbor interpolation is supported")]
InvalidInterpolation,
/// Node has a missing output
#[error("node {0} has a missing output")]
MissingOutput(usize),
/// Inssuficient advice columns
#[error("insuficcient advice columns (need {0} at least)")]
InsufficientAdviceColumns(usize),
}

File diff suppressed because it is too large Load Diff

View File

@@ -6,9 +6,6 @@ pub mod model;
pub mod modules;
/// Inner elements of a computational graph that represent a single operation / constraints.
pub mod node;
/// postgres helper functions
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod postgres;
/// Helper functions
pub mod utilities;
/// Representations of a computational graph's variables.
@@ -28,9 +25,11 @@ use itertools::Itertools;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
use self::input::{FileSource, GraphData};
use self::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use self::input::OnChainSource;
use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
@@ -280,7 +279,13 @@ impl GraphWitness {
})?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
let witness: GraphWitness =
serde_json::from_reader(reader).map_err(Into::<GraphError>::into)?;
// check versions match
crate::check_version_string_matches(witness.version.as_deref().unwrap_or(""));
Ok(witness)
}
/// Save the model input to a file
@@ -449,6 +454,10 @@ pub struct GraphSettings {
pub num_blinding_factors: Option<usize>,
/// unix time timestamp
pub timestamp: Option<u128>,
/// Model inputs types (if any)
pub input_types: Option<Vec<InputType>>,
/// Model outputs types (if any)
pub output_types: Option<Vec<InputType>>,
}
impl GraphSettings {
@@ -531,16 +540,38 @@ impl GraphSettings {
/// calculate the total number of instances
pub fn total_instances(&self) -> Vec<usize> {
let mut instances: Vec<usize> = self
.model_instance_shapes
.iter()
.map(|x| x.iter().product())
.collect();
instances.extend(self.module_sizes.num_instances());
let mut instances: Vec<usize> = self.module_sizes.num_instances();
instances.extend(
self.model_instance_shapes
.iter()
.map(|x| x.iter().product::<usize>()),
);
instances
}
/// get the scale data for instances
pub fn get_model_instance_scales(&self) -> Vec<crate::Scale> {
let mut scales = vec![];
if self.run_args.input_visibility.is_public() {
scales.extend(
self.model_input_scales
.iter()
.map(|x| x.clone())
.collect::<Vec<crate::Scale>>(),
);
};
if self.run_args.output_visibility.is_public() {
scales.extend(
self.model_output_scales
.iter()
.map(|x| x.clone())
.collect::<Vec<crate::Scale>>(),
);
};
scales
}
/// calculate the log2 of the total number of instances
pub fn log2_total_instances(&self) -> u32 {
let sum = self.total_instances().iter().sum::<usize>();
@@ -572,10 +603,14 @@ impl GraphSettings {
// buf reader
let reader =
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
serde_json::from_reader(reader).map_err(|e| {
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})
})?;
crate::check_version_string_matches(&settings.version);
Ok(settings)
}
/// Export the ezkl configuration as json
@@ -609,11 +644,6 @@ impl GraphSettings {
}
}
///
pub fn uses_modules(&self) -> bool {
!self.module_sizes.max_constraints() > 0
}
/// if any visibility is encrypted or hashed
pub fn module_requires_fixed(&self) -> bool {
self.run_args.input_visibility.is_hashed()
@@ -697,6 +727,9 @@ impl GraphCircuit {
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
// check the versions matche
crate::check_version_string_matches(&result.core.settings.version);
Ok(result)
}
}
@@ -752,8 +785,8 @@ pub struct TestOnChainData {
/// The path to the test witness
pub data: std::path::PathBuf,
/// rpc endpoint
pub rpc: Option<String>,
///
pub rpc: String,
/// data sources for the on chain data
pub data_sources: TestSources,
}
@@ -919,128 +952,11 @@ impl GraphCircuit {
}
///
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
pub fn load_graph_input(&mut self, data: &GraphData) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
self.process_data_source(&data.input_data, shapes, scales, input_types)
}
///
pub fn load_graph_from_file_exclusively(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
debug!("input scales: {:?}", scales);
match &data.input_data {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
_ => unreachable!("cannot load from on-chain data"),
}
}
///
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn load_graph_input(
&mut self,
data: &GraphData,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
let shapes = self.model().graph.input_shapes()?;
let scales = self.model().graph.get_input_scales();
let input_types = self.model().graph.get_input_types()?;
debug!("input scales: {:?}", scales);
self.process_data_source(&data.input_data, shapes, scales, input_types)
.await
}
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
/// Process the data source for the model
fn process_data_source(
&mut self,
data: &DataSource,
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::OnChain(_) => Err(GraphError::OnChainDataSource),
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
/// Process the data source for the model
async fn process_data_source(
&mut self,
data: &DataSource,
shapes: Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
input_types: Vec<InputType>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
match &data {
DataSource::OnChain(source) => {
let mut per_item_scale = vec![];
for (i, shape) in shapes.iter().enumerate() {
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
}
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
.await
}
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::DB(pg) => {
let data = pg.fetch_and_format_as_file().await?;
self.load_file_data(&data, &shapes, scales, input_types)
}
}
}
/// Prepare on chain test data
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn load_on_chain_data(
&mut self,
source: OnChainSource,
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
use crate::eth::{
evm_quantize_multi, evm_quantize_single, read_on_chain_inputs_multi,
read_on_chain_inputs_single, setup_eth_backend,
};
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let quantized_evm_inputs = match source.calls {
input::Calls::Single(call) => {
let (inputs, decimals) =
read_on_chain_inputs_single(client.clone(), client_address, call).await?;
evm_quantize_single(client, scales, &inputs, decimals).await?
}
input::Calls::Multiple(calls) => {
let inputs =
read_on_chain_inputs_multi(client.clone(), client_address, &calls).await?;
evm_quantize_multi(client, scales, &inputs).await?
}
};
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
let mut inputs: Vec<Tensor<Fp>> = vec![];
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
let mut t: Tensor<Fp> = input.iter().cloned().collect();
t.reshape(shape)?;
inputs.push(t);
}
Ok(inputs)
self.load_file_data(data.input_data.values(), &shapes, scales, input_types)
}
///
@@ -1418,75 +1334,6 @@ impl GraphCircuit {
let model = Model::from_run_args(&params.run_args, model_path)?;
Self::new_from_settings(model, params.clone(), check_mode)
}
///
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub async fn populate_on_chain_test_data(
&mut self,
data: &mut GraphData,
test_on_chain_data: TestOnChainData,
) -> Result<(), GraphError> {
// Set up local anvil instance for reading on-chain data
let input_scales = self.model().graph.get_input_scales();
let output_scales = self.model().graph.get_output_scales()?;
let input_shapes = self.model().graph.input_shapes()?;
let output_shapes = self.model().graph.output_shapes()?;
if matches!(
test_on_chain_data.data_sources.input,
TestDataSource::OnChain
) {
// if not public then fail
if self.settings().run_args.input_visibility.is_private() {
return Err(GraphError::OnChainDataSource);
}
let input_data = match &data.input_data {
DataSource::File(input_data) => input_data,
_ => {
return Err(GraphError::OnChainDataSource);
}
};
// Get the flatten length of input_data
// if the input source is a field then set scale to 0
let datam: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
input_data,
input_scales,
input_shapes,
test_on_chain_data.rpc.as_deref(),
)
.await?;
data.input_data = datam.1.into();
}
if matches!(
test_on_chain_data.data_sources.output,
TestDataSource::OnChain
) {
// if not public then fail
if self.settings().run_args.output_visibility.is_private() {
return Err(GraphError::OnChainDataSource);
}
let output_data = match &data.output_data {
Some(DataSource::File(output_data)) => output_data,
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
_ => return Err(GraphError::MissingDataSource),
};
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
output_data,
output_scales,
output_shapes,
test_on_chain_data.rpc.as_deref(),
)
.await?;
data.output_data = Some(datum.1.into());
}
// Save the updated GraphData struct to the data_path
data.save(test_on_chain_data.data)?;
Ok(())
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]

View File

@@ -1,7 +1,6 @@
use super::errors::GraphError;
use super::extract_const_quantized_values;
use super::node::*;
use super::scale_to_multiplier;
use super::vars::*;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
@@ -379,13 +378,18 @@ pub struct ParsedNodes {
pub nodes: BTreeMap<usize, NodeType>,
inputs: Vec<usize>,
outputs: Vec<Outlet>,
output_types: Vec<InputType>,
}
impl ParsedNodes {
/// Returns the output types of the computational graph.
pub fn get_output_types(&self) -> Vec<InputType> {
self.output_types.clone()
}
/// Returns the number of the computational graph's inputs
pub fn num_inputs(&self) -> usize {
let input_nodes = self.inputs.iter();
input_nodes.len()
self.inputs.len()
}
/// Input types
@@ -425,8 +429,7 @@ impl ParsedNodes {
/// Returns the number of the computational graph's outputs
pub fn num_outputs(&self) -> usize {
let output_nodes = self.outputs.iter();
output_nodes.len()
self.outputs.len()
}
/// Returns shapes of the computational graph's outputs
@@ -493,6 +496,16 @@ impl Model {
Ok(om)
}
/// Gets the input types from the parsed nodes
pub fn get_input_types(&self) -> Result<Vec<InputType>, GraphError> {
self.graph.get_input_types()
}
/// Gets the output types from the parsed nodes
pub fn get_output_types(&self) -> Vec<InputType> {
self.graph.get_output_types()
}
///
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
@@ -576,6 +589,11 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
num_shuffles: res.num_shuffles,
@@ -621,19 +639,23 @@ impl Model {
/// * `scale` - The scale to use for quantization.
/// * `public_params` - Whether to make the params public.
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn load_onnx_using_tract(
pub(crate) fn load_onnx_using_tract(
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
variables: &[(String, usize)],
) -> Result<TractResult, GraphError> {
use tract_onnx::tract_hir::internal::GenericFactoid;
let mut model = tract_onnx::onnx().model_for_read(reader)?;
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(run_args.variables.clone());
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
for (i, id) in model.clone().inputs.iter().enumerate() {
let input = model.node_mut(id.node);
if input.outputs.len() == 0 {
return Err(GraphError::MissingOutput(id.node));
}
let mut fact: InferenceFact = input.outputs[0].fact.clone();
for (i, x) in fact.clone().shape.dims().enumerate() {
@@ -655,7 +677,7 @@ impl Model {
}
let mut symbol_values = SymbolValues::default();
for (symbol, value) in run_args.variables.iter() {
for (symbol, value) in variables.iter() {
let symbol = model.symbols.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
debug!("set {} to {}", symbol, value);
@@ -683,7 +705,7 @@ impl Model {
) -> Result<ParsedNodes, GraphError> {
let start_time = instant::Instant::now();
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
let (model, symbol_values) = Self::load_onnx_using_tract(reader, &run_args.variables)?;
let scales = VarScales::from_args(run_args);
let nodes = Self::nodes_from_graph(
@@ -702,6 +724,11 @@ impl Model {
nodes,
inputs: model.inputs.iter().map(|o| o.node).collect(),
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
output_types: model
.outputs
.iter()
.map(|o| Ok::<InputType, GraphError>(model.outlet_fact(*o)?.datum_type.into()))
.collect::<Result<Vec<_>, GraphError>>()?,
};
let duration = start_time.elapsed();
@@ -860,6 +887,15 @@ impl Model {
nodes: subgraph_nodes,
inputs: model.inputs.iter().map(|o| o.node).collect(),
outputs: model.outputs.iter().map(|o| (o.node, o.slot)).collect(),
output_types: model
.outputs
.iter()
.map(|o| {
Ok::<InputType, GraphError>(
model.outlet_fact(*o)?.datum_type.into(),
)
})
.collect::<Result<Vec<_>, GraphError>>()?,
};
let om = Model {
@@ -906,6 +942,7 @@ impl Model {
n.opkind = SupportedOp::Input(Input {
scale,
datum_type: inp.datum_type,
decomp: !run_args.ignore_range_check_inputs_outputs,
});
input_idx += 1;
n.out_scale = scale;
@@ -964,7 +1001,7 @@ impl Model {
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
})?;
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
let (model, _) = Model::load_onnx_using_tract(&mut file, &run_args.variables)?;
let datum_types: Vec<DatumType> = model
.input_outlets()?
@@ -1016,6 +1053,10 @@ impl Model {
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
if vars.advices.len() < 3 {
return Err(GraphError::InsufficientAdviceColumns(3));
}
let mut base_gate = PolyConfig::configure(
meta,
vars.advices[0..2].try_into()?,
@@ -1035,6 +1076,10 @@ impl Model {
}
if settings.requires_dynamic_lookup() {
if vars.advices.len() < 6 {
return Err(GraphError::InsufficientAdviceColumns(6));
}
base_gate.configure_dynamic_lookup(
meta,
vars.advices[0..3].try_into()?,
@@ -1043,10 +1088,13 @@ impl Model {
}
if settings.requires_shuffle() {
if vars.advices.len() < 6 {
return Err(GraphError::InsufficientAdviceColumns(6));
}
base_gate.configure_shuffles(
meta,
vars.advices[0..2].try_into()?,
vars.advices[3..5].try_into()?,
vars.advices[0..3].try_into()?,
vars.advices[3..6].try_into()?,
)?;
}
@@ -1061,6 +1109,7 @@ impl Model {
/// * `vars` - The variables for the circuit.
/// * `witnessed_outputs` - The values to compare against.
/// * `constants` - The constants for the circuit.
#[allow(clippy::too_many_arguments)]
pub fn layout(
&self,
mut config: ModelConfig,
@@ -1123,17 +1172,10 @@ impl Model {
})?;
if run_args.output_visibility.is_public() || run_args.output_visibility.is_fixed() {
let output_scales = self.graph.get_output_scales().map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let comparators = if run_args.output_visibility == Visibility::Public {
let res = vars
.instance
@@ -1155,7 +1197,9 @@ impl Model {
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
Box::new(HybridOp::RangeCheck(tolerance)),
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
)
.map_err(|e| e.into())
})
@@ -1226,6 +1270,7 @@ impl Model {
values.iter().map(|v| v.dims()).collect_vec()
);
let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let res = if node.is_constant() && node.num_uses() == 1 {
@@ -1363,6 +1408,7 @@ impl Model {
results.insert(*idx, full_results);
}
}
debug!("------------ layout of {} took {:?}", idx, start.elapsed());
}
// we do this so we can support multiple passes of the same model and have deterministic results (Non-assigned inputs etc... etc...)
@@ -1413,11 +1459,9 @@ impl Model {
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
.map(|output| {
let mut comparator: ValTensor<Fp> = (0..output.len())
.map(|_| {
if !self.visibility.output.is_fixed() {
@@ -1430,13 +1474,12 @@ impl Model {
.into();
comparator.reshape(output.dims())?;
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
Box::new(HybridOp::RangeCheck(tolerance)),
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
)
})
.collect::<Result<Vec<_>, _>>();
@@ -1458,7 +1501,7 @@ impl Model {
.iter()
.map(|x| {
x.get_felt_evals()
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
.unwrap_or_else(|_| Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
})
.collect();
@@ -1528,6 +1571,7 @@ impl Model {
let mut op = crate::circuit::Constant::new(
c.quantized_values.clone(),
c.raw_values.clone(),
c.decomp,
);
op.pre_assign(consts[const_idx].clone());
n.opkind = SupportedOp::Constant(op);
@@ -1555,4 +1599,16 @@ impl Model {
}
Ok(instance_shapes)
}
/// Input types of the computational graph's public inputs (if any)
pub fn instance_types(&self) -> Result<Vec<InputType>, GraphError> {
let mut instance_types = vec![];
if self.visibility.input.is_public() {
instance_types.extend(self.graph.get_input_types()?);
}
if self.visibility.output.is_public() {
instance_types.extend(self.graph.get_output_types());
}
Ok(instance_types)
}
}

View File

@@ -14,14 +14,11 @@ use serde::{Deserialize, Serialize};
use super::errors::GraphError;
use super::{VarVisibility, Visibility};
/// poseidon len to hash in tree
pub const POSEIDON_LEN_GRAPH: usize = 32;
/// Poseidon number of instances
pub const POSEIDON_INSTANCES: usize = 1;
/// Poseidon module type
pub type ModulePoseidon =
PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, POSEIDON_LEN_GRAPH>;
pub type ModulePoseidon = PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE>;
/// Poseidon module config
pub type ModulePoseidonConfig = PoseidonConfig<POSEIDON_WIDTH, POSEIDON_RATE>;
@@ -284,7 +281,6 @@ impl GraphModules {
log::error!("Poseidon config not initialized");
return Err(Error::Synthesis);
}
// If the module is encrypted, then we need to encrypt the inputs
}
Ok(())

View File

@@ -1,10 +1,19 @@
// Import dependencies for scaling operations
use super::scale_to_multiplier;
// Import ONNX-specific utilities when EZKL feature is enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::utilities::node_output_shapes;
// Import scale management types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
// Import visibility settings for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::Visibility;
// Import operation types for different circuit components
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
@@ -13,28 +22,49 @@ use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
// Import graph error types for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::errors::GraphError;
// Import ONNX operation conversion utilities
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::new_op_from_onnx;
// Import tensor error handling
use crate::tensor::TensorError;
// Import curve-specific field type
use halo2curves::bn256::Fr as Fp;
// Import logging for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
// Import serialization traits
use serde::Deserialize;
use serde::Serialize;
// Import data structures for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::collections::BTreeMap;
// Import formatting traits for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use std::fmt;
// Import table display formatting for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tabled::Tabled;
// Import ONNX-specific types and traits for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::{
self,
prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp},
};
/// Helper function to format vectors for display
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
if !v.is_empty() {
@@ -44,29 +74,35 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
}
}
/// Helper function to format operation kinds for display
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn display_opkind(v: &SupportedOp) -> String {
v.as_string()
}
/// A wrapper for an operation that has been rescaled.
/// A wrapper for an operation that has been rescaled to handle different precision requirements.
/// This enables operations to work with inputs that have been scaled to different fixed-point representations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Rescaled {
/// The operation that has to be rescaled.
/// The underlying operation that needs to be rescaled
pub inner: Box<SupportedOp>,
/// The scale of the operation's inputs.
/// Vector of (index, scale) pairs defining how each input should be scaled
pub scale: Vec<(usize, u128)>,
}
/// Implementation of the Op trait for Rescaled operations
impl Op<Fp> for Rescaled {
/// Convert to Any type for runtime type checking
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Get string representation of the operation
fn as_string(&self) -> String {
format!("RESCALED INPUT ({})", self.inner.as_string())
}
/// Calculate output scale based on input scales
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let in_scales = in_scales
.into_iter()
@@ -77,6 +113,7 @@ impl Op<Fp> for Rescaled {
Op::<Fp>::out_scale(&*self.inner, in_scales)
}
/// Layout the operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -93,28 +130,40 @@ impl Op<Fp> for Rescaled {
self.inner.layout(config, region, res)
}
/// Create a cloned boxed copy of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
Box::new(self.clone()) // Forward to the derive(Clone) impl
Box::new(self.clone())
}
}
/// A wrapper for an operation that has been rescaled.
/// A wrapper for operations that require scale rebasing
/// This handles cases where operation scales need to be adjusted to a target scale
/// while preserving the numerical relationships
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RebaseScale {
/// The operation that has to be rescaled.
/// The operation that needs to be rescaled
pub inner: Box<SupportedOp>,
/// rebase op
/// Operation used for rebasing, typically division
pub rebase_op: HybridOp,
/// scale being rebased to
/// Scale that we're rebasing to
pub target_scale: i32,
/// The original scale of the operation's inputs.
/// Original scale of operation's inputs before rebasing
pub original_scale: i32,
/// multiplier
/// Scaling multiplier used in rebasing
pub multiplier: f64,
}
impl RebaseScale {
/// Creates a rebased version of an operation if needed
///
/// # Arguments
/// * `inner` - Operation to potentially rebase
/// * `global_scale` - Base scale for the system
/// * `op_out_scale` - Current output scale of the operation
/// * `scale_rebase_multiplier` - Factor determining when rebasing should occur
///
/// # Returns
/// Original or rebased operation depending on scale relationships
pub fn rebase(
inner: SupportedOp,
global_scale: crate::Scale,
@@ -155,7 +204,15 @@ impl RebaseScale {
}
}
/// Creates a rebased operation with increased scale
///
/// # Arguments
/// * `inner` - Operation to potentially rebase
/// * `target_scale` - Scale to rebase to
/// * `op_out_scale` - Current output scale of the operation
///
/// # Returns
/// Original or rebased operation with increased scale
pub fn rebase_up(
inner: SupportedOp,
target_scale: crate::Scale,
@@ -192,10 +249,12 @@ impl RebaseScale {
}
impl Op<Fp> for RebaseScale {
/// Convert to Any type for runtime type checking
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Get string representation of the operation
fn as_string(&self) -> String {
format!(
"REBASED (div={:?}, rebasing_op={}) ({})",
@@ -205,10 +264,12 @@ impl Op<Fp> for RebaseScale {
)
}
/// Calculate output scale based on input scales
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.target_scale)
}
/// Layout the operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -222,34 +283,40 @@ impl Op<Fp> for RebaseScale {
self.rebase_op.layout(config, region, &[original_res])
}
/// Create a cloned boxed copy of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
Box::new(self.clone()) // Forward to the derive(Clone) impl
Box::new(self.clone())
}
}
/// A single operation in a [crate::graph::Model].
/// Represents all supported operation types in the circuit
/// Each variant encapsulates a different type of operation with specific behavior
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum SupportedOp {
/// A linear operation.
/// Linear operations (polynomial-based)
Linear(PolyOp),
/// A nonlinear operation.
/// Nonlinear operations requiring lookup tables
Nonlinear(LookupOp),
/// A hybrid operation.
/// Mixed operations combining different approaches
Hybrid(HybridOp),
///
/// Input values to the circuit
Input(Input),
///
/// Constant values in the circuit
Constant(Constant<Fp>),
///
/// Placeholder for unsupported operations
Unknown(Unknown),
///
/// Operations requiring rescaling of inputs
Rescaled(Rescaled),
///
/// Operations requiring scale rebasing
RebaseScale(RebaseScale),
}
impl SupportedOp {
/// Checks if the operation is a lookup operation
///
/// # Returns
/// * `true` if operation requires lookup table
/// * `false` otherwise
pub fn is_lookup(&self) -> bool {
match self {
SupportedOp::Nonlinear(_) => true,
@@ -257,7 +324,12 @@ impl SupportedOp {
_ => false,
}
}
/// Returns input operation if this is an input
///
/// # Returns
/// * `Some(Input)` if this is an input operation
/// * `None` otherwise
pub fn get_input(&self) -> Option<Input> {
match self {
SupportedOp::Input(op) => Some(op.clone()),
@@ -265,7 +337,11 @@ impl SupportedOp {
}
}
/// Returns reference to rebased operation if this is a rebased operation
///
/// # Returns
/// * `Some(&RebaseScale)` if this is a rebased operation
/// * `None` otherwise
pub fn get_rebased(&self) -> Option<&RebaseScale> {
match self {
SupportedOp::RebaseScale(op) => Some(op),
@@ -273,7 +349,11 @@ impl SupportedOp {
}
}
/// Returns reference to lookup operation if this is a lookup operation
///
/// # Returns
/// * `Some(&LookupOp)` if this is a lookup operation
/// * `None` otherwise
pub fn get_lookup(&self) -> Option<&LookupOp> {
match self {
SupportedOp::Nonlinear(op) => Some(op),
@@ -281,7 +361,11 @@ impl SupportedOp {
}
}
/// Returns reference to constant if this is a constant
///
/// # Returns
/// * `Some(&Constant)` if this is a constant
/// * `None` otherwise
pub fn get_constant(&self) -> Option<&Constant<Fp>> {
match self {
SupportedOp::Constant(op) => Some(op),
@@ -289,7 +373,11 @@ impl SupportedOp {
}
}
/// Returns mutable reference to constant if this is a constant
///
/// # Returns
/// * `Some(&mut Constant)` if this is a constant
/// * `None` otherwise
pub fn get_mutable_constant(&mut self) -> Option<&mut Constant<Fp>> {
match self {
SupportedOp::Constant(op) => Some(op),
@@ -297,18 +385,19 @@ impl SupportedOp {
}
}
/// Creates a homogeneously rescaled version of this operation if needed
/// Only available with EZKL feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn homogenous_rescale(
&self,
in_scales: Vec<crate::Scale>,
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let inputs_to_scale = self.requires_homogenous_input_scales();
// creates a rescaled op if the inputs are not homogenous
let op = self.clone_dyn();
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
}
/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
/// Returns reference to underlying Op implementation
fn as_op(&self) -> &dyn Op<Fp> {
match self {
SupportedOp::Linear(op) => op,
@@ -322,9 +411,10 @@ impl SupportedOp {
}
}
/// check if is the identity operation
/// Checks if this is an identity operation
///
/// # Returns
/// * `true` if the operation is the identity operation
/// * `true` if this operation passes input through unchanged
/// * `false` otherwise
pub fn is_identity(&self) -> bool {
match self {
@@ -361,9 +451,11 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
if let Some(op) = value.as_any().downcast_ref::<Unknown>() {
return SupportedOp::Unknown(op.clone());
};
if let Some(op) = value.as_any().downcast_ref::<Rescaled>() {
return SupportedOp::Rescaled(op.clone());
};
if let Some(op) = value.as_any().downcast_ref::<RebaseScale>() {
return SupportedOp::RebaseScale(op.clone());
};
@@ -375,6 +467,7 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
}
impl Op<Fp> for SupportedOp {
/// Layout this operation in the circuit
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -384,54 +477,61 @@ impl Op<Fp> for SupportedOp {
self.as_op().layout(config, region, values)
}
/// Check if this is an input operation
fn is_input(&self) -> bool {
self.as_op().is_input()
}
/// Check if this is a constant operation
fn is_constant(&self) -> bool {
self.as_op().is_constant()
}
/// Get which inputs require homogeneous scales
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
self.as_op().requires_homogenous_input_scales()
}
/// Create a clone of this operation
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
self.as_op().clone_dyn()
}
/// Get string representation
fn as_string(&self) -> String {
self.as_op().as_string()
}
/// Convert to Any type
fn as_any(&self) -> &dyn std::any::Any {
self
}
/// Calculate output scale from input scales
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
self.as_op().out_scale(in_scales)
}
}
/// A node's input is a tensor from another node's output.
/// Represents a connection to another node's output
/// First element is node index, second is output slot index
pub type Outlet = (usize, usize);
/// A single operation in a [crate::graph::Model].
/// Represents a single computational node in the circuit graph
/// Contains all information needed to execute and connect operations
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Node {
/// [Op] i.e what operation this node represents.
/// The operation this node performs
pub opkind: SupportedOp,
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
/// Fixed point scale factor for this node's output
pub out_scale: i32,
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
// but in_dim is [in], out_dim is [out]
/// The indices of the node's inputs.
/// Connections to other nodes' outputs that serve as inputs
pub inputs: Vec<Outlet>,
/// Dimensions of output.
/// Shape of this node's output tensor
pub out_dims: Vec<usize>,
/// The node's unique identifier.
/// Unique identifier for this node
pub idx: usize,
/// The node's num of uses
/// Number of times this node's output is used
pub num_uses: usize,
}
@@ -469,12 +569,19 @@ impl PartialEq for Node {
}
impl Node {
/// Converts a tract [OnnxNode] into an ezkl [Node].
/// # Arguments:
/// * `node` - [OnnxNode]
/// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph.
/// * `public_params` - flag if parameters of model are public
/// * `idx` - The node's unique identifier.
/// Creates a new Node from an ONNX node
/// Only available when EZKL feature is enabled
///
/// # Arguments
/// * `node` - Source ONNX node
/// * `other_nodes` - Map of existing nodes in the graph
/// * `scales` - Scale factors for variables
/// * `idx` - Unique identifier for this node
/// * `symbol_values` - ONNX symbol values
/// * `run_args` - Runtime configuration arguments
///
/// # Returns
/// New Node instance or error if creation fails
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[allow(clippy::too_many_arguments)]
pub fn new(
@@ -612,16 +719,14 @@ impl Node {
})
}
/// check if it is a softmax node
/// Check if this node performs softmax operation
pub fn is_softmax(&self) -> bool {
if let SupportedOp::Hybrid(HybridOp::Softmax { .. }) = self.opkind {
true
} else {
false
}
matches!(self.opkind, SupportedOp::Hybrid(HybridOp::Softmax { .. }))
}
}
/// Helper function to rescale constants that are only used once
/// Only available when EZKL feature is enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn rescale_const_with_single_use(
constant: &mut Constant<Fp>,

View File

@@ -1,493 +0,0 @@
use log::{debug, error, info};
use std::fmt::Debug;
use std::net::IpAddr;
#[cfg(all(not(not(feature = "ezkl")), unix))]
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, pin::Pin};
use tokio::task::JoinHandle;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
};
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::NoTls;
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
/// Connection configuration.
///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
///
///
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("config", &self.config)
.finish()
}
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
tokio_postgres::Config::new().into()
}
/// Sets the user to authenticate with.
///
/// If the user is not set, then this defaults to the user executing this process.
pub fn user(&mut self, user: &str) -> &mut Config {
self.config.user(user);
self
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn get_user(&self) -> Option<&str> {
self.config.get_user()
}
/// Sets the password to authenticate with.
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.config.password(password);
self
}
/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.config.get_password()
}
/// Sets the name of the database to connect to.
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.config.dbname(dbname);
self
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn get_dbname(&self) -> Option<&str> {
self.config.get_dbname()
}
/// Sets command line options used to configure the server.
pub fn options(&mut self, options: &str) -> &mut Config {
self.config.options(options);
self
}
/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.config.get_options()
}
/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}
/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.config.get_application_name()
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.config.ssl_mode(ssl_mode);
self
}
/// Gets the SSL configuration.
pub fn get_ssl_mode(&self) -> SslMode {
self.config.get_ssl_mode()
}
/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
/// There must be either no hosts, or the same number of hosts as hostaddrs.
pub fn host(&mut self, host: &str) -> &mut Config {
self.config.host(host);
self
}
/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &[Host] {
self.config.get_hosts()
}
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
pub fn get_hostaddrs(&self) -> &[IpAddr] {
self.config.get_hostaddrs()
}
/// Adds a Unix socket host to the configuration.
///
/// Unlike `host`, this method allows non-UTF8 paths.
#[cfg(all(not(not(feature = "ezkl")), unix))]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.config.host_path(host);
self
}
/// Adds a hostaddr to the configuration.
///
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
self.config.hostaddr(hostaddr);
self
}
/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.config.port(port);
self
}
/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &[u16] {
self.config.get_ports()
}
/// Sets the timeout applied to socket-level connection attempts.
///
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
/// host separately. Defaults to no limit.
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.config.connect_timeout(connect_timeout);
self
}
/// Gets the connection timeout, if one has been set with the
/// `connect_timeout` method.
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.config.get_connect_timeout()
}
/// Sets the TCP user timeout.
///
/// This is ignored for Unix domain socket connections. It is only supported on systems where
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
/// on other systems, it has no effect.
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
self.config.tcp_user_timeout(tcp_user_timeout);
self
}
/// Gets the TCP user timeout, if one has been set with the
/// `user_timeout` method.
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
self.config.get_tcp_user_timeout()
}
/// Controls the use of TCP keepalive.
///
/// This is ignored for Unix domain socket connections. Defaults to `true`.
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.config.keepalives(keepalives);
self
}
/// Reports whether TCP keepalives will be used.
pub fn get_keepalives(&self) -> bool {
self.config.get_keepalives()
}
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.config.keepalives_idle(keepalives_idle);
self
}
/// Gets the configured amount of idle time before a keepalive packet will
/// be sent on the connection.
pub fn get_keepalives_idle(&self) -> Duration {
self.config.get_keepalives_idle()
}
/// Sets the time interval between TCP keepalive probes.
/// On Windows, this sets the value of the tcp_keepalive structs keepaliveinterval field.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
self.config.keepalives_interval(keepalives_interval);
self
}
/// Gets the time interval between TCP keepalive probes.
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.config.get_keepalives_interval()
}
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
self.config.keepalives_retries(keepalives_retries);
self
}
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.config.get_keepalives_retries()
}
/// Sets the requirements of the session.
///
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
/// secondary servers. Defaults to `Any`.
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.config.target_session_attrs(target_session_attrs);
self
}
/// Gets the requirements of the session.
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.config.get_target_session_attrs()
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
/// Gets the channel binding behavior.
pub fn get_channel_binding(&self) -> ChannelBinding {
self.config.get_channel_binding()
}
/// Sets the host load balancing behavior.
///
/// Defaults to `disable`.
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.config.load_balance_hosts(load_balance_hosts);
self
}
/// Gets the host load balancing behavior.
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.config.get_load_balance_hosts()
}
/// Sets the notice callback.
///
/// This callback will be invoked with the contents of every
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
/// the same structure as errors, but they are not "errors" per-se.
///
/// Notices are distinct from notifications, which are instead accessible
/// via the [`Notifications`] API.
///
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
/// [`Notifications`]: crate::Notifications
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}
/// Opens a connection to a PostgreSQL database.
pub async fn connect(&self) -> Result<Client, Error> {
let (client, connection) = self.config.connect(NoTls).await?;
let connection = Connection::new(connection);
Ok(Client::new(client, connection))
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
s.parse::<tokio_postgres::Config>().map(Config::from)
}
}
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
/// get dropped.
pub struct Connection {
/// The underlying connection stream.
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
}
impl Connection {
/// Creates a new connection.
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
Connection {
connection: Box::pin(connection),
}
}
/// start the connection
pub async fn start(self) {
if let Err(e) = self.connection.await {
error!("connection error: {}", e);
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL client.
pub struct Client {
connection: JoinHandle<()>,
client: tokio_postgres::Client,
}
impl Drop for Client {
fn drop(&mut self) {
let _ = self.close_inner();
}
}
impl Client {
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
let thread = tokio::spawn(async move {
connection.start().await;
});
Client {
client,
connection: thread,
}
}
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
///
/// See the documentation for [`Config`] for information about the connection syntax.
///
/// [`Config`]: config/struct.Config.html
pub async fn connect(params: &str) -> Result<Client, Error> {
debug!("Connecting to database with params: {}", params);
params.parse::<Config>()?.connect().await
}
/// Returns a new `Config` object which can be used to configure and connect to a database.
pub fn configure() -> Config {
Config::new()
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
pub async fn execute<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.execute(query, params).await
}
/// Executes a statement, returning the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Examples
///
pub async fn query<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.query(query, params).await
}
/// Determines if the client's connection has already closed.
///
/// If this returns `true`, the client is no longer usable.
pub fn is_closed(&self) -> bool {
self.client.is_closed()
}
/// Closes the client's connection to the server.
///
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
/// caller.
pub fn close(mut self) -> Result<(), Error> {
self.close_inner()
}
fn close_inner(&mut self) -> Result<(), Error> {
self.client.__private_api_close();
Ok(())
}
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,6 +22,7 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -31,7 +32,6 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{
@@ -39,16 +39,15 @@ use tract_onnx::tract_hir::{
ops::array::{Pad, PadMode, TypedConcat},
ops::cnn::PoolSpec,
ops::konst::Const,
ops::nn::DataFormat,
tract_core::ops::cast::Cast,
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
tract_core::ops::cnn::{MaxPool, SumPool},
};
/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
/// Quantizes an iterable of f64 to a [Tensor] of IntegerRep using a fixed point representation.
/// NAN gets mapped to 0. INFINITY and NEG_INFINITY error out.
/// Arguments
///
/// * `vec` - the vector to quantize.
/// * `dims` - the dimensionality of the resulting [Tensor].
/// * `elem` - the element to quantize.
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float(
@@ -59,7 +58,7 @@ pub fn quantize_float(
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
if *elem > max_value || *elem < -max_value {
return Err(TensorError::SigBitTruncationError);
}
@@ -85,7 +84,7 @@ pub fn scale_to_multiplier(scale: crate::Scale) -> f64 {
f64::powf(2., scale as f64)
}
/// Converts a scale (log base 2) to a fixed point multiplier.
/// Converts a fixed point multiplier to a scale (log base 2).
pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
mult.log2().round() as crate::Scale
}
@@ -142,8 +141,6 @@ use tract_onnx::prelude::SymbolValues;
pub fn extract_tensor_value(
input: Arc<tract_onnx::prelude::Tensor>,
) -> Result<Tensor<f32>, GraphError> {
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
let dt = input.datum_type();
let dims = input.shape().to_vec();
@@ -156,7 +153,7 @@ pub fn extract_tensor_value(
match dt {
DatumType::F16 => {
let vec = input.as_slice::<tract_onnx::prelude::f16>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| (*x).into()).collect();
let cast: Vec<f32> = vec.iter().map(|x| (*x).into()).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::F32 => {
@@ -165,61 +162,61 @@ pub fn extract_tensor_value(
}
DatumType::F64 => {
let vec = input.as_slice::<f64>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I64 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i64>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I32 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i32>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I16 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i16>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I8 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i8>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U8 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u8>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U16 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u16>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U32 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u32>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U64 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u64>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::Bool => {
// Generally a shape or hyperparam
let vec = input.as_slice::<bool>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as usize as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as usize as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::TDim => {
@@ -227,13 +224,10 @@ pub fn extract_tensor_value(
let vec = input.as_slice::<tract_onnx::prelude::TDim>()?.to_vec();
let cast: Result<Vec<f32>, GraphError> = vec
.par_iter()
.iter()
.map(|x| match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
},
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
})
.collect();
@@ -279,11 +273,9 @@ pub fn new_op_from_onnx(
symbol_values: &SymbolValues,
run_args: &crate::RunArgs,
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
use std::f64::consts::E;
use tract_onnx::tract_core::ops::array::Trilu;
use crate::circuit::InputType;
use std::f64::consts::E;
use tract_onnx::tract_core::ops::array::Trilu;
let input_scales = inputs
.iter()
@@ -314,6 +306,9 @@ pub fn new_op_from_onnx(
let mut deleted_indices = vec![];
let node = match node.op().name().as_ref() {
"ShiftLeft" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
};
// load shift amount
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
@@ -326,10 +321,13 @@ pub fn new_op_from_onnx(
out_scale: Some(input_scales[0] - raw_values[0] as i32),
})
} else {
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
return Err(GraphError::OpMismatch(idx, "shift left".to_string()));
}
}
"ShiftRight" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
};
// load shift amount
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
@@ -342,7 +340,7 @@ pub fn new_op_from_onnx(
out_scale: Some(input_scales[0] + raw_values[0] as i32),
})
} else {
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
return Err(GraphError::OpMismatch(idx, "shift right".to_string()));
}
}
"MultiBroadcastTo" => {
@@ -365,7 +363,10 @@ pub fn new_op_from_onnx(
}
}
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
if input_ops.len() != 3 {
return Err(GraphError::InvalidDims(idx, "range".to_string()));
}
let input_ops = input_ops
.iter()
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
@@ -380,7 +381,11 @@ pub fn new_op_from_onnx(
// Quantize the raw value (integers)
let quantized_value = quantize_tensor(raw_value.clone(), 0, &Visibility::Fixed)?;
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
let c = crate::circuit::ops::Constant::new(
quantized_value,
raw_value,
!run_args.ignore_range_check_inputs_outputs,
);
// Create a constant op
SupportedOp::Constant(c)
}
@@ -421,6 +426,10 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
if inputs[0].out_dims().is_empty() || inputs[0].out_dims()[0].len() <= axis {
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
}
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| {
@@ -438,6 +447,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: false,
}));
inputs[1].bump_scale(0);
}
@@ -449,8 +459,17 @@ pub fn new_op_from_onnx(
"Topk" => {
let op = load_op::<Topk>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
};
// if param_visibility.is_public() {
let k = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
if c.raw_values.len() != 1 {
return Err(GraphError::InvalidDims(idx, "topk".to_string()));
}
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
c.raw_values.map(|x| x as usize)[0]
@@ -490,6 +509,10 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -501,6 +524,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -524,6 +548,9 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -534,6 +561,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -557,6 +585,9 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
@@ -568,6 +599,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -591,6 +623,9 @@ pub fn new_op_from_onnx(
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
}
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -602,6 +637,7 @@ pub fn new_op_from_onnx(
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
decomp: !run_args.ignore_range_check_inputs_outputs,
}));
inputs[1].bump_scale(0);
}
@@ -676,7 +712,11 @@ pub fn new_op_from_onnx(
constant_scale,
&run_args.param_visibility,
)?;
let c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
let c = crate::circuit::ops::Constant::new(
quantized_value,
raw_value,
run_args.ignore_range_check_inputs_outputs,
);
// Create a constant op
SupportedOp::Constant(c)
}
@@ -686,7 +726,9 @@ pub fn new_op_from_onnx(
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
assert_eq!(axes.len(), 1, "only support argmax over one axis");
if axes.len() != 1 {
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
}
SupportedOp::Hybrid(HybridOp::ReduceArgMax { dim: axes[0] })
}
@@ -696,7 +738,9 @@ pub fn new_op_from_onnx(
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes: Vec<usize> = op.axes.into_iter().collect();
assert_eq!(axes.len(), 1, "only support argmin over one axis");
if axes.len() != 1 {
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
}
SupportedOp::Hybrid(HybridOp::ReduceArgMin { dim: axes[0] })
}
@@ -805,12 +849,16 @@ pub fn new_op_from_onnx(
}
}
"Recip" => {
if inputs.len() != 1 {
return Err(GraphError::InvalidDims(idx, "recip".to_string()));
};
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
// If the input scale is larger than the params scale
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
eps: run_args.get_epsilon(),
})
}
@@ -848,11 +896,15 @@ pub fn new_op_from_onnx(
scale: scale_to_multiplier(input_scales[0]).into(),
}),
"Rsqrt" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "rsqrt".to_string()));
};
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Rsqrt {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
eps: run_args.get_epsilon(),
})
}
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
@@ -863,6 +915,7 @@ pub fn new_op_from_onnx(
if run_args.bounded_log_lookup {
SupportedOp::Hybrid(HybridOp::Ln {
scale: scale_to_multiplier(input_scales[0]).into(),
eps: run_args.get_epsilon(),
})
} else {
SupportedOp::Nonlinear(LookupOp::Ln {
@@ -929,13 +982,19 @@ pub fn new_op_from_onnx(
DatumType::F64 => (scales.input, InputType::F64),
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
};
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
SupportedOp::Input(crate::circuit::ops::Input {
scale,
datum_type,
decomp: !run_args.ignore_range_check_inputs_outputs,
})
}
"Cast" => {
let op = load_op::<Cast>(node.op(), idx, node.op().name().to_string())?;
let dt = op.to;
assert_eq!(input_scales.len(), 1);
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "cast".to_string()));
};
match dt {
DatumType::Bool
@@ -985,6 +1044,11 @@ pub fn new_op_from_onnx(
if const_idx.len() == 1 {
let const_idx = const_idx[0];
if inputs.len() <= const_idx {
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
}
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
if c.raw_values.len() == 1 && c.raw_values[0] < 1. {
// if not divisible by 2 then we need to add a range check
@@ -1059,6 +1123,9 @@ pub fn new_op_from_onnx(
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
}
};
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "softmax".to_string()));
}
let in_scale = input_scales[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
@@ -1067,6 +1134,7 @@ pub fn new_op_from_onnx(
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
eps: run_args.get_epsilon(),
})
}
"MaxPool" => {
@@ -1081,13 +1149,6 @@ pub fn new_op_from_onnx(
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
let kernel_shape = &pool_spec.kernel_shape;
@@ -1096,24 +1157,45 @@ pub fn new_op_from_onnx(
padding,
stride: stride.to_vec(),
pool_dims: kernel_shape.to_vec(),
data_format: pool_spec.data_format.into(),
})
}
"Ceil" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "ceil".to_string()));
}
SupportedOp::Hybrid(HybridOp::Ceil {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Floor" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "floor".to_string()));
}
SupportedOp::Hybrid(HybridOp::Floor {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Round" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "round".to_string()));
}
SupportedOp::Hybrid(HybridOp::Round {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"RoundHalfToEven" => {
if input_scales.len() != 1 {
return Err(GraphError::InvalidDims(idx, "roundhalftoeven".to_string()));
}
SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
})
}
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Round" => SupportedOp::Hybrid(HybridOp::Round {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
scale: scale_to_multiplier(input_scales[0]).into(),
legs: run_args.decomp_legs,
}),
"Sign" => SupportedOp::Linear(PolyOp::Sign),
"Pow" => {
// Extract the slope layer hyperparams from a const
@@ -1123,7 +1205,9 @@ pub fn new_op_from_onnx(
inputs[1].decrement_use();
deleted_indices.push(1);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar pow")
return Err(GraphError::NonScalarPower);
} else if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
}
let exponent = c.raw_values[0];
@@ -1136,26 +1220,30 @@ pub fn new_op_from_onnx(
a: crate::circuit::utils::F32(exponent),
})
}
} else {
if let Some(c) = inputs[0].opkind().get_mutable_constant() {
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar base")
}
let base = c.raw_values[0];
SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[1]).into(),
base: base.into(),
})
} else {
unimplemented!("only support constant base or pow for now")
} else if let Some(c) = inputs[0].opkind().get_mutable_constant() {
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
return Err(GraphError::NonScalarBase);
} else if c.raw_values.is_empty() {
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
}
let base = c.raw_values[0];
SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[1]).into(),
base: base.into(),
})
} else {
return Err(GraphError::InvalidDims(idx, "pow".to_string()));
}
}
"Div" => {
if inputs.len() != 2 {
return Err(GraphError::InvalidDims(idx, "div".to_string()));
}
let const_idx = inputs
.iter()
.enumerate()
@@ -1163,14 +1251,15 @@ pub fn new_op_from_onnx(
.map(|(i, _)| i)
.collect::<Vec<_>>();
if const_idx.len() > 1 {
if const_idx.len() > 1 || const_idx.is_empty() {
return Err(GraphError::InvalidDims(idx, "div".to_string()));
}
let const_idx = const_idx[0];
if const_idx != 1 {
unimplemented!("only support div with constant as second input")
return Err(GraphError::MisformedParams(
"only support div with constant as second input".to_string(),
));
}
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
@@ -1180,14 +1269,28 @@ pub fn new_op_from_onnx(
// get the non constant index
let denom = c.raw_values[0];
SupportedOp::Hybrid(HybridOp::Div {
let op = SupportedOp::Hybrid(HybridOp::Div {
denom: denom.into(),
})
});
// if the input is scale 0 we re up to the max scale
if input_scales[0] == 0 {
SupportedOp::Rescaled(Rescaled {
inner: Box::new(op),
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
})
} else {
op
}
} else {
unimplemented!("only support non zero divisors of size 1")
return Err(GraphError::MisformedParams(
"only support non zero divisors of size 1".to_string(),
));
}
} else {
unimplemented!("only support div with constant as second input")
return Err(GraphError::MisformedParams(
"only support div with constant as second input".to_string(),
));
}
}
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
@@ -1208,15 +1311,6 @@ pub fn new_op_from_onnx(
}
}
if ((conv_node.pool_spec.data_format != DataFormat::NCHW)
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
{
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
));
}
let pool_spec = &conv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
@@ -1244,6 +1338,8 @@ pub fn new_op_from_onnx(
padding,
stride,
group,
data_format: conv_node.pool_spec.data_format.into(),
kernel_format: conv_node.kernel_fmt.into(),
})
}
"Not" => SupportedOp::Linear(PolyOp::Not),
@@ -1267,14 +1363,6 @@ pub fn new_op_from_onnx(
}
}
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|| (deconv_node.kernel_format != KernelFormat::OIHW)
{
return Err(GraphError::MisformedParams(
"data or kernel in wrong format".to_string(),
));
}
let pool_spec = &deconv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
@@ -1300,6 +1388,8 @@ pub fn new_op_from_onnx(
output_padding: deconv_node.adjustments.to_vec(),
stride,
group: deconv_node.group,
data_format: deconv_node.pool_spec.data_format.into(),
kernel_format: deconv_node.kernel_format.into(),
})
}
"Downsample" => {
@@ -1312,7 +1402,7 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Downsample {
axis: downsample_node.axis,
stride: downsample_node.stride as usize,
stride: downsample_node.stride,
modulo: downsample_node.modulo,
})
}
@@ -1327,7 +1417,7 @@ pub fn new_op_from_onnx(
if !resize_node.contains("interpolator: Nearest")
&& !resize_node.contains("nearest: Floor")
{
unimplemented!("Only nearest neighbor interpolation is supported")
return Err(GraphError::InvalidInterpolation);
}
// check if optional scale factor is present
if inputs.len() != 2 && inputs.len() != 3 {
@@ -1383,13 +1473,6 @@ pub fn new_op_from_onnx(
let pool_spec: &PoolSpec = &sumpool_node.pool_spec;
// only support pytorch type formatting for now
if pool_spec.data_format != DataFormat::NCHW {
return Err(GraphError::MissingParams(
"data in wrong format".to_string(),
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
@@ -1398,6 +1481,7 @@ pub fn new_op_from_onnx(
stride: stride.to_vec(),
kernel_shape: pool_spec.kernel_shape.to_vec(),
normalized: sumpool_node.normalize,
data_format: pool_spec.data_format.into(),
})
}
"Pad" => {
@@ -1431,6 +1515,10 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Reshape(output_shape))
}
"Flatten" => {
if inputs.len() != 1 || inputs[0].out_dims().is_empty() {
return Err(GraphError::InvalidDims(idx, "flatten".to_string()));
};
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
SupportedOp::Linear(PolyOp::Flatten(new_dims))
}
@@ -1504,12 +1592,10 @@ pub fn homogenize_input_scales(
input_scales: Vec<crate::Scale>,
inputs_to_scale: Vec<usize>,
) -> Result<Box<dyn Op<Fp>>, GraphError> {
let relevant_input_scales = input_scales
.clone()
.into_iter()
.enumerate()
.filter(|(idx, _)| inputs_to_scale.contains(idx))
.map(|(_, scale)| scale)
let relevant_input_scales = inputs_to_scale
.iter()
.filter(|idx| input_scales.len() > **idx)
.map(|&idx| input_scales[idx])
.collect_vec();
if inputs_to_scale.is_empty() {
@@ -1550,10 +1636,30 @@ pub fn homogenize_input_scales(
}
#[cfg(test)]
/// tests for the utility module
pub mod tests {
use super::*;
// quantization tests
#[test]
fn test_quantize_tensor() {
let tensor: Tensor<f32> = (0..10).map(|x| x as f32).into();
let reference: Tensor<Fp> = (0..10).map(|x| x.into()).into();
let scale = 0;
let visibility = &Visibility::Public;
let quantized: Tensor<Fp> = quantize_tensor(tensor, scale, visibility).unwrap();
assert_eq!(quantized.len(), 10);
assert_eq!(quantized, reference);
}
#[test]
fn test_quantize_edge_cases() {
assert_eq!(quantize_float(&f64::NAN, 0.0, 0).unwrap(), 0);
assert!(quantize_float(&f64::INFINITY, 0.0, 0).is_err());
assert!(quantize_float(&f64::NEG_INFINITY, 0.0, 0).is_err());
}
#[test]
fn test_flatten_valtensors() {
let tensor1: Tensor<Fp> = (0..10).map(|x| x.into()).into();

View File

@@ -11,35 +11,34 @@ use log::debug;
use pyo3::{
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
use self::errors::GraphError;
use super::*;
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
/// Defines the visibility level of values within the zero-knowledge circuit
/// Controls how values are handled during proof generation and verification
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum Visibility {
/// Mark an item as private to the prover (not in the proof submitted for verification)
/// Value is private to the prover and not included in proof
#[default]
Private,
/// Mark an item as public (sent in the proof submitted for verification)
/// Value is public and included in proof for verification
Public,
/// Mark an item as publicly committed to (hash sent in the proof submitted for verification)
/// Value is hashed and the hash is included in proof
Hashed {
/// Whether the hash is used as an instance (sent in the proof submitted for verification)
/// if false the hash is used as an advice (not in the proof submitted for verification) and is then sent to the computational graph
/// if true the hash is used as an instance (sent in the proof submitted for verification) the *inputs* to the hashing function are then sent to the computational graph
/// Controls how the hash is handled in proof
/// true - hash is included directly in proof (public)
/// false - hash is used as advice and passed to computational graph
hash_is_public: bool,
///
/// Specifies which outputs this hash affects
outlets: Vec<usize>,
},
/// Mark an item as publicly committed to (KZG commitment sent in the proof submitted for verification)
/// Value is committed using KZG commitment scheme
KZGCommit,
/// assigned as a constant in the circuit
/// Value is assigned as a constant in the circuit
Fixed,
}
@@ -66,15 +65,17 @@ impl Display for Visibility {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl ToFlags for Visibility {
/// Converts visibility to command line flags
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl<'a> From<&'a str> for Visibility {
/// Converts string representation to Visibility
fn from(s: &'a str) -> Self {
if s.contains("hashed/private") {
// split on last occurrence of '/'
// Split on last occurrence of '/'
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
let outlets = outlets
.trim_start_matches('/')
@@ -106,8 +107,8 @@ impl<'a> From<&'a str> for Visibility {
}
#[cfg(feature = "python-bindings")]
/// Converts Visibility into a PyObject (Required for Visibility to be compatible with Python)
impl IntoPy<PyObject> for Visibility {
/// Converts Visibility to Python object
fn into_py(self, py: Python) -> PyObject {
match self {
Visibility::Private => "private".to_object(py),
@@ -134,14 +135,13 @@ impl IntoPy<PyObject> for Visibility {
}
#[cfg(feature = "python-bindings")]
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
impl<'source> FromPyObject<'source> for Visibility {
/// Extracts Visibility from Python object
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let strval = String::extract_bound(ob)?;
let strval = strval.as_str();
if strval.contains("hashed/private") {
// split on last occurence of '/'
let (_, outlets) = strval.split_at(strval.rfind('/').unwrap());
let outlets = outlets
.trim_start_matches('/')
@@ -174,29 +174,32 @@ impl<'source> FromPyObject<'source> for Visibility {
}
impl Visibility {
#[allow(missing_docs)]
/// Returns true if visibility is Fixed
pub fn is_fixed(&self) -> bool {
matches!(&self, Visibility::Fixed)
}
#[allow(missing_docs)]
/// Returns true if visibility is Private or hashed private
pub fn is_private(&self) -> bool {
matches!(&self, Visibility::Private) || self.is_hashed_private()
}
#[allow(missing_docs)]
/// Returns true if visibility is Public
pub fn is_public(&self) -> bool {
matches!(&self, Visibility::Public)
}
#[allow(missing_docs)]
/// Returns true if visibility involves hashing
pub fn is_hashed(&self) -> bool {
matches!(&self, Visibility::Hashed { .. })
}
#[allow(missing_docs)]
/// Returns true if visibility uses KZG commitment
pub fn is_polycommit(&self) -> bool {
matches!(&self, Visibility::KZGCommit)
}
#[allow(missing_docs)]
/// Returns true if visibility is hashed with public hash
pub fn is_hashed_public(&self) -> bool {
if let Visibility::Hashed {
hash_is_public: true,
@@ -207,7 +210,8 @@ impl Visibility {
}
false
}
#[allow(missing_docs)]
/// Returns true if visibility is hashed with private hash
pub fn is_hashed_private(&self) -> bool {
if let Visibility::Hashed {
hash_is_public: false,
@@ -219,11 +223,12 @@ impl Visibility {
false
}
#[allow(missing_docs)]
/// Returns true if visibility requires additional processing
pub fn requires_processing(&self) -> bool {
matches!(&self, Visibility::Hashed { .. }) | matches!(&self, Visibility::KZGCommit)
}
#[allow(missing_docs)]
/// Returns vector of output indices that this visibility setting affects
pub fn overwrites_inputs(&self) -> Vec<usize> {
if let Visibility::Hashed { outlets, .. } = self {
return outlets.clone();
@@ -232,14 +237,14 @@ impl Visibility {
}
}
/// Represents the scale of the model input, model parameters.
/// Manages scaling factors for different parts of the model
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
pub struct VarScales {
///
/// Scale factor for input values
pub input: crate::Scale,
///
/// Scale factor for parameter values
pub params: crate::Scale,
///
/// Multiplier for scale rebasing
pub rebase_multiplier: u32,
}
@@ -250,17 +255,17 @@ impl std::fmt::Display for VarScales {
}
impl VarScales {
///
/// Returns maximum scale value
pub fn get_max(&self) -> crate::Scale {
std::cmp::max(self.input, self.params)
}
///
/// Returns minimum scale value
pub fn get_min(&self) -> crate::Scale {
std::cmp::min(self.input, self.params)
}
/// Place in [VarScales] struct.
/// Creates VarScales from runtime arguments
pub fn from_args(args: &RunArgs) -> Self {
Self {
input: args.input_scale,
@@ -270,16 +275,17 @@ impl VarScales {
}
}
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
/// Controls visibility settings for different parts of the model
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
pub struct VarVisibility {
/// Input to the model or computational graph
/// Visibility of model inputs
pub input: Visibility,
/// Parameters, such as weights and biases, in the model
/// Visibility of model parameters (weights, biases)
pub params: Visibility,
/// Output of the model or computational graph
/// Visibility of model outputs
pub output: Visibility,
}
impl std::fmt::Display for VarVisibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
@@ -301,8 +307,7 @@ impl Default for VarVisibility {
}
impl VarVisibility {
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
/// Place in [VarVisibility] struct.
/// Creates visibility settings from runtime arguments
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
let input_vis = &args.input_visibility;
let params_vis = &args.param_visibility;
@@ -313,17 +318,17 @@ impl VarVisibility {
}
if !output_vis.is_public()
& !params_vis.is_public()
& !input_vis.is_public()
& !output_vis.is_fixed()
& !params_vis.is_fixed()
& !input_vis.is_fixed()
& !output_vis.is_hashed()
& !params_vis.is_hashed()
& !input_vis.is_hashed()
& !output_vis.is_polycommit()
& !params_vis.is_polycommit()
& !input_vis.is_polycommit()
&& !params_vis.is_public()
&& !input_vis.is_public()
&& !output_vis.is_fixed()
&& !params_vis.is_fixed()
&& !input_vis.is_fixed()
&& !output_vis.is_hashed()
&& !params_vis.is_hashed()
&& !input_vis.is_hashed()
&& !output_vis.is_polycommit()
&& !params_vis.is_polycommit()
&& !input_vis.is_polycommit()
{
return Err(GraphError::Visibility);
}
@@ -335,17 +340,17 @@ impl VarVisibility {
}
}
/// A wrapper for holding all columns that will be assigned to by a model.
/// Container for circuit columns used by a model
#[derive(Clone, Debug)]
pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
#[allow(missing_docs)]
/// Advice columns for circuit assignments
pub advices: Vec<VarTensor>,
#[allow(missing_docs)]
/// Optional instance column for public inputs
pub instance: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
/// Get instance col
/// Gets reference to instance column if it exists
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
if let Some(instance) = &self.instance {
match instance {
@@ -357,14 +362,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
/// Set the initial instance offset
/// Sets initial offset for instance values
pub fn set_initial_instance_offset(&mut self, offset: usize) {
if let Some(instance) = &mut self.instance {
instance.set_initial_instance_offset(offset);
}
}
/// Get the total instance len
/// Gets total length of instance data
pub fn get_instance_len(&self) -> usize {
if let Some(instance) = &self.instance {
instance.get_total_instance_len()
@@ -373,21 +378,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
/// Increment the instance offset
/// Increments instance index
pub fn increment_instance_idx(&mut self) {
if let Some(instance) = &mut self.instance {
instance.increment_idx();
}
}
/// Reset the instance offset
/// Sets instance index to specific value
pub fn set_instance_idx(&mut self, val: usize) {
if let Some(instance) = &mut self.instance {
instance.set_idx(val);
}
}
/// Get the instance offset
/// Gets current instance index
pub fn get_instance_idx(&self) -> usize {
if let Some(instance) = &self.instance {
instance.get_idx()
@@ -396,7 +401,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
}
}
///
/// Initializes instance column with specified dimensions and scale
pub fn instantiate_instance(
&mut self,
cs: &mut ConstraintSystem<F>,
@@ -417,7 +422,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
};
}
/// Allocate all columns that will be assigned to by a model.
/// Creates new ModelVars with allocated columns based on settings
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
debug!("number of blinding factors: {}", cs.blinding_factors());
@@ -435,7 +440,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
.collect_vec();
if requires_dynamic_lookup || requires_shuffle {
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
let num_cols = 3;
for _ in 0..num_cols {
let dynamic_lookup =
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);

View File

@@ -28,6 +28,9 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use log::warn;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc as _;
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
@@ -41,6 +44,7 @@ pub enum EZKLError {
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[eth] {0}")]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
EthError(#[from] eth::EthError),
#[error("[graph] {0}")]
GraphError(#[from] graph::errors::GraphError),
@@ -94,12 +98,11 @@ impl From<String> for EZKLError {
use std::str::FromStr;
use circuit::{table::Range, CheckMode, Tolerance};
use circuit::{table::Range, CheckMode};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use fieldutils::IntegerRep;
use graph::Visibility;
use graph::{Visibility, MAX_PUBLIC_SRS};
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};
@@ -132,7 +135,7 @@ pub mod circuit;
/// CLI commands.
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod commands;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[cfg(all(feature = "eth", not(target_arch = "wasm32")))]
// abigen doesn't generate docs for this module
#[allow(missing_docs)]
/// Utility functions for contracts
@@ -165,7 +168,6 @@ pub mod srs_sha;
pub mod tensor;
#[cfg(feature = "ios-bindings")]
uniffi::setup_scaffolding!();
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use lazy_static::lazy_static;
@@ -180,11 +182,9 @@ lazy_static! {
.unwrap_or("8000".to_string())
.parse()
.unwrap();
/// The serialization format for the keys
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
.unwrap_or("raw-bytes".to_string());
}
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
@@ -266,80 +266,111 @@ impl From<String> for Commitments {
}
/// Parameters specific to a proving run
///
/// RunArgs contains all configuration parameters needed to control the proving process,
/// including scaling factors, visibility settings, and circuit parameters.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
derive(Args, ToFlags)
)]
pub struct RunArgs {
/// The tolerance for error on model outputs
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))]
pub tolerance: Tolerance,
/// The denominator in the fixed point representation used when quantizing inputs
/// Fixed point scaling factor for quantizing inputs
/// Higher values provide more precision but increase circuit complexity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub input_scale: Scale,
/// The denominator in the fixed point representation used when quantizing parameters
/// Fixed point scaling factor for quantizing parameters
/// Higher values provide more precision but increase circuit complexity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))]
pub param_scale: Scale,
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
/// Scale rebase threshold multiplier
/// When scale exceeds input_scale * multiplier, it is rebased to input_scale
/// Advanced parameter that should be used with caution
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))]
pub scale_rebase_multiplier: u32,
/// The min and max elements in the lookup table input column
/// Range for lookup table input column values
/// Specified as (min, max) pair
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768"))]
pub lookup_range: Range,
/// The log_2 number of rows
/// Log2 of the number of rows in the circuit
/// Controls circuit size and proving time
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))]
pub logrows: u32,
/// The log_2 number of rows
/// Number of inner columns per block
/// Affects circuit layout and efficiency
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))]
pub num_inner_cols: usize,
/// Hand-written parser for graph variables, eg. batch_size=1
/// Graph variables for parameterizing the computation
/// Format: "name->value", e.g. "batch_size->1"
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))]
pub variables: Vec<(String, usize)>,
/// Flags whether inputs are public, private, fixed, hashed, polycommit
/// Visibility setting for input values
/// Controls whether inputs are public or private in the circuit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
pub input_visibility: Visibility,
/// Flags whether outputs are public, private, fixed, hashed, polycommit
/// Visibility setting for output values
/// Controls whether outputs are public or private in the circuit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))]
pub output_visibility: Visibility,
/// Flags whether params are fixed, private, hashed, polycommit
/// Visibility setting for parameters
/// Controls how parameters are handled in the circuit
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))]
pub param_visibility: Visibility,
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// Should constants with 0.0 fraction be rebased to scale 0
/// Whether to rebase constants with zero fractional part to scale 0
/// Can improve efficiency for integer constants
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub rebase_frac_zero_constants: bool,
/// check mode (safe, unsafe, etc)
/// Circuit checking mode
/// Controls level of constraint verification
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))]
pub check_mode: CheckMode,
/// commitment scheme
/// Commitment scheme for circuit proving
/// Affects proof size and verification time
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))]
pub commitment: Option<Commitments>,
/// the base used for decompositions
/// Base for number decomposition
/// Must be a power of 2
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))]
pub decomp_base: usize,
/// Number of decomposition legs
/// Controls decomposition granularity
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
/// the number of legs used for decompositions
pub decomp_legs: usize,
/// Whether to use bounded lookup for logarithm computation
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
/// use unbounded lookup for the log
pub bounded_log_lookup: bool,
/// Range check inputs and outputs (turn off if the inputs are felts)
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub ignore_range_check_inputs_outputs: bool,
/// Optional override for epsilon value
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
pub epsilon: Option<f64>,
}
impl RunArgs {
/// Returns the epsilon value
pub fn get_epsilon(&self) -> f64 {
self.epsilon.unwrap_or(f64::EPSILON)
}
}
impl Default for RunArgs {
/// Creates a new RunArgs instance with default values
///
/// Default configuration is optimized for common use cases
/// while maintaining reasonable proving time and circuit size
fn default() -> Self {
Self {
bounded_log_lookup: false,
tolerance: Tolerance::default(),
input_scale: 7,
param_scale: 7,
scale_rebase_multiplier: 1,
@@ -355,54 +386,140 @@ impl Default for RunArgs {
commitment: None,
decomp_base: 16384,
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
epsilon: None,
}
}
}
impl RunArgs {
/// Validates the RunArgs configuration
///
/// Performs comprehensive validation of all parameters to ensure they are within
/// acceptable ranges and follow required constraints. Returns accumulated errors
/// if any validations fail.
///
/// # Returns
/// - Ok(()) if all validations pass
/// - Err(String) with detailed error message if any validation fails
pub fn validate(&self) -> Result<(), String> {
let mut errors = Vec::new();
// check if the largest represented integer in the decomposed form overflows IntegerRep
// try it with the largest possible value
let max_decomp = (self.decomp_base as IntegerRep).checked_pow(self.decomp_legs as u32);
if max_decomp.is_none() {
errors.push(format!(
"decomp_base^decomp_legs overflows IntegerRep: {}^{}",
self.decomp_base, self.decomp_legs
));
}
// Visibility validations
if self.param_visibility == Visibility::Public {
return Err(
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
.into(),
errors.push(
"Parameters cannot be public instances. Use 'fixed' or 'kzgcommit' instead"
.to_string(),
);
}
// Scale validations
if self.scale_rebase_multiplier < 1 {
return Err("scale_rebase_multiplier must be >= 1".into());
errors.push("scale_rebase_multiplier must be >= 1".to_string());
}
// if any of the scales are too small
if self.input_scale < 8 || self.param_scale < 8 {
warn!("low scale values (<8) may impact precision");
}
// Lookup range validations
if self.lookup_range.0 > self.lookup_range.1 {
return Err("lookup_range min is greater than max".into());
errors.push(format!(
"Invalid lookup range: min ({}) is greater than max ({})",
self.lookup_range.0, self.lookup_range.1
));
}
// Size validations
if self.logrows < 1 {
return Err("logrows must be >= 1".into());
errors.push("logrows must be >= 1".to_string());
}
if self.num_inner_cols < 1 {
return Err("num_inner_cols must be >= 1".into());
errors.push("num_inner_cols must be >= 1".to_string());
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
let batch_size = self.variables.iter().find(|(name, _)| name == "batch_size");
if let Some(batch_size) = batch_size {
if batch_size.1 == 0 {
errors.push("'batch_size' cannot be 0".to_string());
}
}
// Decomposition validations
if self.decomp_base == 0 {
errors.push("decomp_base cannot be 0".to_string());
}
if self.decomp_legs == 0 {
errors.push("decomp_legs cannot be 0".to_string());
}
// Performance validations
if self.logrows > MAX_PUBLIC_SRS {
warn!("logrows exceeds maximum public SRS size");
}
// Performance warnings
if self.input_scale > 20 || self.param_scale > 20 {
warn!("High scale values (>20) may impact performance");
}
if errors.is_empty() {
Ok(())
} else {
Err(errors.join("\n"))
}
Ok(())
}
/// Export the ezkl configuration as json
/// Exports the configuration as JSON
///
/// Serializes the RunArgs instance to a JSON string
///
/// # Returns
/// * `Ok(String)` containing JSON representation
/// * `Err` if serialization fails
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
let serialized = match serde_json::to_string(&self) {
Ok(s) => s,
Err(e) => {
return Err(Box::new(e));
}
};
Ok(serialized)
let res = serde_json::to_string(&self)?;
Ok(res)
}
/// Parse an ezkl configuration from a json
/// Parses configuration from JSON
///
/// Deserializes a RunArgs instance from a JSON string
///
/// # Arguments
/// * `arg_json` - JSON string containing configuration
///
/// # Returns
/// * `Ok(RunArgs)` if parsing succeeds
/// * `Err` if parsing fails
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(arg_json)
}
}
/// Parse a single key-value pair
// Additional helper functions for the module
/// Parses a key-value pair from a string in the format "key->value"
///
/// # Arguments
/// * `s` - Input string in the format "key->value"
///
/// # Returns
/// * `Ok((T, U))` - Parsed key and value
/// * `Err` - If parsing fails
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
fn parse_key_val<T, U>(
s: &str,
@@ -415,8 +532,114 @@ where
{
let pos = s
.find("->")
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
let a = s[..pos].parse()?;
let b = s[pos + 2..].parse()?;
Ok((a, b))
.ok_or_else(|| format!("invalid KEY->VALUE: no `->` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 2..].parse()?))
}
/// Verifies that a version string matches the expected artifact version
/// Logs warnings for version mismatches or unversioned artifacts
///
/// # Arguments
/// * `artifact_version` - Version string from the artifact
pub fn check_version_string_matches(artifact_version: &str) {
if artifact_version == "0.0.0"
|| artifact_version == "source - no compatibility guaranteed"
|| artifact_version.is_empty()
{
log::warn!("Artifact version is 0.0.0, skipping version check");
return;
}
let version = crate::version();
if version == "source - no compatibility guaranteed" {
log::warn!("Compiled source version is not guaranteed to match artifact version");
return;
}
if version != artifact_version {
log::warn!(
"Version mismatch: CLI version is {} but artifact version is {}",
version,
artifact_version
);
}
}
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
#[test]
fn test_valid_default_args() {
let args = RunArgs::default();
assert!(args.validate().is_ok());
}
#[test]
fn test_invalid_param_visibility() {
let mut args = RunArgs::default();
args.param_visibility = Visibility::Public;
let err = args.validate().unwrap_err();
assert!(err.contains("Parameters cannot be public instances"));
}
#[test]
fn test_invalid_scale_rebase() {
let mut args = RunArgs::default();
args.scale_rebase_multiplier = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("scale_rebase_multiplier must be >= 1"));
}
#[test]
fn test_invalid_lookup_range() {
let mut args = RunArgs::default();
args.lookup_range = (100, -100);
let err = args.validate().unwrap_err();
assert!(err.contains("Invalid lookup range"));
}
#[test]
fn test_invalid_logrows() {
let mut args = RunArgs::default();
args.logrows = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("logrows must be >= 1"));
}
#[test]
fn test_invalid_inner_cols() {
let mut args = RunArgs::default();
args.num_inner_cols = 0;
let err = args.validate().unwrap_err();
assert!(err.contains("num_inner_cols must be >= 1"));
}
#[test]
fn test_zero_batch_size() {
let mut args = RunArgs::default();
args.variables = vec![("batch_size".to_string(), 0)];
let err = args.validate().unwrap_err();
assert!(err.contains("'batch_size' cannot be 0"));
}
#[test]
fn test_json_serialization() {
let args = RunArgs::default();
let json = args.as_json().unwrap();
let deserialized = RunArgs::from_json(&json).unwrap();
assert_eq!(args, deserialized);
}
#[test]
fn test_multiple_validation_errors() {
let mut args = RunArgs::default();
args.logrows = 0;
args.lookup_range = (100, -100);
let err = args.validate().unwrap_err();
// Should contain multiple error messages
assert!(err.matches("\n").count() >= 1);
}
}

View File

@@ -133,7 +133,6 @@ pub fn aggregate<'a>(
.collect_vec()
}));
// loader.ctx().constrain_equal(cell_0, cell_1)
let mut transcript = PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
let proof = PlonkSuccinctVerifier::read_proof(svk, &protocol, &instances, &mut transcript)
.map_err(|_| plonk::Error::Synthesis)?;
@@ -309,11 +308,11 @@ impl AggregationCircuit {
})
}
///
/// Number of limbs used for decomposition
pub fn num_limbs() -> usize {
LIMBS
}
///
/// Number of bits used for decomposition
pub fn num_bits() -> usize {
BITS
}

View File

@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::CurveAffine;
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use halo2curves::serde::SerdeObject;
use halo2curves::CurveAffine;
use instant::Instant;
use log::{debug, info, trace};
#[cfg(not(feature = "det-prove"))]
@@ -51,6 +51,9 @@ use pyo3::types::PyDictMethods;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
/// Converts a string to a `SerdeFormat`.
/// # Panics
/// Panics if the provided `s` is not a valid `SerdeFormat` (i.e. not one of "processed", "raw-bytes-unchecked", or "raw-bytes").
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
match s {
"processed" => halo2_proofs::SerdeFormat::Processed,
@@ -321,7 +324,7 @@ where
}
#[cfg(feature = "python-bindings")]
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
#[cfg(feature = "python-bindings")]
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
where
@@ -345,14 +348,15 @@ where
}
impl<
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,
{
/// Create a new application snark from proof and instance variables ready for aggregation
#[allow(clippy::too_many_arguments)]
pub fn new(
protocol: Option<PlonkProtocol<C>>,
instances: Vec<Vec<F>>,
@@ -528,7 +532,6 @@ pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
disable_selector_compression: bool,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
@@ -794,7 +797,6 @@ pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
params: <C as Circuit<Scheme::Scalar>>::Params,
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
where
C: Circuit<Scheme::Scalar>,
Scheme::Curve: SerdeObject + CurveAffine,
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
{
@@ -817,11 +819,11 @@ pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
params: <C as Circuit<Scheme::Scalar>>::Params,
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
where
C: Circuit<Scheme::Scalar>,
Scheme::Curve: SerdeObject + CurveAffine,
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
{
debug!("loading proving key from {:?}", path);
let start = instant::Instant::now();
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
@@ -830,7 +832,8 @@ where
params,
)
.map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
info!("loaded proving key ✅");
let elapsed = start.elapsed();
info!("loaded proving key in {:?}", elapsed);
Ok(pk)
}

View File

@@ -1,6 +1,6 @@
use thiserror::Error;
use super::ops::DecompositionError;
use super::{ops::DecompositionError, DataFormat};
/// A wrapper for tensor related errors.
#[derive(Debug, Error)]
@@ -38,4 +38,13 @@ pub enum TensorError {
/// Decomposition error
#[error("decomposition error: {0}")]
DecompositionError(#[from] DecompositionError),
/// Invalid argument
#[error("invalid argument: {0}")]
InvalidArgument(String),
/// Index out of bounds
#[error("index {0} out of bounds for dimension {1}")]
IndexOutOfBounds(usize, usize),
/// Invalid data conversion
#[error("invalid data conversion from format {0} to {1}")]
InvalidDataConversion(DataFormat, DataFormat),
}

View File

@@ -9,6 +9,7 @@ pub mod var;
pub use errors::TensorError;
use core::hash::Hash;
use halo2curves::ff::PrimeField;
use maybe_rayon::{
prelude::{
@@ -24,12 +25,9 @@ use std::path::PathBuf;
pub use val::*;
pub use var::*;
#[cfg(feature = "metal")]
use instant::Instant;
use crate::{
circuit::utils,
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::{IntegerRep, integer_rep_to_felt},
graph::Visibility,
};
@@ -40,8 +38,6 @@ use halo2_proofs::{
poly::Rotation,
};
use itertools::Itertools;
#[cfg(feature = "metal")]
use metal::{Device, MTLResourceOptions, MTLSize};
use std::error::Error;
use std::fmt::Debug;
use std::io::Read;
@@ -49,31 +45,6 @@ use std::iter::Iterator;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
#[cfg(feature = "metal")]
use std::collections::HashMap;
#[cfg(feature = "metal")]
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
#[cfg(feature = "metal")]
lazy_static::lazy_static! {
static ref DEVICE: Device = Device::system_default().expect("no device found");
static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap();
static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue();
static ref PIPELINES: HashMap<String, metal::ComputePipelineState> = {
let mut map = HashMap::new();
for name in ["add", "sub", "mul"] {
let function = LIB.get_function(name, None).unwrap();
let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap();
map.insert(name.to_string(), pipeline);
}
map
};
}
/// The (inner) type of tensor elements.
pub trait TensorType: Clone + Debug + 'static {
/// Returns the zero value.
@@ -91,7 +62,7 @@ pub trait TensorType: Clone + Debug + 'static {
}
macro_rules! tensor_type {
($rust_type:ty, $tensor_type:ident, $zero:expr, $one:expr) => {
($rust_type:ty, $tensor_type:ident, $zero:expr_2021, $one:expr_2021) => {
impl TensorType for $rust_type {
fn zero() -> Option<Self> {
Some($zero)
@@ -444,7 +415,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
Err(_) => {
return Err(TensorError::FileLoadError(
"Failed to read tensor".to_string(),
))
));
}
}
}
@@ -638,42 +609,44 @@ impl<T: Clone + TensorType> Tensor<T> {
where
T: Send + Sync,
{
if indices.is_empty() {
// Fast path: empty indices or full tensor slice
if indices.is_empty()
|| indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims
{
return Ok(self.clone());
}
// Validate dimensions
if self.dims.len() < indices.len() {
return Err(TensorError::DimError(format!(
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
indices, self.dims
)));
} else if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims {
// else if slice is the same as dims, return self
return Ok(self.clone());
}
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
// Pre-allocate the full indices vector with capacity
let mut full_indices = Vec::with_capacity(self.dims.len());
full_indices.extend_from_slice(indices);
for i in 0..(self.dims.len() - indices.len()) {
full_indices.push(0..self.dims()[indices.len() + i])
}
// Fill remaining dimensions
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
let cartesian_coord: Vec<Vec<usize>> = full_indices
// Pre-calculate total size and allocate result vector
let total_size: usize = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
let res: Vec<T> = cartesian_coord
.par_iter()
.map(|e| {
let index = self.get_index(e);
self[index].clone()
})
.collect();
.map(|range| range.end - range.start)
.product();
let mut res = Vec::with_capacity(total_size);
// Calculate new dimensions once
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
// Use iterator directly without collecting into intermediate Vec
for coord in full_indices.iter().cloned().multi_cartesian_product() {
let index = self.get_index(&coord);
res.push(self[index].clone());
}
Tensor::new(Some(&res), &dims)
}
@@ -831,7 +804,13 @@ impl<T: Clone + TensorType> Tensor<T> {
num_repeats: usize,
initial_offset: usize,
) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
if n == 0 {
return Err(TensorError::InvalidArgument(
"Cannot duplicate every 0th element".to_string(),
));
}
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.clone().into_iter().enumerate() {
if (i + offset + 1) % n == 0 {
@@ -860,20 +839,28 @@ impl<T: Clone + TensorType> Tensor<T> {
num_repeats: usize,
initial_offset: usize,
) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
let mut indices_to_remove = std::collections::HashSet::new();
for i in 0..self.inner.len() {
if (i + initial_offset + 1) % n == 0 {
for j in 1..(1 + num_repeats) {
indices_to_remove.insert(i + j);
}
}
if n == 0 {
return Err(TensorError::InvalidArgument(
"Cannot remove every 0th element".to_string(),
));
}
let old_inner = self.inner.clone();
for (i, elem) in old_inner.into_iter().enumerate() {
if !indices_to_remove.contains(&i) {
inner.push(elem.clone());
// Pre-calculate capacity to avoid reallocations
let estimated_size = self.inner.len() - (self.inner.len() / n) * num_repeats;
let mut inner = Vec::with_capacity(estimated_size);
// Use iterator directly instead of creating intermediate collectionsif
let mut i = 0;
while i < self.inner.len() {
// Add the current element
inner.push(self.inner[i].clone());
// If this is an nth position (accounting for offset)
if (i + initial_offset + 1) % n == 0 {
// Skip the next num_repeats elements
i += num_repeats + 1;
} else {
i += 1;
}
}
@@ -881,7 +868,6 @@ impl<T: Clone + TensorType> Tensor<T> {
}
/// Remove indices
/// WARN: assumes indices are in ascending order for speed
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
@@ -908,7 +894,11 @@ impl<T: Clone + TensorType> Tensor<T> {
}
// remove indices
for elem in indices.iter().rev() {
inner.remove(*elem);
if *elem < self.len() {
inner.remove(*elem);
} else {
return Err(TensorError::IndexOutOfBounds(*elem, self.len()));
}
}
Tensor::new(Some(&inner), &[inner.len()])
@@ -936,6 +926,9 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
self.dims = vec![];
}
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
new_dims.iter().product::<usize>()
@@ -1114,6 +1107,10 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
} else if self.dims() == &[0] && shape.iter().product::<usize>() == 1 {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
}
if self.dims().len() > shape.len() {
@@ -1264,7 +1261,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get last element of empty tensor".to_string(),
))
));
}
};
@@ -1289,7 +1286,7 @@ impl<T: Clone + TensorType> Tensor<T> {
None => {
return Err(TensorError::DimError(
"Cannot get first element of empty tensor".to_string(),
))
));
}
};
@@ -1400,10 +1397,6 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "add");
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
@@ -1501,10 +1494,6 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "sub");
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
@@ -1572,10 +1561,6 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
let lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
#[cfg(feature = "metal")]
let res = metal_tensor_op(&lhs, &rhs, "mul");
#[cfg(not(feature = "metal"))]
let res = {
let mut res: Tensor<T> = lhs
.par_iter()
@@ -1681,7 +1666,9 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
}
// implement remainder
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Rem for Tensor<T> {
impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + PartialEq> Rem
for Tensor<T>
{
type Output = Result<Tensor<T>, TensorError>;
/// Elementwise remainder of a tensor with another tensor.
@@ -1710,9 +1697,24 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
let mut lhs = self.expand(&broadcasted_shape).unwrap();
let rhs = rhs.expand(&broadcasted_shape).unwrap();
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
*o = o.clone() % r;
});
lhs.par_iter_mut()
.zip(rhs)
.map(|(o, r)| match T::zero() {
Some(zero) => {
if r != zero {
*o = o.clone() % r;
Ok(())
} else {
Err(TensorError::InvalidArgument(
"Cannot divide by zero in remainder".to_string(),
))
}
}
_ => Err(TensorError::InvalidArgument(
"Undefined zero value".to_string(),
)),
})
.collect::<Result<Vec<_>, _>>()?;
Ok(lhs)
}
@@ -1747,7 +1749,6 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
/// assert_eq!(c, vec![2, 3]);
///
/// ```
pub fn get_broadcasted_shape(
shape_a: &[usize],
shape_b: &[usize],
@@ -1755,23 +1756,247 @@ pub fn get_broadcasted_shape(
let num_dims_a = shape_a.len();
let num_dims_b = shape_b.len();
match (num_dims_a, num_dims_b) {
(a, b) if a == b => {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
}
Ok(broadcasted_shape)
if num_dims_a == num_dims_b {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
}
(a, b) if a < b => Ok(shape_b.to_vec()),
(a, b) if a > b => Ok(shape_a.to_vec()),
_ => Err(TensorError::DimError(
Ok(broadcasted_shape)
} else if num_dims_a < num_dims_b {
Ok(shape_b.to_vec())
} else if num_dims_a > num_dims_b {
Ok(shape_a.to_vec())
} else {
Err(TensorError::DimError(
"Unknown condition for broadcasting".to_string(),
)),
))
}
}
////////////////////////
///
/// The shape of data for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
pub enum DataFormat {
/// NCHW
#[default]
NCHW,
/// NHWC
NHWC,
/// CHW
CHW,
/// HWC
HWC,
}
// as str
impl core::fmt::Display for DataFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DataFormat::NCHW => write!(f, "NCHW"),
DataFormat::NHWC => write!(f, "NHWC"),
DataFormat::CHW => write!(f, "CHW"),
DataFormat::HWC => write!(f, "HWC"),
}
}
}
impl DataFormat {
/// Get the format's canonical form
pub fn canonical(&self) -> DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => self.clone(),
}
}
/// no batch dim
pub fn has_no_batch(&self) -> bool {
match self {
DataFormat::CHW | DataFormat::HWC => true,
_ => false,
}
}
/// Convert tensor to canonical format (NCHW or CHW)
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
tensor: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
DataFormat::NHWC => {
// For ND: Move channels from last axis to position after batch
let ndims = tensor.dims().len();
if ndims > 2 {
tensor.move_axis(ndims - 1, 1)?;
}
}
DataFormat::HWC => {
// For ND: Move channels from last axis to first position
let ndims = tensor.dims().len();
if ndims > 1 {
tensor.move_axis(ndims - 1, 0)?;
}
}
_ => {} // NCHW/CHW are already in canonical format
}
Ok(())
}
/// Convert tensor from canonical format to target format
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
tensor: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
DataFormat::NHWC => {
// Move channels from position 1 to end
let ndims = tensor.dims().len();
if ndims > 2 {
tensor.move_axis(1, ndims - 1)?;
}
}
DataFormat::HWC => {
// Move channels from position 0 to end
let ndims = tensor.dims().len();
if ndims > 1 {
tensor.move_axis(0, ndims - 1)?;
}
}
_ => {} // NCHW/CHW don't need conversion
}
Ok(())
}
/// Get the position of the channel dimension
pub fn get_channel_dim(&self, ndims: usize) -> usize {
match self {
DataFormat::NCHW => 1,
DataFormat::NHWC => ndims - 1,
DataFormat::CHW => 0,
DataFormat::HWC => ndims - 1,
}
}
}
/// The shape of the kernel for some operations
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default, Copy)]
pub enum KernelFormat {
/// HWIO
HWIO,
/// OIHW
#[default]
OIHW,
/// OHWI
OHWI,
}
impl core::fmt::Display for KernelFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KernelFormat::HWIO => write!(f, "HWIO"),
KernelFormat::OIHW => write!(f, "OIHW"),
KernelFormat::OHWI => write!(f, "OHWI"),
}
}
}
impl KernelFormat {
/// Get the format's canonical form
pub fn canonical(&self) -> KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => self.clone(),
}
}
/// Convert kernel to canonical format (OIHW)
pub fn to_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
kernel: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
KernelFormat::HWIO => {
let kdims = kernel.dims().len();
// Move output channels from last to first
kernel.move_axis(kdims - 1, 0)?;
// Move input channels from new last to second position
kernel.move_axis(kdims - 1, 1)?;
}
KernelFormat::OHWI => {
let kdims = kernel.dims().len();
// Move input channels from last to second position
kernel.move_axis(kdims - 1, 1)?;
}
_ => {} // OIHW is already canonical
}
Ok(())
}
/// Convert kernel from canonical format to target format
pub fn from_canonical<F: PrimeField + TensorType + PartialOrd + Hash>(
&self,
kernel: &mut ValTensor<F>,
) -> Result<(), TensorError> {
match self {
KernelFormat::HWIO => {
let kdims = kernel.dims().len();
// Move input channels from second position to last
kernel.move_axis(1, kdims - 1)?;
// Move output channels from first to last
kernel.move_axis(0, kdims - 1)?;
}
KernelFormat::OHWI => {
let kdims = kernel.dims().len();
// Move input channels from second position to last
kernel.move_axis(1, kdims - 1)?;
}
_ => {} // OIHW doesn't need conversion
}
Ok(())
}
/// Get the position of input and output channel dimensions
pub fn get_channel_dims(&self, ndims: usize) -> (usize, usize) {
// (input_ch, output_ch)
match self {
KernelFormat::OIHW => (1, 0),
KernelFormat::HWIO => (ndims - 2, ndims - 1),
KernelFormat::OHWI => (ndims - 1, 0),
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<tract_onnx::tract_hir::ops::nn::DataFormat> for DataFormat {
fn from(fmt: tract_onnx::tract_hir::ops::nn::DataFormat) -> Self {
match fmt {
tract_onnx::tract_hir::ops::nn::DataFormat::NCHW => DataFormat::NCHW,
tract_onnx::tract_hir::ops::nn::DataFormat::NHWC => DataFormat::NHWC,
tract_onnx::tract_hir::ops::nn::DataFormat::CHW => DataFormat::CHW,
tract_onnx::tract_hir::ops::nn::DataFormat::HWC => DataFormat::HWC,
}
}
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl From<tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat> for KernelFormat {
fn from(fmt: tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat) -> Self {
match fmt {
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::HWIO => {
KernelFormat::HWIO
}
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OIHW => {
KernelFormat::OIHW
}
tract_onnx::tract_hir::tract_core::ops::cnn::conv::KernelFormat::OHWI => {
KernelFormat::OHWI
}
}
}
}
#[cfg(test)]
mod tests {
@@ -1807,66 +2032,4 @@ mod tests {
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
}
#[test]
#[cfg(feature = "metal")]
fn tensor_metal_int() {
let a = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
let b = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
let c = metal_tensor_op(&a, &b, "add");
assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap());
let c = metal_tensor_op(&a, &b, "sub");
assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap());
let c = metal_tensor_op(&a, &b, "mul");
assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap());
}
#[test]
#[cfg(feature = "metal")]
fn tensor_metal_felt() {
use halo2curves::bn256::Fr;
let a = Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
&[2, 2],
)
.unwrap();
let b = Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
&[2, 2],
)
.unwrap();
let c = metal_tensor_op(&a, &b, "add");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]),
&[2, 2],
)
.unwrap()
);
let c = metal_tensor_op(&a, &b, "sub");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]),
&[2, 2],
)
.unwrap()
);
let c = metal_tensor_op(&a, &b, "mul");
assert_eq!(
c,
Tensor::<Fr>::new(
Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]),
&[2, 2],
)
.unwrap()
);
}
}

View File

@@ -27,7 +27,7 @@ pub fn get_rep(
n: usize,
) -> Result<Vec<IntegerRep>, DecompositionError> {
// check if x is too large
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
if (*x).abs() > ((base as i128).pow(n as u32)) - 1 {
return Err(DecompositionError::TooLarge(*x, base, n));
}
let mut rep = vec![0; n + 1];
@@ -43,8 +43,8 @@ pub fn get_rep(
let mut x = x.abs();
//
for i in (1..rep.len()).rev() {
rep[i] = x % base as i128;
x /= base as i128;
rep[i] = x % base as IntegerRep;
x /= base as IntegerRep;
}
Ok(rep)
@@ -127,7 +127,7 @@ pub fn decompose(
.flatten()
.collect::<Vec<IntegerRep>>();
let output = Tensor::<i128>::new(Some(&resp), &dims)?;
let output = Tensor::<IntegerRep>::new(Some(&resp), &dims)?;
Ok(output)
}
@@ -160,7 +160,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap();
@@ -168,7 +168,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let a = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
@@ -188,7 +188,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap();
@@ -196,7 +196,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let a = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9]),
@@ -216,7 +216,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap();
@@ -224,7 +224,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
/// ```
pub fn trilu<T: TensorType + std::marker::Send + std::marker::Sync>(
a: &Tensor<T>,
@@ -385,6 +385,12 @@ pub fn resize<T: TensorType + Send + Sync>(
pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync>(
t: &[Tensor<T>],
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.is_empty() {
return Err(TensorError::DimMismatch("add".to_string()));
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();
@@ -433,6 +439,11 @@ pub fn add<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sy
pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync>(
t: &[Tensor<T>],
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.is_empty() {
return Err(TensorError::DimMismatch("sub".to_string()));
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();
@@ -479,6 +490,11 @@ pub fn sub<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sy
pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync>(
t: &[Tensor<T>],
) -> Result<Tensor<T>, TensorError> {
if t.len() == 1 {
return Ok(t[0].clone());
} else if t.is_empty() {
return Err(TensorError::DimMismatch("mult".to_string()));
}
// calculate value of output
let mut output: Tensor<T> = t[0].clone();
@@ -519,30 +535,101 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
/// let result = downsample(&x, 1, 2, 2).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 6]), &[2, 1]).unwrap();
/// assert_eq!(result, expected);
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
/// &[2, 3],
/// ).unwrap();
///
/// // Test case 1: Negative stride along dimension 0
/// // This should flip the order along dimension 0
/// let result = downsample(&x, 0, -1, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 6, 1, 2, 3]), // Flipped order of rows
/// &[2, 3]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 2: Negative stride along dimension 1
/// // This should flip the order along dimension 1
/// let result = downsample(&x, 1, -1, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 2, 1, 6, 5, 4]), // Flipped order of columns
/// &[2, 3]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 3: Negative stride with stride magnitude > 1
/// // This should both skip and flip
/// let result = downsample(&x, 1, -2, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 1, 6, 4]), // Take every 2nd element in reverse
/// &[2, 2]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 4: Negative stride with non-zero modulo
/// // This should start at (size - 1 - modulo) and reverse
/// let result = downsample(&x, 1, -2, 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[2, 5]), // Start at second element from end, take every 2nd in reverse
/// &[2, 1]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Create a larger test case for more complex downsampling
/// let y = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
/// &[3, 4],
/// ).unwrap();
///
/// // Test case 5: Negative stride with modulo on larger tensor
/// let result = downsample(&y, 1, -2, 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 1, 7, 5, 11, 9]), // Start at one after reverse, take every 2nd
/// &[3, 2]
/// ).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn downsample<T: TensorType + Send + Sync>(
input: &Tensor<T>,
dim: usize,
stride: usize,
stride: isize, // Changed from usize to isize to support negative strides
modulo: usize,
) -> Result<Tensor<T>, TensorError> {
let mut output_shape = input.dims().to_vec();
// now downsample along axis dim offset by modulo, rounding up (+1 if remaidner is non-zero)
let remainder = (input.dims()[dim] - modulo) % stride;
let div = (input.dims()[dim] - modulo) / stride;
output_shape[dim] = div + (remainder > 0) as usize;
let mut output = Tensor::<T>::new(None, &output_shape)?;
// Handle negative stride case
if stride == 0 {
return Err(TensorError::DimMismatch(
"downsample stride cannot be zero".to_string(),
));
}
if modulo > input.dims()[dim] {
let stride_abs = stride.unsigned_abs();
let mut output_shape = input.dims().to_vec();
if modulo >= input.dims()[dim] {
return Err(TensorError::DimMismatch("downsample".to_string()));
}
// now downsample along axis dim offset by modulo
// Calculate output shape based on the absolute value of stride
let remainder = (input.dims()[dim] - modulo) % stride_abs;
let div = (input.dims()[dim] - modulo) / stride_abs;
output_shape[dim] = div + (remainder > 0) as usize;
let mut output = Tensor::<T>::new(None, &output_shape)?;
// Calculate indices based on stride direction
let indices = (0..output_shape.len())
.map(|i| {
if i == dim {
let mut index = vec![0; output_shape[i]];
for (i, idx) in index.iter_mut().enumerate() {
*idx = i * stride + modulo;
for (j, idx) in index.iter_mut().enumerate() {
if stride > 0 {
// Positive stride: move forward from modulo
*idx = j * stride_abs + modulo;
} else {
// Negative stride: move backward from (size - 1 - modulo)
*idx = (input.dims()[dim] - 1 - modulo) - j * stride_abs;
}
}
index
} else {
@@ -1310,7 +1397,6 @@ pub fn pad<T: TensorType>(
///
/// # Errors
/// Returns a TensorError if the tensors in `inputs` have incompatible dimensions for concatenation along the specified `axis`.
pub fn concat<T: TensorType + Send + Sync>(
inputs: &[&Tensor<T>],
axis: usize,
@@ -1773,14 +1859,14 @@ pub mod nonlinearities {
/// Some(&[4, 25, 8, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = rsqrt(&x, 1.0);
/// let result = rsqrt(&x, 1.0, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64, eps: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;
let fout = scale_input / (kix.sqrt() + f64::EPSILON);
let fout = scale_input / (kix.sqrt() + eps);
let rounded = fout.round();
Ok::<_, TensorError>(rounded as IntegerRep)
})
@@ -2086,7 +2172,6 @@ pub mod nonlinearities {
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn tanh(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;
@@ -2254,14 +2339,23 @@ pub mod nonlinearities {
/// &[2, 3],
/// ).unwrap();
/// let k = 2_f64;
/// let result = recip(&x, 1.0, k);
/// let result = recip(&x, 1.0, k, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn recip(a: &Tensor<IntegerRep>, input_scale: f64, out_scale: f64) -> Tensor<IntegerRep> {
pub fn recip(
a: &Tensor<IntegerRep>,
input_scale: f64,
out_scale: f64,
eps: f64,
) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let rescaled = (a_i as f64) / input_scale;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let denom = if rescaled == 0_f64 {
(1_f64) / (rescaled + eps)
} else {
(1_f64) / (rescaled)
};
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
})
@@ -2277,16 +2371,16 @@ pub mod nonlinearities {
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
/// let k = 2_f64;
/// let result = zero_recip(1.0);
/// let result = zero_recip(1.0, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[4503599627370496]), &[1]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn zero_recip(out_scale: f64) -> Tensor<IntegerRep> {
pub fn zero_recip(out_scale: f64, eps: f64) -> Tensor<IntegerRep> {
let a = Tensor::<IntegerRep>::new(Some(&[0]), &[1]).unwrap();
a.par_enum_map(|_, a_i| {
let rescaled = a_i as f64;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let denom = (1_f64) / (rescaled + eps);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
})

File diff suppressed because it is too large Load Diff

View File

@@ -2,36 +2,38 @@ use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::{region::ConstantsMap, CheckMode};
use crate::circuit::{CheckMode, region::ConstantsMap};
use super::*;
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
/// Typically assign [ValTensor]s to [VarTensor]s when laying out a circuit.
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
/// VarTensors are used to store and manage circuit columns, typically for assigning ValTensor
/// values during circuit layout. The tensor organizes storage into blocks of columns, where each
/// block contains multiple columns and each column contains multiple rows.
#[derive(Clone, Default, Debug, PartialEq, Eq)]
pub enum VarTensor {
/// A VarTensor for holding Advice values, which are assigned at proving time.
Advice {
/// Vec of Advice columns, we have [[xx][xx][xx]...] where each inner vec is xx columns
inner: Vec<Vec<Column<Advice>>>,
///
/// The number of columns in each inner block
num_inner_cols: usize,
/// Number of rows available to be used in each column of the storage
col_size: usize,
},
/// Dummy var
/// A placeholder tensor used for testing or temporary storage
Dummy {
///
/// The number of columns in each inner block
num_inner_cols: usize,
/// Number of rows available to be used in each column of the storage
col_size: usize,
},
/// Empty var
/// An empty tensor with no storage
#[default]
Empty,
}
impl VarTensor {
/// name of the tensor
/// Returns the name of the tensor variant as a static string
pub fn name(&self) -> &'static str {
match self {
VarTensor::Advice { .. } => "Advice",
@@ -40,22 +42,35 @@ impl VarTensor {
}
}
///
/// Returns true if the tensor is an Advice variant
pub fn is_advice(&self) -> bool {
matches!(self, VarTensor::Advice { .. })
}
/// Calculates the maximum number of usable rows in the constraint system
///
/// # Arguments
/// * `cs` - The constraint system
/// * `logrows` - Log base 2 of the total number of rows (including system and blinding rows)
///
/// # Returns
/// The maximum number of usable rows after accounting for blinding factors
pub fn max_rows<F: PrimeField>(cs: &ConstraintSystem<F>, logrows: usize) -> usize {
let base = 2u32;
base.pow(logrows as u32) as usize - cs.blinding_factors() - 1
}
/// Create a new VarTensor::Advice that is unblinded
/// Arguments
/// * `cs` - The constraint system
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
/// * `capacity` - The number of advice cells to allocate
/// Creates a new VarTensor::Advice with unblinded columns. Unblinded columns are used when
/// the values do not need to be hidden in the proof.
///
/// # Arguments
/// * `cs` - The constraint system to create columns in
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_inner_cols` - Number of columns in each inner block
/// * `capacity` - Total number of advice cells to allocate
///
/// # Returns
/// A new VarTensor::Advice with unblinded columns enabled for equality constraints
pub fn new_unblinded_advice<F: PrimeField>(
cs: &mut ConstraintSystem<F>,
logrows: usize,
@@ -93,11 +108,17 @@ impl VarTensor {
}
}
/// Create a new VarTensor::Advice
/// Arguments
/// * `cs` - The constraint system
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
/// * `capacity` - The number of advice cells to allocate
/// Creates a new VarTensor::Advice with standard (blinded) columns, used when
/// the values need to be hidden in the proof.
///
/// # Arguments
/// * `cs` - The constraint system to create columns in
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_inner_cols` - Number of columns in each inner block
/// * `capacity` - Total number of advice cells to allocate
///
/// # Returns
/// A new VarTensor::Advice with blinded columns enabled for equality constraints
pub fn new_advice<F: PrimeField>(
cs: &mut ConstraintSystem<F>,
logrows: usize,
@@ -133,11 +154,17 @@ impl VarTensor {
}
}
/// Initializes fixed columns to support the VarTensor::Advice
/// Arguments
/// * `cs` - The constraint system
/// * `logrows` - log2 number of rows in the matrix, including any system and blinding rows.
/// * `capacity` - The number of advice cells to allocate
/// Initializes fixed columns in the constraint system to support the VarTensor::Advice
/// Fixed columns are used for constant values that are known at circuit creation time.
///
/// # Arguments
/// * `cs` - The constraint system to create columns in
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_constants` - Number of constant values needed
/// * `module_requires_fixed` - Whether the module requires at least one fixed column
///
/// # Returns
/// The number of fixed columns created
pub fn constant_cols<F: PrimeField>(
cs: &mut ConstraintSystem<F>,
logrows: usize,
@@ -169,7 +196,14 @@ impl VarTensor {
modulo
}
/// Create a new VarTensor::Dummy
/// Creates a new dummy VarTensor for testing or temporary storage
///
/// # Arguments
/// * `logrows` - Log base 2 of the total number of rows
/// * `num_inner_cols` - Number of columns in each inner block
///
/// # Returns
/// A new VarTensor::Dummy with the specified dimensions
pub fn dummy(logrows: usize, num_inner_cols: usize) -> Self {
let base = 2u32;
let max_rows = base.pow(logrows as u32) as usize - 6;
@@ -179,7 +213,7 @@ impl VarTensor {
}
}
/// Gets the dims of the object the VarTensor represents
/// Returns the number of blocks in the tensor
pub fn num_blocks(&self) -> usize {
match self {
VarTensor::Advice { inner, .. } => inner.len(),
@@ -187,7 +221,7 @@ impl VarTensor {
}
}
/// Num inner cols
/// Returns the number of columns in each inner block
pub fn num_inner_cols(&self) -> usize {
match self {
VarTensor::Advice { num_inner_cols, .. } | VarTensor::Dummy { num_inner_cols, .. } => {
@@ -197,7 +231,7 @@ impl VarTensor {
}
}
/// Total number of columns
/// Returns the total number of columns across all blocks
pub fn num_cols(&self) -> usize {
match self {
VarTensor::Advice { inner, .. } => inner[0].len() * inner.len(),
@@ -205,7 +239,7 @@ impl VarTensor {
}
}
/// Gets the size of each column
/// Returns the maximum number of rows in each column
pub fn col_size(&self) -> usize {
match self {
VarTensor::Advice { col_size, .. } | VarTensor::Dummy { col_size, .. } => *col_size,
@@ -213,7 +247,7 @@ impl VarTensor {
}
}
/// Gets the size of each column
/// Returns the total size of each block (num_inner_cols * col_size)
pub fn block_size(&self) -> usize {
match self {
VarTensor::Advice {
@@ -230,7 +264,13 @@ impl VarTensor {
}
}
/// Take a linear coordinate and output the (column, row) position in the storage block.
/// Converts a linear coordinate to (block, column, row) coordinates in the storage
///
/// # Arguments
/// * `linear_coord` - The linear index to convert
///
/// # Returns
/// A tuple of (block_index, column_index, row_index)
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize, usize) {
// x indexes over blocks of size num_inner_cols
let x = linear_coord / self.block_size();
@@ -243,7 +283,17 @@ impl VarTensor {
}
impl VarTensor {
/// Retrieve the value of a specific cell in the tensor.
/// Queries a range of cells in the tensor during circuit synthesis
///
/// # Arguments
/// * `meta` - Virtual cells accessor
/// * `x` - Block index
/// * `y` - Column index within block
/// * `z` - Starting row offset
/// * `rng` - Number of consecutive rows to query
///
/// # Returns
/// A tensor of expressions representing the queried cells
pub fn query_rng<F: PrimeField>(
&self,
meta: &mut VirtualCells<'_, F>,
@@ -268,7 +318,16 @@ impl VarTensor {
}
}
/// Retrieve the value of a specific block at an offset in the tensor.
/// Queries an entire block of cells at a given offset
///
/// # Arguments
/// * `meta` - Virtual cells accessor
/// * `x` - Block index
/// * `z` - Row offset
/// * `rng` - Number of consecutive rows to query
///
/// # Returns
/// A tensor of expressions representing the queried block
pub fn query_whole_block<F: PrimeField>(
&self,
meta: &mut VirtualCells<'_, F>,
@@ -293,7 +352,16 @@ impl VarTensor {
}
}
/// Assigns a constant value to a specific cell in the tensor.
/// Assigns a constant value to a specific cell in the tensor
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for the assignment
/// * `coord` - Coordinate within the tensor
/// * `constant` - The constant value to assign
///
/// # Returns
/// The assigned cell or an error if assignment fails
pub fn assign_constant<F: PrimeField + TensorType + PartialOrd>(
&self,
region: &mut Region<F>,
@@ -313,7 +381,17 @@ impl VarTensor {
}
}
/// Assigns [ValTensor] to the columns of the inner tensor.
/// Assigns values from a ValTensor to this tensor, excluding specified positions
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `omissions` - Set of positions to skip during assignment
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
@@ -325,7 +403,10 @@ impl VarTensor {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
unimplemented!("cannot assign instance to advice columns with omissions")
error!(
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
@@ -344,7 +425,16 @@ impl VarTensor {
Ok(res)
}
/// Assigns [ValTensor] to the columns of the inner tensor.
/// Assigns values from a ValTensor to this tensor
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
@@ -396,14 +486,23 @@ impl VarTensor {
Ok(res)
}
/// Helper function to get the remaining size of the column
/// Returns the remaining available space in a column for assignments
///
/// # Arguments
/// * `offset` - Current offset in the column
/// * `values` - The ValTensor to check space for
///
/// # Returns
/// The number of rows that need to be flushed or an error if space is insufficient
pub fn get_column_flush<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
offset: usize,
values: &ValTensor<F>,
) -> Result<usize, halo2_proofs::plonk::Error> {
if values.len() > self.col_size() {
error!("Values are too large for the column");
error!(
"There are too many values to flush for this column size, try setting the logrows to a higher value (eg. --logrows 22 on the cli)"
);
return Err(halo2_proofs::plonk::Error::Synthesis);
}
@@ -427,8 +526,16 @@ impl VarTensor {
Ok(flush_len)
}
/// Assigns [ValTensor] to the columns of the inner tensor. Whereby the values are assigned to a single column, without overflowing.
/// So for instance if we are assigning 10 values and we are at index 18 of the column, and the columns are of length 20, we skip the last 2 values of current column and start from the beginning of the next column.
/// Assigns values to a single column, avoiding column overflow by flushing to the next column if needed
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// A tuple of (assigned ValTensor, number of rows flushed) or an error if assignment fails
pub fn assign_exact_column<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
@@ -443,8 +550,17 @@ impl VarTensor {
Ok((assigned_vals, flush_len))
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
/// Assigns values with duplication in dummy mode, used for testing and simulation
///
/// # Arguments
/// * `row` - Starting row for assignment
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `single_inner_col` - Whether to treat as a single column
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
pub fn dummy_assign_with_duplication<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
@@ -456,8 +572,13 @@ impl VarTensor {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
ValTensor::Instance { .. } => {
error!(
"duplication is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, dims, .. } => {
let duplication_freq = if single_inner_col {
self.col_size()
} else {
@@ -470,21 +591,20 @@ impl VarTensor {
self.num_inner_cols()
};
let duplication_offset = if single_inner_col {
row
} else {
offset
};
let duplication_offset = if single_inner_col { row } else { offset };
// duplicates every nth element to adjust for column overflow
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
let mut res: ValTensor<F> = v
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
.unwrap()
.into();
let constants_map = res.create_constants_map();
constants.extend(constants_map);
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
.unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
@@ -494,108 +614,183 @@ impl VarTensor {
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// Assigns values with duplication but without enforcing constraints between duplicated values
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
pub fn assign_with_duplication_unconstrained<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => {
error!(
"duplication is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, dims, .. } => {
let duplication_freq = self.block_size();
let num_repeats = self.num_inner_cols();
let duplication_offset = offset;
// duplicates every nth element to adjust for column overflow
let v = v
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
.map_err(|e| {
error!("Error duplicating values: {:?}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let mut res: ValTensor<F> = {
v.enum_map(|coord, k| {
let cell =
self.assign_value(region, offset, k.clone(), coord, constants)?;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into()
};
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
.map_err(|e| {
error!("Error duplicating values: {:?}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
res.reshape(dims).map_err(|e| {
error!("Error duplicating values: {:?}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
res.set_scale(values.scale());
Ok((res, total_used_len))
}
}
}
/// Assigns values with duplication and enforces equality constraints between duplicated values
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `row` - Starting row for assignment
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `check_mode` - Mode for checking equality constraints
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// A tuple of (assigned ValTensor, total length used) or an error if assignment fails
pub fn assign_with_duplication_constrained<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
region: &mut Region<F>,
row: usize,
offset: usize,
values: &ValTensor<F>,
check_mode: &CheckMode,
single_inner_col: bool,
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
let mut prev_cell = None;
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
let duplication_freq = if single_inner_col {
self.col_size()
} else {
self.block_size()
};
let num_repeats = if single_inner_col {
1
} else {
self.num_inner_cols()
};
let duplication_offset = if single_inner_col {
row
} else {
offset
};
ValTensor::Instance { .. } => {
error!(
"duplication is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, dims, .. } => {
let duplication_freq = self.col_size();
let num_repeats = 1;
let duplication_offset = row;
// duplicates every nth element to adjust for column overflow
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
let mut res: ValTensor<F> = {
v.enum_map(|coord, k| {
let v = v
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
.unwrap();
let mut res: ValTensor<F> = v
.enum_map(|coord, k| {
let step = self.num_inner_cols();
let step = if !single_inner_col {
1
} else {
self.num_inner_cols()
};
let (x, y, z) = self.cartesian_coord(offset + coord * step);
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
// assert that duplication occurred correctly
assert_eq!(
Into::<IntegerRep>::into(k.clone()),
Into::<IntegerRep>::into(v[coord - 1].clone())
);
};
let (x, y, z) = self.cartesian_coord(offset + coord * step);
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
// assert that duplication occurred correctly
assert_eq!(Into::<IntegerRep>::into(k.clone()), Into::<IntegerRep>::into(v[coord - 1].clone()));
};
let cell =
self.assign_value(region, offset, k.clone(), coord * step, constants)?;
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
let at_end_of_column = z == duplication_freq - 1;
let at_beginning_of_column = z == 0;
if single_inner_col {
if z == 0 {
// if we are at the end of the column, we need to copy the cell to the next column
prev_cell = Some(cell.clone());
} else if coord > 0 && z == 0 && single_inner_col {
if let Some(prev_cell) = prev_cell.as_ref() {
let cell = cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
let prev_cell = prev_cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Error copy-constraining previous value: {:?}", (x,y));
return Err(halo2_proofs::plonk::Error::Synthesis);
if at_end_of_column {
// if we are at the end of the column, we need to copy the cell to the next column
prev_cell = Some(cell.clone());
} else if coord > 0 && at_beginning_of_column {
if let Some(prev_cell) = prev_cell.as_ref() {
let cell = if let Some(cell) = cell.cell() {
cell
} else {
error!("Error getting cell: {:?}", (x, y));
return Err(halo2_proofs::plonk::Error::Synthesis);
};
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
prev_cell
} else {
error!("Error getting prev cell: {:?}", (x, y));
return Err(halo2_proofs::plonk::Error::Synthesis);
};
region.constrain_equal(prev_cell, cell)?;
} else {
error!("Previous cell was not set");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
}
}}
Ok(cell)
Ok(cell)
})?
.into();
})?.into()};
let total_used_len = res.len();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
.unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
if matches!(check_mode, CheckMode::SAFE) {
// during key generation this will be 0 so we use this as a flag to check
// TODO: this isn't very safe and would be better to get the phase directly
let res_evals = res.int_evals().unwrap();
let is_assigned = res_evals
.iter()
.all(|&x| x == 0);
if !is_assigned {
assert_eq!(
values.int_evals().unwrap(),
res_evals
)};
}
Ok((res, total_used_len))
}
}
}
/// Assigns a single value to the tensor. This is a helper function used by other assignment methods.
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for the assignment
/// * `k` - The value to assign
/// * `coord` - The coordinate where to assign the value
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned value or an error if assignment fails
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
@@ -606,32 +801,49 @@ impl VarTensor {
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset + coord);
let res = match k {
// Handle direct value assignment
ValType::Value(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
}
_ => unimplemented!(),
_ => {
error!("VarTensor was not initialized");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
},
// Handle copying previously assigned value
ValType::PrevAssigned(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
}
_ => unimplemented!(),
_ => {
error!("VarTensor was not initialized");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
},
// Handle copying previously assigned constant
ValType::AssignedConstant(v, val) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
}
_ => unimplemented!(),
_ => {
error!("VarTensor was not initialized");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
},
// Handle assigning evaluated value
ValType::AssignedValue(v) => match &self {
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
region
.assign_advice(|| "k", advices[x][y], z, || v)?
.evaluate(),
),
_ => unimplemented!(),
_ => {
error!("VarTensor was not initialized");
return Err(halo2_proofs::plonk::Error::Synthesis);
}
},
// Handle constant value assignment with caching
ValType::Constant(v) => {
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
let value = ValType::AssignedConstant(

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More