mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3640c9aa6d | ||
|
|
00d6873f9a | ||
|
|
c97ff84198 | ||
|
|
f5f8ef56f7 | ||
|
|
685487c853 | ||
|
|
33d41c7e49 | ||
|
|
e04c959662 | ||
|
|
27b1f2e9d4 | ||
|
|
4a172877af | ||
|
|
5a8498894d | ||
|
|
095c0ca5b4 | ||
|
|
3fa482c9ef | ||
|
|
6be3b1d663 | ||
|
|
d5a1d1439c |
55
.github/workflows/engine.yml
vendored
55
.github/workflows/engine.yml
vendored
@@ -178,3 +178,58 @@ jobs:
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: [publish-wasm-bindings]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
|
||||
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
|
||||
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
|
||||
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
|
||||
echo "ENGINE_INTEGRITY=$ENGINE_INTEGRITY" >> $GITHUB_ENV
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"$CLEANED_TAG\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Update pnpm-lock.yaml versions and integrity
|
||||
run: |
|
||||
awk -v integrity="$ENGINE_INTEGRITY" -v tag="$CLEANED_TAG" '
|
||||
NR==30{$0=" specifier: \"" tag "\""}
|
||||
NR==31{$0=" version: \"" tag "\""}
|
||||
NR==400{$0=" /@ezkljs/engine@" tag ":"}
|
||||
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm run build
|
||||
pnpm publish --no-git-checks
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -341,6 +341,10 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & outputs)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & kzg outputs + params)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM)
|
||||
|
||||
68
.github/workflows/verify.yml
vendored
68
.github/workflows/verify.yml
vendored
@@ -1,68 +0,0 @@
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag to release"
|
||||
required: true
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: .
|
||||
jobs:
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Prepare tag and fetch package integrity
|
||||
run: |
|
||||
CLEANED_TAG=${{ github.ref_name }} # Get the tag from ref_name
|
||||
CLEANED_TAG="${CLEANED_TAG#v}" # Remove leading 'v'
|
||||
echo "CLEANED_TAG=${CLEANED_TAG}" >> $GITHUB_ENV # Set it as an environment variable for later steps
|
||||
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@$CLEANED_TAG dist.integrity)
|
||||
echo "ENGINE_INTEGRITY=$ENGINE_INTEGRITY" >> $GITHUB_ENV
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"$CLEANED_TAG\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Update pnpm-lock.yaml versions and integrity
|
||||
run: |
|
||||
awk -v integrity="$ENGINE_INTEGRITY" -v tag="$CLEANED_TAG" '
|
||||
NR==30{$0=" specifier: \"" tag "\""}
|
||||
NR==31{$0=" version: \"" tag "\""}
|
||||
NR==400{$0=" /@ezkljs/engine@" tag ":"}
|
||||
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm run build
|
||||
pnpm publish --no-git-checks
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
46
Cargo.lock
generated
46
Cargo.lock
generated
@@ -58,7 +58,7 @@ checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||
[[package]]
|
||||
name = "alloy"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-contract",
|
||||
@@ -80,7 +80,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-consensus"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-eips",
|
||||
"alloy-primitives 0.7.2",
|
||||
@@ -93,7 +93,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-contract"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-dyn-abi",
|
||||
"alloy-json-abi",
|
||||
@@ -140,7 +140,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-eips"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-primitives 0.7.2",
|
||||
"alloy-rlp",
|
||||
@@ -154,7 +154,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-genesis"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-primitives 0.7.2",
|
||||
"alloy-serde",
|
||||
@@ -177,7 +177,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-json-rpc"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-primitives 0.7.2",
|
||||
"serde",
|
||||
@@ -189,7 +189,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-network"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-eips",
|
||||
@@ -206,7 +206,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-node-bindings"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-genesis",
|
||||
"alloy-primitives 0.7.2",
|
||||
@@ -261,7 +261,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-provider"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-eips",
|
||||
"alloy-json-rpc",
|
||||
@@ -281,6 +281,7 @@ dependencies = [
|
||||
"futures",
|
||||
"futures-utils-wasm",
|
||||
"lru",
|
||||
"pin-project",
|
||||
"reqwest",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
@@ -313,7 +314,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-rpc-client"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-json-rpc",
|
||||
"alloy-transport",
|
||||
@@ -333,7 +334,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-rpc-types"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-eips",
|
||||
@@ -351,7 +352,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-rpc-types-trace"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-primitives 0.7.2",
|
||||
"alloy-rpc-types",
|
||||
@@ -363,7 +364,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-serde"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-primitives 0.7.2",
|
||||
"serde",
|
||||
@@ -373,7 +374,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-signer"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-primitives 0.7.2",
|
||||
"async-trait",
|
||||
@@ -386,7 +387,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-signer-wallet"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-consensus",
|
||||
"alloy-network",
|
||||
@@ -459,7 +460,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-transport"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-json-rpc",
|
||||
"base64 0.22.1",
|
||||
@@ -477,7 +478,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "alloy-transport-http"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy#e60d64cff86d995657f0acea85c2bbd52f9bd810"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-json-rpc",
|
||||
"alloy-transport",
|
||||
@@ -1207,6 +1208,15 @@ dependencies = [
|
||||
"strsim",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_complete"
|
||||
version = "4.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd79504325bf38b10165b02e89b4347300f855f273c4cb30c4a3209e6583275e"
|
||||
dependencies = [
|
||||
"clap 4.5.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.3"
|
||||
@@ -1831,6 +1841,7 @@ dependencies = [
|
||||
"bincode",
|
||||
"chrono",
|
||||
"clap 4.5.3",
|
||||
"clap_complete",
|
||||
"colored",
|
||||
"colored_json",
|
||||
"console_error_panic_hook",
|
||||
@@ -1867,6 +1878,7 @@ dependencies = [
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"semver 1.0.22",
|
||||
"seq-macro",
|
||||
"serde",
|
||||
"serde-wasm-bindgen",
|
||||
|
||||
@@ -23,6 +23,7 @@ halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curv
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.5.3", features = ["derive"] }
|
||||
clap_complete = "4.5.2"
|
||||
serde = { version = "1.0.126", features = ["derive"], optional = true }
|
||||
serde_json = { version = "1.0.97", default_features = false, features = [
|
||||
"float_roundtrip",
|
||||
@@ -44,10 +45,11 @@ num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
semver = "1.0.22"
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev="5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
|
||||
foundry-compilers = {version = "0.4.1", features = ["svm-solc"]}
|
||||
ethabi = "18"
|
||||
indicatif = { version = "0.17.5", features = ["rayon"] }
|
||||
|
||||
@@ -93,6 +93,79 @@ contract LoadInstances {
|
||||
}
|
||||
}
|
||||
|
||||
// Contract that checks that the COMMITMENT_KZG bytes is equal to the first part of the proof.
|
||||
pragma solidity ^0.8.0;
|
||||
|
||||
// The kzg commitments of a given model, all aggregated into a single bytes array.
|
||||
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
|
||||
// It will be used to check that the proof commitments match the expected commitments.
|
||||
bytes constant COMMITMENT_KZG = hex"";
|
||||
|
||||
contract SwapProofCommitments {
|
||||
/**
|
||||
* @dev Swap the proof commitments
|
||||
* @notice must pass encoded bytes from memory
|
||||
* @param encoded - verifier calldata
|
||||
*/
|
||||
function checkKzgCommits(
|
||||
bytes calldata encoded
|
||||
) internal pure returns (bool equal) {
|
||||
bytes4 funcSig;
|
||||
uint256 proof_offset;
|
||||
uint256 proof_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := calldataload(encoded.offset)
|
||||
|
||||
// Fetch proof offset which is 4 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
proof_offset := calldataload(
|
||||
add(
|
||||
encoded.offset,
|
||||
add(0x04, mul(0x20, eq(funcSig, 0xaf83a18d)))
|
||||
)
|
||||
)
|
||||
|
||||
proof_length := calldataload(
|
||||
add(add(encoded.offset, 0x04), proof_offset)
|
||||
)
|
||||
}
|
||||
// Check the length of the commitment against the proof bytes
|
||||
if (proof_length < COMMITMENT_KZG.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load COMMITMENT_KZG into memory
|
||||
bytes memory commitment = COMMITMENT_KZG;
|
||||
|
||||
// Compare the first N bytes of the proof with COMMITMENT_KZG
|
||||
uint words = (commitment.length + 31) / 32; // Calculate the number of 32-byte words
|
||||
|
||||
assembly {
|
||||
// Now we compare the commitment with the proof,
|
||||
// ensuring that the commitments divided up into 32 byte words are all equal.
|
||||
for {
|
||||
let i := 0x20
|
||||
} lt(i, add(mul(words, 0x20), 0x20)) {
|
||||
i := add(i, 0x20)
|
||||
} {
|
||||
let wordProof := calldataload(
|
||||
add(add(encoded.offset, add(i, 0x04)), proof_offset)
|
||||
)
|
||||
let wordCommitment := mload(add(commitment, i))
|
||||
equal := eq(wordProof, wordCommitment)
|
||||
if eq(equal, 0) {
|
||||
return(0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return equal; // Return true if the commitment comparison passed
|
||||
}
|
||||
}
|
||||
|
||||
// This contract serves as a Data Attestation Verifier for the EZKL model.
|
||||
// It is designed to read and attest to instances of proofs generated from a specified circuit.
|
||||
// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions.
|
||||
@@ -104,9 +177,10 @@ contract LoadInstances {
|
||||
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
|
||||
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
|
||||
// 6. Proof Verification: The `verifyWithDataAttestation` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
|
||||
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
|
||||
// then calls the `verifyProof` method to verify the proof on the verifier.
|
||||
|
||||
contract DataAttestation is LoadInstances {
|
||||
contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
|
||||
* @param the address of the account to make calls to
|
||||
@@ -350,12 +424,18 @@ contract DataAttestation is LoadInstances {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Verify the proof with the data attestation.
|
||||
* @param verifier - The address of the verifier contract.
|
||||
* @param encoded - The verifier calldata.
|
||||
*/
|
||||
function verifyWithDataAttestation(
|
||||
address verifier,
|
||||
bytes calldata encoded
|
||||
) public view returns (bool) {
|
||||
require(verifier.code.length > 0, "Address: call to non-contract");
|
||||
attestData(getInstancesCalldata(encoded));
|
||||
require(checkKzgCommits(encoded), "Invalid KZG commitments");
|
||||
// static call the verifier contract to verify the proof
|
||||
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
ezkl==11.3.3
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '11.3.3'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
604
examples/notebooks/data_attest_kzg_vis.ipynb
Normal file
604
examples/notebooks/data_attest_kzg_vis.ipynb
Normal file
@@ -0,0 +1,604 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# data-attest-kzg-vis\n",
|
||||
"\n",
|
||||
"Here's an example leveraging EZKL whereby the inputs to the model are read and attested to from an on-chain source and the params and outputs are committed to using kzg-commitments. \n",
|
||||
"\n",
|
||||
"In this setup:\n",
|
||||
"- the inputs and outputs are publicly known to the prover and verifier\n",
|
||||
"- the on chain inputs will be fetched and then fed directly into the circuit\n",
|
||||
"- the quantization of the on-chain inputs happens within the evm and is replicated at proving time \n",
|
||||
"- The kzg commitment to the params and inputs will be read from the proof and checked to make sure it matches the expected commitment stored on-chain.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First we import the necessary dependencies and set up logging to be as informative as possible. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from torch import nn\n",
|
||||
"import ezkl\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"# uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we define our model. It is a very simple PyTorch model that has just one layer, an average pooling 2D layer. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"# Defines the model\n",
|
||||
"\n",
|
||||
"class MyModel(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MyModel, self).__init__()\n",
|
||||
" self.layer = nn.AvgPool2d(2, 1, (1, 1))\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return self.layer(x)[0]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = MyModel()\n",
|
||||
"\n",
|
||||
"# this is where you'd train your model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. \n",
|
||||
"Now we export the model to onnx and create a corresponding (randomly generated) input. This input data will eventually be stored on chain and read from according to the call_data field in the graph input.\n",
|
||||
"\n",
|
||||
"You can replace the random `x` with real data if you so wish. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x = 0.1*torch.rand(1,*[3, 2, 2], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
" # Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"network.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w' ))\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We now define a function that will create a new anvil instance which we will deploy our test contract too. This contract will contain in its storage the data that we will read from and attest to. In production you would not need to set up a local anvil instance. Instead you would replace RPC_URL with the actual RPC endpoint of the chain you are deploying your verifiers too, reading from the data on said chain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import time\n",
|
||||
"import threading\n",
|
||||
"\n",
|
||||
"# make sure anvil is running locally\n",
|
||||
"# $ anvil -p 3030\n",
|
||||
"\n",
|
||||
"RPC_URL = \"http://localhost:3030\"\n",
|
||||
"\n",
|
||||
"# Save process globally\n",
|
||||
"anvil_process = None\n",
|
||||
"\n",
|
||||
"def start_anvil():\n",
|
||||
" global anvil_process\n",
|
||||
" if anvil_process is None:\n",
|
||||
" anvil_process = subprocess.Popen([\"anvil\", \"-p\", \"3030\", \"--code-size-limit=41943040\"])\n",
|
||||
" if anvil_process.returncode is not None:\n",
|
||||
" raise Exception(\"failed to start anvil process\")\n",
|
||||
" time.sleep(3)\n",
|
||||
"\n",
|
||||
"def stop_anvil():\n",
|
||||
" global anvil_process\n",
|
||||
" if anvil_process is not None:\n",
|
||||
" anvil_process.terminate()\n",
|
||||
" anvil_process = None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We define our `PyRunArgs` objects which contains the visibility parameters for out model. \n",
|
||||
"- `input_visibility` defines the visibility of the model inputs\n",
|
||||
"- `param_visibility` defines the visibility of the model weights and constants and parameters \n",
|
||||
"- `output_visibility` defines the visibility of the model outputs\n",
|
||||
"\n",
|
||||
"Here we create the following setup:\n",
|
||||
"- `input_visibility`: \"public\"\n",
|
||||
"- `param_visibility`: \"polycommitment\" \n",
|
||||
"- `output_visibility`: \"polycommitment\"\n",
|
||||
"\n",
|
||||
"**Note**:\n",
|
||||
"When we set this to polycommitment, we are saying that the model parameters are committed to using a polynomial commitment scheme. This commitment will be stored on chain as a constant stored in the DA contract, and the proof will contain the commitment to the parameters. The DA verification will then check that the commitment in the proof matches the commitment stored on chain. \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ezkl\n",
|
||||
"\n",
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"srs_path = os.path.join('kzg.srs')\n",
|
||||
"data_path = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.input_visibility = \"public\"\n",
|
||||
"run_args.param_visibility = \"polycommit\"\n",
|
||||
"run_args.output_visibility = \"polycommit\"\n",
|
||||
"run_args.num_inner_cols = 1\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.\n",
|
||||
"\n",
|
||||
"You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate a bunch of dummy calibration data\n",
|
||||
"cal_data = {\n",
|
||||
" \"input_data\": [(0.1*torch.rand(2, *[3, 2, 2])).flatten().tolist()],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"cal_path = os.path.join('val_data.json')\n",
|
||||
"# save as json file\n",
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
|
||||
"\n",
|
||||
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
|
||||
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like:\n",
|
||||
" \n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": [\n",
|
||||
" {\n",
|
||||
" \"call_data\": [\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
|
||||
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
|
||||
" 5\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
|
||||
" 5\n",
|
||||
" ]\n",
|
||||
" ],\n",
|
||||
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await ezkl.setup_test_evm_witness(\n",
|
||||
" data_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" # we write the call data to the same file as the input data\n",
|
||||
" data_path,\n",
|
||||
" input_source=ezkl.PyTestDataSource.OnChain,\n",
|
||||
" output_source=ezkl.PyTestDataSource.File,\n",
|
||||
" rpc_url=RPC_URL)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
|
||||
"\n",
|
||||
"These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = await ezkl.get_srs( settings_path)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we generate a full proof. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And verify it as a sanity check. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now create and then deploy a vanilla evm verifier."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
" )\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"addr_path_verifier = \"addr_verifier.txt\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When deploying a DA with kzg commitments, we need to make sure to also pass a witness file that contains the commitments to the parameters and inputs. This is because the verifier will need to check that the commitments in the proof match the commitments stored on chain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"input_path = 'input.json'\n",
|
||||
"\n",
|
||||
"res = await ezkl.create_evm_data_attestation(\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
" witness_path = witness_path,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can deploy the data attest verifier contract. For security reasons, this binding will only deploy to a local anvil instance, using accounts generated by anvil. \n",
|
||||
"So should only be used for testing purposes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"addr_path_da = \"addr_da.txt\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read the verifier address\n",
|
||||
"addr_verifier = None\n",
|
||||
"with open(addr_path_verifier, 'r') as f:\n",
|
||||
" addr = f.read()\n",
|
||||
"#read the data attestation address\n",
|
||||
"addr_da = None\n",
|
||||
"with open(addr_path_da, 'r') as f:\n",
|
||||
" addr_da = f.read()\n",
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" addr_da,\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ezkl",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -482,7 +482,7 @@
|
||||
"source": [
|
||||
"import pytest\n",
|
||||
"def test_verification():\n",
|
||||
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
|
||||
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
|
||||
" ezkl.verify(\n",
|
||||
" proof_path_faulty,\n",
|
||||
" settings_path,\n",
|
||||
@@ -514,9 +514,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,6 +157,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b78d3cbf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -192,7 +193,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -298,7 +299,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -169,7 +169,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -302,4 +302,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -478,12 +478,11 @@
|
||||
"import pytest\n",
|
||||
"\n",
|
||||
"def test_verification():\n",
|
||||
" with pytest.raises(RuntimeError, match='Failed to run verify: The constraint system is not satisfied'):\n",
|
||||
" with pytest.raises(RuntimeError, match='Failed to run verify: \\\\[halo2\\\\] The constraint system is not satisfied'):\n",
|
||||
" ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Run the test function\n",
|
||||
@@ -510,9 +509,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,6 +99,10 @@ fi
|
||||
echo "Removing old ezkl binary if it exists"
|
||||
[ -e file ] && rm file
|
||||
|
||||
# echo platform and architecture
|
||||
echo "Platform: $PLATFORM"
|
||||
echo "Architecture: $ARCHITECTURE"
|
||||
|
||||
# download the release and unpack the right tarball
|
||||
if [ "$PLATFORM" == "windows-msvc" ]; then
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
@@ -126,7 +130,6 @@ elif [ "$PLATFORM" == "macos" ]; then
|
||||
|
||||
echo "Cleaning up"
|
||||
rm "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
|
||||
|
||||
else
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos.tar.gz")
|
||||
@@ -155,7 +158,7 @@ elif [ "$PLATFORM" == "linux" ]; then
|
||||
|
||||
echo "Cleaning up"
|
||||
rm "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
|
||||
else if [ "$ARCHITECTURE" == "aarch64" ]; then
|
||||
elif [ "$ARCHITECTURE" == "aarch64" ]; then
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-aarch64.tar.gz")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// ignore file if compiling for wasm
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use clap::Parser;
|
||||
use clap::{CommandFactory, Parser};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored_json::ToColoredJson;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -17,29 +17,43 @@ use rand::prelude::SliceRandom;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(feature = "icicle")]
|
||||
use std::env;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::error::Error;
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
pub async fn main() {
|
||||
let args = Cli::parse();
|
||||
init_logger();
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
banner();
|
||||
#[cfg(feature = "icicle")]
|
||||
if env::var("ENABLE_ICICLE_GPU").is_ok() {
|
||||
info!("Running with ICICLE GPU");
|
||||
|
||||
if let Some(generator) = args.generator {
|
||||
ezkl::commands::print_completions(generator, &mut Cli::command());
|
||||
} else if let Some(command) = args.command {
|
||||
init_logger();
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
banner();
|
||||
#[cfg(feature = "icicle")]
|
||||
if env::var("ENABLE_ICICLE_GPU").is_ok() {
|
||||
info!("Running with ICICLE GPU");
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
}
|
||||
info!(
|
||||
"command: \n {}",
|
||||
&command.as_json().to_colored_json_auto().unwrap()
|
||||
);
|
||||
let res = run(command).await;
|
||||
match &res {
|
||||
Ok(_) => {
|
||||
info!("succeeded");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("{}", e);
|
||||
std::process::exit(1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
init_logger();
|
||||
error!("No command provided");
|
||||
std::process::exit(1)
|
||||
}
|
||||
info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
let res = run(args.command).await;
|
||||
match &res {
|
||||
Ok(_) => info!("succeeded"),
|
||||
Err(e) => error!("failed: {}", e),
|
||||
};
|
||||
res.map(|_| ())
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
|
||||
25
src/circuit/modules/errors.rs
Normal file
25
src/circuit/modules/errors.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use halo2_proofs::plonk::Error as PlonkError;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error type for the circuit module
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ModuleError {
|
||||
/// Halo 2 error
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] PlonkError),
|
||||
/// Wrong input type for a module
|
||||
#[error("wrong input type {0} must be {1}")]
|
||||
WrongInputType(String, String),
|
||||
/// A constant was not previously assigned
|
||||
#[error("constant was not previously assigned")]
|
||||
ConstantNotAssigned,
|
||||
/// Input length is wrong
|
||||
#[error("input length is wrong {0}")]
|
||||
InputWrongLength(usize),
|
||||
}
|
||||
|
||||
impl From<ModuleError> for PlonkError {
|
||||
fn from(_e: ModuleError) -> PlonkError {
|
||||
PlonkError::Synthesis
|
||||
}
|
||||
}
|
||||
@@ -6,10 +6,11 @@ pub mod polycommit;
|
||||
|
||||
///
|
||||
pub mod planner;
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Error},
|
||||
};
|
||||
|
||||
///
|
||||
pub mod errors;
|
||||
|
||||
use halo2_proofs::{circuit::Layouter, plonk::ConstraintSystem};
|
||||
use halo2curves::ff::PrimeField;
|
||||
pub use planner::*;
|
||||
|
||||
@@ -35,14 +36,14 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Name
|
||||
fn name(&self) -> &'static str;
|
||||
/// Run the operation the module represents
|
||||
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, Box<dyn std::error::Error>>;
|
||||
fn run(input: Self::RunInputs) -> Result<Vec<Vec<F>>, errors::ModuleError>;
|
||||
/// Layout inputs
|
||||
fn layout_inputs(
|
||||
&self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
input: &[ValTensor<F>],
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<Self::InputAssignments, Error>;
|
||||
) -> Result<Self::InputAssignments, errors::ModuleError>;
|
||||
/// Layout
|
||||
fn layout(
|
||||
&self,
|
||||
@@ -50,7 +51,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
input: &[ValTensor<F>],
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, Error>;
|
||||
) -> Result<ValTensor<F>, errors::ModuleError>;
|
||||
/// Number of instance values the module uses every time it is applied
|
||||
fn instance_increment_input(&self) -> Vec<usize>;
|
||||
/// Number of rows used by the module
|
||||
|
||||
@@ -18,6 +18,7 @@ use halo2curves::CurveAffine;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
|
||||
|
||||
use super::errors::ModuleError;
|
||||
use super::Module;
|
||||
|
||||
/// The number of instance columns used by the PolyCommit hash function
|
||||
@@ -110,7 +111,7 @@ impl Module<Fp> for PolyCommitChip {
|
||||
_: &mut impl Layouter<Fp>,
|
||||
_: &[ValTensor<Fp>],
|
||||
_: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, Error> {
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -123,28 +124,30 @@ impl Module<Fp> for PolyCommitChip {
|
||||
input: &[ValTensor<Fp>],
|
||||
_: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, Error> {
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
assert_eq!(input.len(), 1);
|
||||
|
||||
let local_constants = constants.clone();
|
||||
layouter.assign_region(
|
||||
|| "PolyCommit",
|
||||
|mut region| {
|
||||
let mut local_inner_constants = local_constants.clone();
|
||||
let res = self.config.inputs.assign(
|
||||
&mut region,
|
||||
0,
|
||||
&input[0],
|
||||
&mut local_inner_constants,
|
||||
)?;
|
||||
*constants = local_inner_constants;
|
||||
Ok(res)
|
||||
},
|
||||
)
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "PolyCommit",
|
||||
|mut region| {
|
||||
let mut local_inner_constants = local_constants.clone();
|
||||
let res = self.config.inputs.assign(
|
||||
&mut region,
|
||||
0,
|
||||
&input[0],
|
||||
&mut local_inner_constants,
|
||||
)?;
|
||||
*constants = local_inner_constants;
|
||||
Ok(res)
|
||||
},
|
||||
)
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
///
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
|
||||
Ok(vec![message])
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ use std::marker::PhantomData;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType};
|
||||
|
||||
use super::errors::ModuleError;
|
||||
use super::Module;
|
||||
|
||||
/// The number of instance columns used by the Poseidon hash function
|
||||
@@ -174,7 +175,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, Error> {
|
||||
) -> Result<Self::InputAssignments, ModuleError> {
|
||||
assert_eq!(message.len(), 1);
|
||||
let message = message[0].clone();
|
||||
|
||||
@@ -185,78 +186,82 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let res = layouter.assign_region(
|
||||
|| "load message",
|
||||
|mut region| {
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, Error> = match &message {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
let assigned_message: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> =
|
||||
match &message {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
v.iter()
|
||||
.enumerate()
|
||||
.map(|(i, value)| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
|
||||
match value {
|
||||
ValType::Value(v) => region.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
Ok(v.clone())
|
||||
}
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
|
||||
log::error!("constant not previously assigned");
|
||||
Error::Synthesis
|
||||
})?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
match value {
|
||||
ValType::Value(v) => region
|
||||
.assign_advice(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
|| *v,
|
||||
)
|
||||
.map_err(|e| e.into()),
|
||||
ValType::PrevAssigned(v)
|
||||
| ValType::AssignedConstant(v, ..) => Ok(v.clone()),
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants
|
||||
.get(f)
|
||||
.unwrap()
|
||||
.assigned_cell()
|
||||
.ok_or(ModuleError::ConstantNotAssigned)?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
constants
|
||||
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
|
||||
constants.insert(
|
||||
*f,
|
||||
ValType::AssignedConstant(res.clone(), *f),
|
||||
);
|
||||
|
||||
Ok(res)
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => Err(ModuleError::WrongInputType(
|
||||
format!("{:?}", e),
|
||||
"PrevAssigned".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
e => {
|
||||
log::error!(
|
||||
"wrong input type {:?}, must be previously assigned",
|
||||
e
|
||||
);
|
||||
Err(Error::Synthesis)
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
ValTensor::Instance {
|
||||
dims,
|
||||
inner: col,
|
||||
idx,
|
||||
initial_offset,
|
||||
..
|
||||
} => {
|
||||
// this should never ever fail
|
||||
let num_elems = dims[*idx].iter().product::<usize>();
|
||||
(0..num_elems)
|
||||
.map(|i| {
|
||||
let x = i % WIDTH;
|
||||
let y = i / WIDTH;
|
||||
region.assign_advice_from_instance(
|
||||
|| "pub input anchor",
|
||||
*col,
|
||||
initial_offset + i,
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
};
|
||||
|
||||
let offset = message.len() / WIDTH + 1;
|
||||
|
||||
@@ -277,7 +282,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
message.len(),
|
||||
start_time.elapsed()
|
||||
);
|
||||
res
|
||||
res.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// L is the number of inputs to the hash function
|
||||
@@ -289,7 +294,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
input: &[ValTensor<Fp>],
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, Error> {
|
||||
) -> Result<ValTensor<Fp>, ModuleError> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
@@ -301,7 +306,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
let mut one_iter = false;
|
||||
// do the Tree dance baby
|
||||
while input_cells.len() > 1 || !one_iter {
|
||||
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, Error> = input_cells
|
||||
let hashes: Result<Vec<AssignedCell<Fp, Fp>>, ModuleError> = input_cells
|
||||
.chunks(L)
|
||||
.enumerate()
|
||||
.map(|(i, block)| {
|
||||
@@ -332,7 +337,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
|
||||
hash
|
||||
})
|
||||
.collect();
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| e.into());
|
||||
|
||||
log::trace!("hashes (N={:?}) took: {:?}", len, start_time.elapsed());
|
||||
one_iter = true;
|
||||
@@ -348,7 +354,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
ValType::PrevAssigned(v) => v,
|
||||
_ => {
|
||||
log::error!("wrong input type, must be previously assigned");
|
||||
return Err(Error::Synthesis);
|
||||
return Err(Error::Synthesis.into());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -380,7 +386,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
}
|
||||
|
||||
///
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, Box<dyn std::error::Error>> {
|
||||
fn run(message: Vec<Fp>) -> Result<Vec<Vec<Fp>>, ModuleError> {
|
||||
let mut hash_inputs = message;
|
||||
|
||||
let len = hash_inputs.len();
|
||||
@@ -400,7 +406,11 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
block.extend(vec![Fp::ZERO; L - remainder].iter());
|
||||
}
|
||||
|
||||
let message = block.try_into().map_err(|_| Error::Synthesis)?;
|
||||
let block_len = block.len();
|
||||
|
||||
let message = block
|
||||
.try_into()
|
||||
.map_err(|_| ModuleError::InputWrongLength(block_len))?;
|
||||
|
||||
Ok(halo2_gadgets::poseidon::primitives::Hash::<
|
||||
_,
|
||||
@@ -411,7 +421,7 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
>::init()
|
||||
.hash(message))
|
||||
})
|
||||
.collect::<Result<Vec<_>, Error>>()?;
|
||||
.collect::<Result<Vec<_>, ModuleError>>()?;
|
||||
one_iter = true;
|
||||
hash_inputs = hashes;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{ConstraintSystem, Constraints, Expression, Selector},
|
||||
@@ -26,31 +24,11 @@ use crate::{
|
||||
},
|
||||
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
|
||||
use std::{collections::BTreeMap, marker::PhantomData};
|
||||
|
||||
use super::{lookup::LookupOp, region::RegionCtx, Op};
|
||||
use super::{lookup::LookupOp, region::RegionCtx, CircuitError, Op};
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CircuitError {
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("dimension mismatch in circuit construction for op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Error when instantiating lookup tables
|
||||
#[error("failed to instantiate lookup tables")]
|
||||
LookupInstantiation,
|
||||
/// A lookup table was was already assigned
|
||||
#[error("attempting to initialize an already instantiated lookup table")]
|
||||
TableAlreadyAssigned,
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
///
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
/// An enum representing activating the sanity checks we can perform on the accumulated arguments
|
||||
#[derive(
|
||||
@@ -513,18 +491,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
nl: &LookupOp,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
if !index.is_advice() {
|
||||
return Err("wrong input type for lookup index".into());
|
||||
return Err(CircuitError::WrongColumnType(index.name().to_string()));
|
||||
}
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
return Err(CircuitError::WrongColumnType(input.name().to_string()));
|
||||
}
|
||||
if !output.is_advice() {
|
||||
return Err("wrong input type for lookup output".into());
|
||||
return Err(CircuitError::WrongColumnType(output.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
@@ -654,19 +632,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
lookups: &[VarTensor; 3],
|
||||
tables: &[VarTensor; 3],
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
for l in lookups.iter() {
|
||||
if !l.is_advice() {
|
||||
return Err("wrong input type for dynamic lookup".into());
|
||||
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
for t in tables.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
return Err("wrong table type for dynamic lookup".into());
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -737,19 +715,19 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
inputs: &[VarTensor; 2],
|
||||
references: &[VarTensor; 2],
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
for l in inputs.iter() {
|
||||
if !l.is_advice() {
|
||||
return Err("wrong input type for dynamic lookup".into());
|
||||
return Err(CircuitError::WrongDynamicColumnType(l.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
for t in references.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
return Err("wrong table type for dynamic lookup".into());
|
||||
return Err(CircuitError::WrongDynamicColumnType(t.name().to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -822,12 +800,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
index: &VarTensor,
|
||||
range: Range,
|
||||
logrows: usize,
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
) -> Result<(), CircuitError>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
return Err(CircuitError::WrongColumnType(input.name().to_string()));
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
@@ -918,7 +896,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
}
|
||||
|
||||
/// layout_tables must be called before layout.
|
||||
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
|
||||
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
|
||||
if !table.is_assigned {
|
||||
debug!(
|
||||
@@ -939,7 +917,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
pub fn layout_range_checks(
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
) -> Result<(), CircuitError> {
|
||||
for range_check in self.range_checks.ranges.values_mut() {
|
||||
if !range_check.is_assigned {
|
||||
debug!("laying out range check for {:?}", range_check.range);
|
||||
@@ -959,7 +937,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseCo
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
op: Box<dyn Op<F>>,
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
op.layout(self, region, values)
|
||||
}
|
||||
}
|
||||
|
||||
94
src/circuit/ops/errors.rs
Normal file
94
src/circuit/ops/errors.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use crate::tensor::TensorError;
|
||||
use halo2_proofs::plonk::Error as PlonkError;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error type for the circuit module
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CircuitError {
|
||||
/// Halo 2 error
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] PlonkError),
|
||||
/// Tensor error
|
||||
#[error("[tensor] {0}")]
|
||||
TensorError(#[from] TensorError),
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("dimension mismatch in circuit construction for op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Error when instantiating lookup tables
|
||||
#[error("failed to instantiate lookup tables")]
|
||||
LookupInstantiation,
|
||||
/// A lookup table was was already assigned
|
||||
#[error("attempting to initialize an already instantiated lookup table")]
|
||||
TableAlreadyAssigned,
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
///
|
||||
#[error("invalid einsum expression")]
|
||||
InvalidEinsum,
|
||||
/// Flush error
|
||||
#[error("failed to flush, linear coord is not aligned with the next row")]
|
||||
FlushError,
|
||||
/// Constrain error
|
||||
#[error("constrain_equal: one of the tensors is assigned and the other is not")]
|
||||
ConstrainError,
|
||||
/// Failed to get lookups
|
||||
#[error("failed to get lookups for op: {0}")]
|
||||
GetLookupsError(String),
|
||||
/// Failed to get range checks
|
||||
#[error("failed to get range checks for op: {0}")]
|
||||
GetRangeChecksError(String),
|
||||
/// Failed to get dynamic lookup
|
||||
#[error("failed to get dynamic lookup for op: {0}")]
|
||||
GetDynamicLookupError(String),
|
||||
/// Failed to get shuffle
|
||||
#[error("failed to get shuffle for op: {0}")]
|
||||
GetShuffleError(String),
|
||||
/// Failed to get constants
|
||||
#[error("failed to get constants for op: {0}")]
|
||||
GetConstantsError(String),
|
||||
/// Slice length mismatch
|
||||
#[error("slice length mismatch: {0}")]
|
||||
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
|
||||
/// Bad conversion
|
||||
#[error("invalid conversion: {0}")]
|
||||
InvalidConversion(#[from] Infallible),
|
||||
/// Invalid min/max lookup range
|
||||
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
|
||||
InvalidMinMaxRange(i64, i64),
|
||||
/// Missing product in einsum
|
||||
#[error("missing product in einsum")]
|
||||
MissingEinsumProduct,
|
||||
/// Mismatched lookup length
|
||||
#[error("mismatched lookup lengths: {0} and {1}")]
|
||||
MismatchedLookupLength(usize, usize),
|
||||
/// Mismatched shuffle length
|
||||
#[error("mismatched shuffle lengths: {0} and {1}")]
|
||||
MismatchedShuffleLength(usize, usize),
|
||||
/// Mismatched lookup table lengths
|
||||
#[error("mismatched lookup table lengths: {0} and {1}")]
|
||||
MismatchedLookupTableLength(usize, usize),
|
||||
/// Wrong column type for lookup
|
||||
#[error("wrong column type for lookup: {0}")]
|
||||
WrongColumnType(String),
|
||||
/// Wrong column type for dynamic lookup
|
||||
#[error("wrong column type for dynamic lookup: {0}")]
|
||||
WrongDynamicColumnType(String),
|
||||
/// Missing selectors
|
||||
#[error("missing selectors for op: {0}")]
|
||||
MissingSelectors(String),
|
||||
/// Table lookup error
|
||||
#[error("value ({0}) out of range: ({1}, {2})")]
|
||||
TableOOR(i64, i64, i64),
|
||||
/// Loookup not configured
|
||||
#[error("lookup not configured: {0}")]
|
||||
LookupNotConfigured(String),
|
||||
/// Range check not configured
|
||||
#[error("range check not configured: {0}")]
|
||||
RangeCheckNotConfigured(String),
|
||||
/// Missing layout
|
||||
#[error("missing layout for op: {0}")]
|
||||
MissingLayout(String),
|
||||
}
|
||||
@@ -155,7 +155,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
@@ -287,7 +287,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
}))
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
HybridOp::Greater { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,5 @@
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
@@ -295,7 +294,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
@@ -305,7 +304,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
}
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
LookupOp::Cast { scale } => {
|
||||
let in_scale = inputs_scale[0];
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{any::Any, error::Error};
|
||||
use std::any::Any;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -15,6 +15,8 @@ pub mod base;
|
||||
///
|
||||
pub mod chip;
|
||||
///
|
||||
pub mod errors;
|
||||
///
|
||||
pub mod hybrid;
|
||||
/// Layouts for specific functions (composed of base ops)
|
||||
pub mod layouts;
|
||||
@@ -25,6 +27,8 @@ pub mod poly;
|
||||
///
|
||||
pub mod region;
|
||||
|
||||
pub use errors::CircuitError;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
@@ -44,10 +48,10 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>>;
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError>;
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>>;
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError>;
|
||||
|
||||
/// Do any of the inputs to this op require homogenous input scales?
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
@@ -139,7 +143,7 @@ pub struct Input {
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.scale)
|
||||
}
|
||||
|
||||
@@ -156,7 +160,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
let value = values[0].clone();
|
||||
if !value.all_prev_assigned() {
|
||||
match self.datum_type {
|
||||
@@ -194,7 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
pub struct Unknown;
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(0)
|
||||
}
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
@@ -209,8 +213,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
_: &mut crate::circuit::BaseConfig<F>,
|
||||
_: &mut RegionCtx<F>,
|
||||
_: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
Err(Box::new(super::CircuitError::UnsupportedOp))
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Err(super::CircuitError::UnsupportedOp)
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
@@ -240,7 +244,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Consta
|
||||
}
|
||||
}
|
||||
/// Rebase the scale of the constant
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), Box<dyn Error>> {
|
||||
pub fn rebase_scale(&mut self, new_scale: crate::Scale) -> Result<(), CircuitError> {
|
||||
let visibility = self.quantized_values.visibility().unwrap();
|
||||
self.quantized_values = quantize_tensor(self.raw_values.clone(), new_scale, &visibility)?;
|
||||
Ok(())
|
||||
@@ -279,7 +283,7 @@ impl<
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
_: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
let value = if let Some(value) = &self.pre_assigned_val {
|
||||
value.clone()
|
||||
} else {
|
||||
@@ -293,7 +297,7 @@ impl<
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.quantized_values.scale().unwrap())
|
||||
}
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ impl<
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
@@ -278,9 +278,10 @@ impl<
|
||||
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
|
||||
PolyOp::Pad(p) => {
|
||||
if values.len() != 1 {
|
||||
return Err(Box::new(TensorError::DimError(
|
||||
return Err(TensorError::DimError(
|
||||
"Pad operation requires a single input".to_string(),
|
||||
)));
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let mut input = values[0].clone();
|
||||
input.pad(p.clone(), 0)?;
|
||||
@@ -297,7 +298,7 @@ impl<
|
||||
}))
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
circuit::table::Range,
|
||||
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
|
||||
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
@@ -19,7 +19,7 @@ use std::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
use super::{lookup::LookupOp, CircuitError};
|
||||
|
||||
/// Constants map
|
||||
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
|
||||
@@ -84,44 +84,6 @@ impl ShuffleIndex {
|
||||
}
|
||||
}
|
||||
|
||||
/// Region error
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RegionError {
|
||||
/// wrap other regions
|
||||
#[error("Wrapped region: {0}")]
|
||||
Wrapped(String),
|
||||
}
|
||||
|
||||
impl From<String> for RegionError {
|
||||
fn from(e: String) -> Self {
|
||||
Self::Wrapped(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for RegionError {
|
||||
fn from(e: &str) -> Self {
|
||||
Self::Wrapped(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TensorError> for RegionError {
|
||||
fn from(e: TensorError) -> Self {
|
||||
Self::Wrapped(format!("{:?}", e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for RegionError {
|
||||
fn from(e: Error) -> Self {
|
||||
Self::Wrapped(format!("{:?}", e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Box<dyn std::error::Error>> for RegionError {
|
||||
fn from(e: Box<dyn std::error::Error>) -> Self {
|
||||
Self::Wrapped(format!("{:?}", e))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A context for a region
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
@@ -317,10 +279,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
pub fn apply_in_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Result<(), RegionError> {
|
||||
) -> Result<(), CircuitError> {
|
||||
if self.is_dummy() {
|
||||
self.dummy_loop(output, inner_loop_function)?;
|
||||
} else {
|
||||
@@ -333,8 +295,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
pub fn real_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>,
|
||||
) -> Result<(), RegionError> {
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>,
|
||||
) -> Result<(), CircuitError> {
|
||||
output
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
@@ -342,7 +304,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
*o = inner_loop_function(i, self)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, RegionError>>()?;
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -353,10 +315,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
pub fn dummy_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, RegionError>
|
||||
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Result<(), RegionError> {
|
||||
) -> Result<(), CircuitError> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
@@ -367,50 +329,48 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output
|
||||
.par_enum_map(|idx, _| {
|
||||
// we kick off the loop with the current offset
|
||||
let starting_offset = row.load(Ordering::SeqCst);
|
||||
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
|
||||
// get inner value of the locked lookups
|
||||
*output = output.par_enum_map(|idx, _| {
|
||||
// we kick off the loop with the current offset
|
||||
let starting_offset = row.load(Ordering::SeqCst);
|
||||
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
|
||||
// get inner value of the locked lookups
|
||||
|
||||
// we need to make sure that the region is not shared between threads
|
||||
let mut local_reg = Self::new_dummy_with_linear_coord(
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.witness_gen,
|
||||
self.check_lookup_range,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
|
||||
linear_coord.fetch_add(
|
||||
local_reg.linear_coord() - starting_linear_coord,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
// we need to make sure that the region is not shared between threads
|
||||
let mut local_reg = Self::new_dummy_with_linear_coord(
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.witness_gen,
|
||||
self.check_lookup_range,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
row.fetch_add(local_reg.row() - starting_offset, Ordering::SeqCst);
|
||||
linear_coord.fetch_add(
|
||||
local_reg.linear_coord() - starting_linear_coord,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
// update the range checks
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
// update the dynamic lookup index
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
// update the range checks
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
// update the dynamic lookup index
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
|
||||
res
|
||||
})
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
|
||||
res
|
||||
})?;
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
@@ -419,49 +379,25 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
|
||||
self.used_range_checks = Arc::try_unwrap(range_checks)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?;
|
||||
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!(
|
||||
"dummy_loop: failed to get dynamic lookup index: {:?}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!(
|
||||
"dummy_loop: failed to get dynamic lookup index: {:?}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?;
|
||||
self.shuffle_index = Arc::try_unwrap(shuffle_index)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
|
||||
})?
|
||||
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetShuffleError(format!("{:?}", e)))?;
|
||||
self.assigned_constants = Arc::try_unwrap(constants)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
|
||||
})?
|
||||
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
|
||||
})?;
|
||||
.map_err(|e| CircuitError::GetConstantsError(format!("{:?}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -470,7 +406,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
pub fn update_max_min_lookup_inputs(
|
||||
&mut self,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), CircuitError> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
|
||||
@@ -482,12 +418,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(
|
||||
&mut self,
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
if range.0 > range.1 {
|
||||
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
|
||||
return Err(CircuitError::InvalidMinMaxRange(range.0, range.1));
|
||||
}
|
||||
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
@@ -506,13 +439,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), CircuitError> {
|
||||
self.used_lookups.insert(lookup);
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
self.used_range_checks.insert(range);
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
@@ -707,7 +640,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// constrain equal
|
||||
pub fn constrain_equal(&mut self, a: &ValTensor<F>, b: &ValTensor<F>) -> Result<(), Error> {
|
||||
pub fn constrain_equal(
|
||||
&mut self,
|
||||
a: &ValTensor<F>,
|
||||
b: &ValTensor<F>,
|
||||
) -> Result<(), CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
let a = a.get_inner_tensor().unwrap();
|
||||
let b = b.get_inner_tensor().unwrap();
|
||||
@@ -717,12 +654,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
let b = b.get_prev_assigned();
|
||||
// if they're both assigned, we can constrain them
|
||||
if let (Some(a), Some(b)) = (&a, &b) {
|
||||
region.borrow_mut().constrain_equal(a.cell(), b.cell())
|
||||
region
|
||||
.borrow_mut()
|
||||
.constrain_equal(a.cell(), b.cell())
|
||||
.map_err(|e| e.into())
|
||||
} else if a.is_some() || b.is_some() {
|
||||
log::error!(
|
||||
"constrain_equal: one of the tensors is assigned and the other is not"
|
||||
);
|
||||
return Err(Error::Synthesis);
|
||||
return Err(CircuitError::ConstrainError);
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
@@ -748,7 +685,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// flush row to the next row
|
||||
pub fn flush(&mut self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn flush(&mut self) -> Result<(), CircuitError> {
|
||||
// increment by the difference between the current linear coord and the next row
|
||||
let remainder = self.linear_coord % self.num_inner_cols;
|
||||
if remainder != 0 {
|
||||
@@ -756,7 +693,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.increment(diff);
|
||||
}
|
||||
if self.linear_coord % self.num_inner_cols != 0 {
|
||||
return Err("flush: linear coord is not aligned with the next row".into());
|
||||
return Err(CircuitError::FlushError);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{error::Error, marker::PhantomData};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -194,9 +194,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
preassigned_input: bool,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
) -> Result<(), CircuitError> {
|
||||
if self.is_assigned {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
return Err(CircuitError::TableAlreadyAssigned);
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
@@ -342,9 +342,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
|
||||
}
|
||||
|
||||
/// Assigns values to the constraints generated when calling `configure`.
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), CircuitError> {
|
||||
if self.is_assigned {
|
||||
return Err(Box::new(CircuitError::TableAlreadyAssigned));
|
||||
return Err(CircuitError::TableAlreadyAssigned);
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
|
||||
365
src/commands.rs
365
src/commands.rs
@@ -1,6 +1,7 @@
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use alloy::primitives::Address as H160;
|
||||
use clap::{Parser, Subcommand};
|
||||
use clap::{Command, Parser, Subcommand};
|
||||
use clap_complete::{generate, Generator, Shell};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
conversion::{FromPyObject, PyTryFrom},
|
||||
@@ -10,7 +11,7 @@ use pyo3::{
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::{error::Error, str::FromStr};
|
||||
use std::str::FromStr;
|
||||
use tosubcommand::{ToFlags, ToSubcommand};
|
||||
|
||||
use crate::{pfsys::ProofType, Commitments, RunArgs};
|
||||
@@ -52,6 +53,8 @@ pub const DEFAULT_VERIFIER_AGGREGATED_ABI: &str = "verifier_aggr_abi.json";
|
||||
pub const DEFAULT_VERIFIER_DA_ABI: &str = "verifier_da_abi.json";
|
||||
/// Default solidity code
|
||||
pub const DEFAULT_SOL_CODE: &str = "evm_deploy.sol";
|
||||
/// Default calldata path
|
||||
pub const DEFAULT_CALLDATA: &str = "calldata.bytes";
|
||||
/// Default solidity code for aggregated proofs
|
||||
pub const DEFAULT_SOL_CODE_AGGREGATED: &str = "evm_deploy_aggr.sol";
|
||||
/// Default solidity code for data attestation
|
||||
@@ -78,7 +81,7 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
|
||||
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
|
||||
/// Default Compress selectors
|
||||
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
|
||||
/// Default render vk seperately
|
||||
/// Default render vk separately
|
||||
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
|
||||
/// Default VK sol path
|
||||
pub const DEFAULT_VK_SOL: &str = "vk.sol";
|
||||
@@ -253,33 +256,66 @@ lazy_static! {
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Parser, Debug, Clone, Deserialize, Serialize)]
|
||||
#[command(author, about, long_about = None)]
|
||||
#[clap(version = *VERSION)]
|
||||
pub struct Cli {
|
||||
#[command(subcommand)]
|
||||
#[allow(missing_docs)]
|
||||
pub command: Commands,
|
||||
/// Get the styles for the CLI
|
||||
pub fn get_styles() -> clap::builder::Styles {
|
||||
clap::builder::Styles::styled()
|
||||
.usage(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
|
||||
)
|
||||
.header(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
|
||||
)
|
||||
.literal(
|
||||
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta))),
|
||||
)
|
||||
.invalid(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
|
||||
)
|
||||
.error(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
|
||||
)
|
||||
.valid(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Green))),
|
||||
)
|
||||
.placeholder(
|
||||
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White))),
|
||||
)
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
/// Parse an ezkl configuration from a json
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(arg_json)
|
||||
}
|
||||
|
||||
/// Print completions for the given generator
|
||||
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
|
||||
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
|
||||
}
|
||||
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, about, long_about = None)]
|
||||
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
|
||||
pub struct Cli {
|
||||
/// If provided, outputs the completion file for given shell
|
||||
#[clap(long = "generate", value_parser)]
|
||||
pub generator: Option<Shell>,
|
||||
#[command(subcommand)]
|
||||
#[allow(missing_docs)]
|
||||
pub command: Option<Commands>,
|
||||
}
|
||||
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
|
||||
pub enum Commands {
|
||||
@@ -289,7 +325,7 @@ pub enum Commands {
|
||||
/// Loads model and prints model table
|
||||
Table {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
/// proving arguments
|
||||
#[clap(flatten)]
|
||||
@@ -299,29 +335,29 @@ pub enum Commands {
|
||||
/// Generates the witness from an input file.
|
||||
GenWitness {
|
||||
/// The path to the .json data file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// Path to output the witness .json file
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_WITNESS)]
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
|
||||
output: Option<PathBuf>,
|
||||
/// Path to the verification key file (optional - solely used to generate kzg commits)
|
||||
#[arg(short = 'V', long)]
|
||||
#[arg(short = 'V', long, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// Path to the srs file (optional - solely used to generate kzg commits)
|
||||
#[arg(short = 'P', long)]
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Produces the proving hyperparameters, from run-args
|
||||
GenSettings {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
/// The path to generate the circuit settings .json file to
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
/// proving arguments
|
||||
#[clap(flatten)]
|
||||
@@ -332,33 +368,34 @@ pub enum Commands {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
CalibrateSettings {
|
||||
/// The path to the .json calibration data file.
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE)]
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file AND overwrite (generated using the gen-settings command).
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET)]
|
||||
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET, value_hint = clap::ValueHint::Other)]
|
||||
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
|
||||
target: CalibrationTarget,
|
||||
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN)]
|
||||
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN, value_hint = clap::ValueHint::Other)]
|
||||
lookup_safety_margin: i64,
|
||||
/// Optional scales to specifically try for calibration. Example, --scales 0,4
|
||||
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
#[arg(long, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::Other)]
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale. Example, --scale-rebase-multipliers 0,4
|
||||
#[arg(
|
||||
long,
|
||||
value_delimiter = ',',
|
||||
allow_hyphen_values = true,
|
||||
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS
|
||||
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
|
||||
@@ -369,13 +406,13 @@ pub enum Commands {
|
||||
#[command(name = "gen-srs", arg_required_else_help = true)]
|
||||
GenSrs {
|
||||
/// The path to output the generated SRS
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: PathBuf,
|
||||
/// number of logrows to use for srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
logrows: usize,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
|
||||
@@ -384,57 +421,57 @@ pub enum Commands {
|
||||
#[command(name = "get-srs")]
|
||||
GetSrs {
|
||||
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
|
||||
#[arg(long)]
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
/// Number of logrows to use for srs. Overrides settings_path if specified.
|
||||
#[arg(long, default_value = None)]
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// Commitment used
|
||||
#[arg(long, default_value = None)]
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Loads model and input and runs mock prover (for testing)
|
||||
Mock {
|
||||
/// The path to the .json witness file (generated using the gen-witness command)
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
|
||||
witness: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Mock aggregate proofs
|
||||
MockAggregate {
|
||||
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// whether the accumulated are segments of a larger proof
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
split_proofs: Option<bool>,
|
||||
},
|
||||
|
||||
/// setup aggregation circuit :)
|
||||
/// Setup aggregation circuit and generate pk and vk
|
||||
SetupAggregate {
|
||||
/// The path to samples of snarks that will be aggregated over (generated using the prove command with the --proof-type=for-aggr flag)
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
/// The path to save the desired verification key file to
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to save the proving key to
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED)]
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// whether the accumulated are segments of a larger proof
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
@@ -443,19 +480,19 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
|
||||
disable_selector_compression: Option<bool>,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Aggregates proofs :)
|
||||
/// Aggregates proofs
|
||||
Aggregate {
|
||||
/// The path to the snarks to aggregate over (generated using the prove command with the --proof-type=for-aggr flag)
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
/// The path to load the desired proving key file (generated using the setup-aggregate command)
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED)]
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)]
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
@@ -465,50 +502,51 @@ pub enum Commands {
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = TranscriptType::default(),
|
||||
value_enum
|
||||
value_enum,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
transcript: TranscriptType,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// run sanity checks during calculations (safe or unsafe)
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
|
||||
check_mode: Option<CheckMode>,
|
||||
/// whether the accumulated proofs are segments of a larger circuit
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
split_proofs: Option<bool>,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
|
||||
CompileCircuit {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
#[arg(long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
},
|
||||
/// Creates pk and vk
|
||||
Setup {
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to output the verification key file to
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to output the proving key file to
|
||||
#[arg(long, default_value = DEFAULT_PK)]
|
||||
#[arg(long, default_value = DEFAULT_PK, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The graph witness (optional - used to override fixed values in the circuit)
|
||||
#[arg(short = 'W', long)]
|
||||
#[arg(short = 'W', long, value_hint = clap::ValueHint::FilePath)]
|
||||
witness: Option<PathBuf>,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
|
||||
@@ -519,24 +557,24 @@ pub enum Commands {
|
||||
#[command(arg_required_else_help = true)]
|
||||
SetupTestEvmData {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long)]
|
||||
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long)]
|
||||
#[arg(short = 'M', long, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information
|
||||
/// derived from the file information in the data .json file.
|
||||
/// Should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'T', long)]
|
||||
#[arg(short = 'T', long, value_hint = clap::ValueHint::FilePath)]
|
||||
test_data: PathBuf,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
/// where the input data come from
|
||||
#[arg(long, default_value = "on-chain")]
|
||||
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
|
||||
input_source: TestDataSource,
|
||||
/// where the output data come from
|
||||
#[arg(long, default_value = "on-chain")]
|
||||
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
|
||||
output_source: TestDataSource,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -544,23 +582,23 @@ pub enum Commands {
|
||||
#[command(arg_required_else_help = true)]
|
||||
TestUpdateAccountCalls {
|
||||
/// The path to the verifier contract's address
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr: H160Flag,
|
||||
/// The path to the .json data file.
|
||||
#[arg(short = 'D', long)]
|
||||
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Swaps the positions in the transcript that correspond to commitments
|
||||
SwapProofCommitments {
|
||||
/// The path to the proof file
|
||||
#[arg(short = 'P', long, default_value = DEFAULT_PROOF)]
|
||||
#[arg(short = 'P', long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to the witness file
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
|
||||
witness_path: Option<PathBuf>,
|
||||
},
|
||||
|
||||
@@ -568,50 +606,65 @@ pub enum Commands {
|
||||
/// Loads model, data, and creates proof
|
||||
Prove {
|
||||
/// The path to the .json witness file (generated using the gen-witness command)
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
|
||||
witness: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to load the desired proving key file (generated using the setup command)
|
||||
#[arg(long, default_value = DEFAULT_PK)]
|
||||
#[arg(long, default_value = DEFAULT_PK, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
long,
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = ProofType::Single,
|
||||
value_enum
|
||||
value_enum,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
proof_type: ProofType,
|
||||
/// run sanity checks during calculations (safe or unsafe)
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
|
||||
check_mode: Option<CheckMode>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Encodes a proof into evm calldata
|
||||
#[command(name = "encode-evm-calldata")]
|
||||
EncodeEvmCalldata {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to save the calldata to
|
||||
#[arg(long, default_value = DEFAULT_CALLDATA, value_hint = clap::ValueHint::FilePath)]
|
||||
calldata_path: Option<PathBuf>,
|
||||
/// The path to the verification key address (only used if the vk is rendered as a separate contract)
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_vk: Option<H160Flag>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEvmVerifier {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE)]
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_ABI)]
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
@@ -624,19 +677,19 @@ pub enum Commands {
|
||||
#[command(name = "create-evm-vk")]
|
||||
CreateEvmVK {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL)]
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VK_ABI)]
|
||||
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -644,21 +697,24 @@ pub enum Commands {
|
||||
#[command(name = "create-evm-da")]
|
||||
CreateEvmDataAttestation {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_DA)]
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI)]
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
/// The path to the .json data file, which should
|
||||
/// contain the necessary calldata and account addresses
|
||||
/// needed to read from all the on-chain
|
||||
/// view functions that return the data that the network
|
||||
/// ingests as inputs.
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the witness file. This is needed for proof swapping for kzg commitments.
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
|
||||
witness: Option<PathBuf>,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -666,22 +722,22 @@ pub enum Commands {
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED)]
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI)]
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
// aggregated circuit settings paths, used to calculate the number of instances in the aggregate proof
|
||||
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::FilePath)]
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
@@ -692,16 +748,16 @@ pub enum Commands {
|
||||
/// Verifies a proof, returning accept or reject
|
||||
Verify {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to the verification key file (generated using the setup command)
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
@@ -710,60 +766,60 @@ pub enum Commands {
|
||||
/// Verifies an aggregate proof, returning accept or reject
|
||||
VerifyAggr {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)]
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to the verification key file (generated using the setup-aggregate command)
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// reduced srs
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)]
|
||||
logrows: Option<u32>,
|
||||
/// commitment
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
DeployEvmVerifier {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE)]
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
|
||||
/// The path to output the contract address
|
||||
addr_path: Option<PathBuf>,
|
||||
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
|
||||
optimizer_runs: usize,
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long)]
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
DeployEvmVK {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL)]
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK)]
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK, value_hint = clap::ValueHint::Other)]
|
||||
/// The path to output the contract address
|
||||
addr_path: Option<PathBuf>,
|
||||
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
|
||||
optimizer_runs: usize,
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long)]
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -771,25 +827,25 @@ pub enum Commands {
|
||||
#[command(name = "deploy-evm-da")]
|
||||
DeployEvmDataAttestation {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(long, default_value = DEFAULT_SETTINGS)]
|
||||
#[arg(long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_DA)]
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA)]
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA, value_hint = clap::ValueHint::FilePath)]
|
||||
/// The path to output the contract address
|
||||
addr_path: Option<PathBuf>,
|
||||
/// The optimizer runs to set on the verifier. (Lower values optimize for deployment, while higher values optimize for execution)
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS, value_hint = clap::ValueHint::Other)]
|
||||
optimizer_runs: usize,
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long)]
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -797,19 +853,38 @@ pub enum Commands {
|
||||
#[command(name = "verify-evm")]
|
||||
VerifyEvm {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to verifier contract's address
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
|
||||
addr_verifier: H160Flag,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
/// does the verifier use data attestation ?
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_da: Option<H160Flag>,
|
||||
// is the vk rendered seperately, if so specify an address
|
||||
#[arg(long)]
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_vk: Option<H160Flag>,
|
||||
},
|
||||
/// Updates ezkl binary to version specified (or latest if not specified)
|
||||
Update {
|
||||
/// The version to update to
|
||||
#[arg(value_hint = clap::ValueHint::Other, short='v', long)]
|
||||
version: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
impl Commands {
|
||||
/// Converts the commands to a json string
|
||||
pub fn as_json(&self) -> String {
|
||||
serde_json::to_string(self).unwrap()
|
||||
}
|
||||
|
||||
/// Converts a json string to a Commands struct
|
||||
pub fn from_json(json: &str) -> Self {
|
||||
serde_json::from_str(json).unwrap()
|
||||
}
|
||||
}
|
||||
179
src/eth.rs
179
src/eth.rs
@@ -16,20 +16,23 @@ use alloy::prelude::Wallet;
|
||||
// use alloy::providers::Middleware;
|
||||
use alloy::json_abi::JsonAbi;
|
||||
use alloy::node_bindings::Anvil;
|
||||
use alloy::primitives::{B256, I256};
|
||||
use alloy::primitives::ruint::ParseError;
|
||||
use alloy::primitives::{ParseSignedError, B256, I256};
|
||||
use alloy::providers::fillers::{
|
||||
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
|
||||
};
|
||||
use alloy::providers::network::{Ethereum, EthereumSigner};
|
||||
use alloy::providers::ProviderBuilder;
|
||||
use alloy::providers::{Identity, Provider, RootProvider};
|
||||
use alloy::rpc::types::eth::BlockId;
|
||||
use alloy::rpc::types::eth::TransactionInput;
|
||||
use alloy::rpc::types::eth::TransactionRequest;
|
||||
use alloy::signers::wallet::LocalWallet;
|
||||
use alloy::signers::k256::ecdsa;
|
||||
use alloy::signers::wallet::{LocalWallet, WalletError};
|
||||
use alloy::sol as abigen;
|
||||
use alloy::transports::http::Http;
|
||||
use alloy::transports::{RpcError, TransportErrorKind};
|
||||
use foundry_compilers::artifacts::Settings as SolcSettings;
|
||||
use foundry_compilers::error::{SolcError, SolcIoError};
|
||||
use foundry_compilers::Solc;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Fr, G1Affine};
|
||||
@@ -37,7 +40,6 @@ use halo2curves::group::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use log::{debug, info, warn};
|
||||
use reqwest::Client;
|
||||
use std::error::Error;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
@@ -214,6 +216,57 @@ abigen!(
|
||||
}
|
||||
);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EthError {
|
||||
#[error("a transport error occurred: {0}")]
|
||||
Transport(#[from] RpcError<TransportErrorKind>),
|
||||
#[error("a contract error occurred: {0}")]
|
||||
Contract(#[from] alloy::contract::Error),
|
||||
#[error("a wallet error occurred: {0}")]
|
||||
Wallet(#[from] WalletError),
|
||||
#[error("failed to parse url {0}")]
|
||||
UrlParse(String),
|
||||
#[error("evm verification error: {0}")]
|
||||
EvmVerification(#[from] EvmVerificationError),
|
||||
#[error("Private key must be in hex format, 64 chars, without 0x prefix")]
|
||||
PrivateKeyFormat,
|
||||
#[error("failed to parse hex: {0}")]
|
||||
HexParse(#[from] hex::FromHexError),
|
||||
#[error("ecdsa error: {0}")]
|
||||
Ecdsa(#[from] ecdsa::Error),
|
||||
#[error("failed to load graph data")]
|
||||
GraphData,
|
||||
#[error("failed to load graph settings")]
|
||||
GraphSettings,
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Data source for either input_data or output_data must be OnChain")]
|
||||
OnChainDataSource,
|
||||
#[error("failed to parse signed integer: {0}")]
|
||||
SignedIntegerParse(#[from] ParseSignedError),
|
||||
#[error("failed to parse unsigned integer: {0}")]
|
||||
UnSignedIntegerParse(#[from] ParseError),
|
||||
#[error("updateAccountCalls should have failed")]
|
||||
UpdateAccountCalls,
|
||||
#[error("ethabi error: {0}")]
|
||||
EthAbi(#[from] ethabi::Error),
|
||||
#[error("conversion error: {0}")]
|
||||
Conversion(#[from] std::convert::Infallible),
|
||||
// Constructor arguments provided but no constructor found
|
||||
#[error("constructor arguments provided but no constructor found")]
|
||||
NoConstructor,
|
||||
#[error("contract not found at path: {0}")]
|
||||
ContractNotFound(String),
|
||||
#[error("solc error: {0}")]
|
||||
Solc(#[from] SolcError),
|
||||
#[error("solc io error: {0}")]
|
||||
SolcIo(#[from] SolcIoError),
|
||||
#[error("svm error: {0}")]
|
||||
Svm(String),
|
||||
#[error("no contract output found")]
|
||||
NoContractOutput,
|
||||
}
|
||||
|
||||
// we have to generate these two contract differently because they are generated dynamically ! and hence the static compilation from above does not suit
|
||||
const ATTESTDATA_SOL: &str = include_str!("../contracts/AttestData.sol");
|
||||
|
||||
@@ -236,7 +289,7 @@ pub type ContractFactory<M> = CallBuilder<Http<Client>, Arc<M>, ()>;
|
||||
pub async fn setup_eth_backend(
|
||||
rpc_url: Option<&str>,
|
||||
private_key: Option<&str>,
|
||||
) -> Result<(EthersClient, alloy::primitives::Address), Box<dyn Error>> {
|
||||
) -> Result<(EthersClient, alloy::primitives::Address), EthError> {
|
||||
// Launch anvil
|
||||
|
||||
let endpoint: String;
|
||||
@@ -258,11 +311,8 @@ pub async fn setup_eth_backend(
|
||||
let wallet: LocalWallet;
|
||||
if let Some(private_key) = private_key {
|
||||
debug!("using private key {}", private_key);
|
||||
// Sanity checks for private_key
|
||||
let private_key_format_error =
|
||||
"Private key must be in hex format, 64 chars, without 0x prefix";
|
||||
if private_key.len() != 64 {
|
||||
return Err(private_key_format_error.into());
|
||||
return Err(EthError::PrivateKeyFormat);
|
||||
}
|
||||
let private_key_buffer = hex::decode(private_key)?;
|
||||
wallet = LocalWallet::from_slice(&private_key_buffer)?;
|
||||
@@ -277,7 +327,11 @@ pub async fn setup_eth_backend(
|
||||
ProviderBuilder::new()
|
||||
.with_recommended_fillers()
|
||||
.signer(EthereumSigner::from(wallet))
|
||||
.on_http(endpoint.parse()?),
|
||||
.on_http(
|
||||
endpoint
|
||||
.parse()
|
||||
.map_err(|_| EthError::UrlParse(endpoint.clone()))?,
|
||||
),
|
||||
);
|
||||
|
||||
let chain_id = client.get_chain_id().await?;
|
||||
@@ -293,7 +347,7 @@ pub async fn deploy_contract_via_solidity(
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
) -> Result<H160, Box<dyn Error>> {
|
||||
) -> Result<H160, EthError> {
|
||||
// anvil instance must be alive at least until the factory completes the deploy
|
||||
let (client, _) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
@@ -315,12 +369,12 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
rpc_url: Option<&str>,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
) -> Result<H160, Box<dyn Error>> {
|
||||
) -> Result<H160, EthError> {
|
||||
let (client, client_address) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
let input = GraphData::from_path(input)?;
|
||||
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let settings = GraphSettings::load(&settings_path).map_err(|_| EthError::GraphSettings)?;
|
||||
|
||||
let mut scales: Vec<u32> = vec![];
|
||||
// The data that will be stored in the test contracts that will eventually be read from.
|
||||
@@ -340,7 +394,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
}
|
||||
|
||||
if settings.run_args.param_visibility.is_hashed() {
|
||||
return Err(Box::new(EvmVerificationError::InvalidVisibility));
|
||||
return Err(EvmVerificationError::InvalidVisibility.into());
|
||||
}
|
||||
|
||||
if settings.run_args.output_visibility.is_hashed() {
|
||||
@@ -401,7 +455,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
|
||||
parse_calls_to_accounts(calls_to_accounts)?
|
||||
} else {
|
||||
return Err("Data source for either input_data or output_data must be OnChain".into());
|
||||
return Err(EthError::OnChainDataSource);
|
||||
};
|
||||
|
||||
let (abi, bytecode, runtime_bytecode) =
|
||||
@@ -470,7 +524,7 @@ type ParsedCallsToAccount = (Vec<H160>, Vec<Vec<Bytes>>, Vec<Vec<U256>>);
|
||||
|
||||
fn parse_calls_to_accounts(
|
||||
calls_to_accounts: Vec<CallsToAccount>,
|
||||
) -> Result<ParsedCallsToAccount, Box<dyn Error>> {
|
||||
) -> Result<ParsedCallsToAccount, EthError> {
|
||||
let mut contract_addresses = vec![];
|
||||
let mut call_data = vec![];
|
||||
let mut decimals: Vec<Vec<U256>> = vec![];
|
||||
@@ -493,8 +547,8 @@ pub async fn update_account_calls(
|
||||
addr: H160,
|
||||
input: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let input = GraphData::from_path(input)?;
|
||||
) -> Result<(), EthError> {
|
||||
let input = GraphData::from_path(input).map_err(|_| EthError::GraphData)?;
|
||||
|
||||
// The data that will be stored in the test contracts that will eventually be read from.
|
||||
let mut calls_to_accounts = vec![];
|
||||
@@ -514,7 +568,7 @@ pub async fn update_account_calls(
|
||||
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
|
||||
parse_calls_to_accounts(calls_to_accounts)?
|
||||
} else {
|
||||
return Err("Data source for either input_data or output_data must be OnChain".into());
|
||||
return Err(EthError::OnChainDataSource);
|
||||
};
|
||||
|
||||
let (client, client_address) = setup_eth_backend(rpc_url, None).await?;
|
||||
@@ -548,7 +602,7 @@ pub async fn update_account_calls(
|
||||
{
|
||||
info!("updateAccountCalls failed as expected");
|
||||
} else {
|
||||
return Err("updateAccountCalls should have failed".into());
|
||||
return Err(EthError::UpdateAccountCalls);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -561,7 +615,7 @@ pub async fn verify_proof_via_solidity(
|
||||
addr: H160,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
) -> Result<bool, EthError> {
|
||||
let flattened_instances = proof.instances.into_iter().flatten();
|
||||
|
||||
let encoded = encode_calldata(
|
||||
@@ -580,18 +634,18 @@ pub async fn verify_proof_via_solidity(
|
||||
|
||||
let result = client.call(&tx).await;
|
||||
|
||||
if result.is_err() {
|
||||
return Err(Box::new(EvmVerificationError::SolidityExecution));
|
||||
if let Err(e) = result {
|
||||
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
|
||||
}
|
||||
let result = result?;
|
||||
debug!("result: {:#?}", result.to_vec());
|
||||
// decode return bytes value into uint8
|
||||
let result = result.to_vec().last().ok_or("no contract output")? == &1u8;
|
||||
let result = result.to_vec().last().ok_or(EthError::NoContractOutput)? == &1u8;
|
||||
if !result {
|
||||
return Err(Box::new(EvmVerificationError::InvalidProof));
|
||||
return Err(EvmVerificationError::InvalidProof.into());
|
||||
}
|
||||
|
||||
let gas = client.estimate_gas(&tx, BlockId::default()).await?;
|
||||
let gas = client.estimate_gas(&tx).await?;
|
||||
|
||||
info!("estimated verify gas cost: {:#?}", gas);
|
||||
|
||||
@@ -627,7 +681,7 @@ fn count_decimal_places(num: f32) -> usize {
|
||||
pub async fn setup_test_contract<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
client: Arc<M>,
|
||||
data: &[Vec<FileSourceInner>],
|
||||
) -> Result<(TestReads::TestReadsInstance<Http<Client>, Arc<M>>, Vec<u8>), Box<dyn Error>> {
|
||||
) -> Result<(TestReads::TestReadsInstance<Http<Client>, Arc<M>>, Vec<u8>), EthError> {
|
||||
let mut decimals = vec![];
|
||||
let mut scaled_by_decimals_data = vec![];
|
||||
for input in &data[0] {
|
||||
@@ -664,7 +718,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
addr_da: H160,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
) -> Result<bool, EthError> {
|
||||
use ethabi::{Function, Param, ParamType, StateMutability, Token};
|
||||
|
||||
let mut public_inputs: Vec<U256> = vec![];
|
||||
@@ -725,19 +779,19 @@ pub async fn verify_proof_with_data_attestation(
|
||||
debug!("transaction {:#?}", tx);
|
||||
info!(
|
||||
"estimated verify gas cost: {:#?}",
|
||||
client.estimate_gas(&tx, BlockId::default()).await?
|
||||
client.estimate_gas(&tx).await?
|
||||
);
|
||||
|
||||
let result = client.call(&tx).await;
|
||||
if result.is_err() {
|
||||
return Err(Box::new(EvmVerificationError::SolidityExecution));
|
||||
if let Err(e) = result {
|
||||
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
|
||||
}
|
||||
let result = result?;
|
||||
debug!("result: {:#?}", result);
|
||||
// decode return bytes value into uint8
|
||||
let result = result.to_vec().last().ok_or("no contract output")? == &1u8;
|
||||
let result = result.to_vec().last().ok_or(EthError::NoContractOutput)? == &1u8;
|
||||
if !result {
|
||||
return Err(Box::new(EvmVerificationError::InvalidProof));
|
||||
return Err(EvmVerificationError::InvalidProof.into());
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
@@ -749,7 +803,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
client: Arc<M>,
|
||||
data: &[Vec<FileSourceInner>],
|
||||
) -> Result<Vec<CallsToAccount>, Box<dyn Error>> {
|
||||
) -> Result<Vec<CallsToAccount>, EthError> {
|
||||
let (contract, decimals) = setup_test_contract(client.clone(), data).await?;
|
||||
|
||||
// Get the encoded call data for each input
|
||||
@@ -775,7 +829,7 @@ pub async fn read_on_chain_inputs<M: 'static + Provider<Http<Client>, Ethereum>>
|
||||
client: Arc<M>,
|
||||
address: H160,
|
||||
data: &Vec<CallsToAccount>,
|
||||
) -> Result<(Vec<Bytes>, Vec<u8>), Box<dyn Error>> {
|
||||
) -> Result<(Vec<Bytes>, Vec<u8>), EthError> {
|
||||
// Iterate over all on-chain inputs
|
||||
|
||||
let mut fetched_inputs = vec![];
|
||||
@@ -809,9 +863,7 @@ pub async fn evm_quantize<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
client: Arc<M>,
|
||||
scales: Vec<crate::Scale>,
|
||||
data: &(Vec<Bytes>, Vec<u8>),
|
||||
) -> Result<Vec<Fr>, Box<dyn Error>> {
|
||||
use alloy::primitives::ParseSignedError;
|
||||
|
||||
) -> Result<Vec<Fr>, EthError> {
|
||||
let contract = QuantizeData::deploy(&client).await?;
|
||||
|
||||
let fetched_inputs = data.0.clone();
|
||||
@@ -871,7 +923,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
|
||||
runtime_bytecode: Bytes,
|
||||
client: Arc<M>,
|
||||
params: Option<T>,
|
||||
) -> Result<ContractFactory<M>, Box<dyn Error>> {
|
||||
) -> Result<ContractFactory<M>, EthError> {
|
||||
const MAX_RUNTIME_BYTECODE_SIZE: usize = 24577;
|
||||
let size = runtime_bytecode.len();
|
||||
debug!("runtime bytecode size: {:#?}", size);
|
||||
@@ -889,7 +941,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
|
||||
// Encode the constructor args & concatenate with the bytecode if necessary
|
||||
let data: Bytes = match (abi.constructor(), params.is_none()) {
|
||||
(None, false) => {
|
||||
return Err("Constructor arguments provided but no constructor found".into())
|
||||
return Err(EthError::NoConstructor);
|
||||
}
|
||||
(None, true) => bytecode.clone(),
|
||||
(Some(_), _) => {
|
||||
@@ -912,7 +964,7 @@ pub async fn get_contract_artifacts(
|
||||
sol_code_path: PathBuf,
|
||||
contract_name: &str,
|
||||
runs: usize,
|
||||
) -> Result<(JsonAbi, Bytes, Bytes), Box<dyn Error>> {
|
||||
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
|
||||
use foundry_compilers::{
|
||||
artifacts::{output_selection::OutputSelection, Optimizer},
|
||||
compilers::CompilerInput,
|
||||
@@ -920,16 +972,20 @@ pub async fn get_contract_artifacts(
|
||||
};
|
||||
|
||||
if !sol_code_path.exists() {
|
||||
return Err(format!("file not found: {:#?}", sol_code_path).into());
|
||||
return Err(EthError::ContractNotFound(
|
||||
sol_code_path.to_string_lossy().to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut settings = SolcSettings::default();
|
||||
settings.optimizer = Optimizer {
|
||||
enabled: Some(true),
|
||||
runs: Some(runs),
|
||||
details: None,
|
||||
let settings = SolcSettings {
|
||||
optimizer: Optimizer {
|
||||
enabled: Some(true),
|
||||
runs: Some(runs),
|
||||
details: None,
|
||||
},
|
||||
output_selection: OutputSelection::default_output_selection(),
|
||||
..Default::default()
|
||||
};
|
||||
settings.output_selection = OutputSelection::default_output_selection();
|
||||
|
||||
let input = SolcInput::build(
|
||||
std::collections::BTreeMap::from([(
|
||||
@@ -945,7 +1001,9 @@ pub async fn get_contract_artifacts(
|
||||
Some(solc) => solc,
|
||||
None => {
|
||||
info!("required solc version is missing ... installing");
|
||||
Solc::install(&SHANGHAI_SOLC).await?
|
||||
Solc::install(&SHANGHAI_SOLC)
|
||||
.await
|
||||
.map_err(|e| EthError::Svm(e.to_string()))?
|
||||
}
|
||||
};
|
||||
|
||||
@@ -954,7 +1012,7 @@ pub async fn get_contract_artifacts(
|
||||
let (abi, bytecode, runtime_bytecode) = match compiled.find(contract_name) {
|
||||
Some(c) => c.into_parts_or_default(),
|
||||
None => {
|
||||
return Err("could not find contract".into());
|
||||
return Err(EthError::ContractNotFound(contract_name.to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -965,7 +1023,8 @@ pub async fn get_contract_artifacts(
|
||||
pub fn fix_da_sol(
|
||||
input_data: Option<Vec<CallsToAccount>>,
|
||||
output_data: Option<Vec<CallsToAccount>>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
commitment_bytes: Option<Vec<u8>>,
|
||||
) -> Result<String, EthError> {
|
||||
let mut accounts_len = 0;
|
||||
let mut contract = ATTESTDATA_SOL.to_string();
|
||||
|
||||
@@ -989,5 +1048,21 @@ pub fn fix_da_sol(
|
||||
}
|
||||
contract = contract.replace("AccountCall[]", &format!("AccountCall[{}]", accounts_len));
|
||||
|
||||
if commitment_bytes.clone().is_some() && !commitment_bytes.clone().unwrap().is_empty() {
|
||||
let commitment_bytes = commitment_bytes.unwrap();
|
||||
let hex_string = hex::encode(commitment_bytes);
|
||||
contract = contract.replace(
|
||||
"bytes constant COMMITMENT_KZG = hex\"\";",
|
||||
&format!("bytes constant COMMITMENT_KZG = hex\"{}\";", hex_string),
|
||||
);
|
||||
} else {
|
||||
// Remove the SwapProofCommitments inheritance and the checkKzgCommits function call if no commitment is provided
|
||||
contract = contract.replace(", SwapProofCommitments", "");
|
||||
contract = contract.replace(
|
||||
"require(checkKzgCommits(encoded), \"Invalid KZG commitments\");",
|
||||
"",
|
||||
);
|
||||
}
|
||||
|
||||
Ok(contract)
|
||||
}
|
||||
|
||||
283
src/execute.rs
283
src/execute.rs
@@ -1,7 +1,6 @@
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::commands::*;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -23,6 +22,7 @@ use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::{commands::*, EZKLError};
|
||||
use crate::{Commitments, RunArgs};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
@@ -63,7 +63,6 @@ use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::compile;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::system::halo2::Config;
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::BufWriter;
|
||||
@@ -92,12 +91,15 @@ lazy_static! {
|
||||
|
||||
}
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
/// A wrapper for execution errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ExecutionError {
|
||||
/// Shape mismatch in a operation
|
||||
#[error("verification failed")]
|
||||
/// verification failed
|
||||
#[error("verification failed:\n{}", .0.iter().map(|e| e.to_string()).collect::<Vec<_>>().join("\n"))]
|
||||
VerifyError(Vec<VerifyFailure>),
|
||||
/// Prover error
|
||||
#[error("[mock] {0}")]
|
||||
MockProverError(String),
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
@@ -109,7 +111,7 @@ lazy_static::lazy_static! {
|
||||
}
|
||||
|
||||
/// Run an ezkl command with given args
|
||||
pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
// set working dir
|
||||
std::env::set_current_dir(WORKING_DIR.as_path())?;
|
||||
|
||||
@@ -123,7 +125,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
} => gen_srs_cmd(
|
||||
srs_path,
|
||||
logrows as u32,
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT)?),
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::GetSrs {
|
||||
@@ -161,7 +163,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse()?),
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
@@ -200,10 +202,22 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_ABI.into()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::EncodeEvmCalldata {
|
||||
proof_path,
|
||||
calldata_path,
|
||||
addr_vk,
|
||||
} => encode_evm_calldata(
|
||||
proof_path.unwrap_or(DEFAULT_PROOF.into()),
|
||||
calldata_path.unwrap_or(DEFAULT_CALLDATA.into()),
|
||||
addr_vk,
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
|
||||
Commands::CreateEvmVK {
|
||||
vk_path,
|
||||
srs_path,
|
||||
@@ -226,12 +240,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
data,
|
||||
witness,
|
||||
} => {
|
||||
create_evm_data_attestation(
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_DA.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_DA_ABI.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
witness,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -251,8 +267,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_AGGREGATED.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_AGGREGATED_ABI.into()),
|
||||
aggregation_settings,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -278,7 +294,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
pk_path.unwrap_or(DEFAULT_PK.into()),
|
||||
witness,
|
||||
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
|
||||
disable_selector_compression
|
||||
.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SetupTestEvmData {
|
||||
@@ -331,7 +348,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
Some(proof_path.unwrap_or(DEFAULT_PROOF.into())),
|
||||
srs_path,
|
||||
proof_type,
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse().unwrap()),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::MockAggregate {
|
||||
@@ -340,8 +357,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
split_proofs,
|
||||
} => mock_aggregate(
|
||||
aggregation_snarks,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
|
||||
),
|
||||
Commands::SetupAggregate {
|
||||
sample_snarks,
|
||||
@@ -357,9 +374,10 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
|
||||
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
|
||||
disable_selector_compression
|
||||
.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()),
|
||||
commitment.into(),
|
||||
),
|
||||
Commands::Aggregate {
|
||||
@@ -378,9 +396,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
transcript,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse().unwrap()),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse().unwrap()),
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -395,7 +413,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
srs_path,
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap()),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::VerifyAggr {
|
||||
@@ -409,8 +427,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
proof_path.unwrap_or(DEFAULT_PROOF_AGGREGATED.into()),
|
||||
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse().unwrap()),
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap()),
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -488,6 +506,65 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
Commands::Update { version } => update_ezkl_binary(&version).map(|e| e.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert that the version is valid
|
||||
fn assert_version_is_valid(version: &str) -> Result<(), EZKLError> {
|
||||
let err_string = "Invalid version string. Must be in the format v0.0.0";
|
||||
if version.is_empty() {
|
||||
return Err(err_string.into());
|
||||
}
|
||||
// safe to unwrap since we know the length is not 0
|
||||
if !version.starts_with('v') {
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
semver::Version::parse(&version[1..])
|
||||
.map_err(|_| "Invalid version string. Must be in the format v0.0.0")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const INSTALL_BYTES: &[u8] = include_bytes!("../install_ezkl_cli.sh");
|
||||
|
||||
fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
|
||||
// run the install script with the version
|
||||
let install_script = std::str::from_utf8(INSTALL_BYTES)?;
|
||||
// now run as sh script with the version as an argument
|
||||
|
||||
// check if bash is installed
|
||||
let command = if std::process::Command::new("bash")
|
||||
.arg("--version")
|
||||
.status()
|
||||
.is_err()
|
||||
{
|
||||
log::warn!("bash is not installed on this system, trying to run the install script with sh (may fail)");
|
||||
"sh"
|
||||
} else {
|
||||
"bash"
|
||||
};
|
||||
|
||||
let mut command = std::process::Command::new(command);
|
||||
let mut command = command.arg("-c").arg(install_script);
|
||||
|
||||
if let Some(version) = version {
|
||||
assert_version_is_valid(version)?;
|
||||
command = command.arg(version)
|
||||
};
|
||||
let output = command.output()?;
|
||||
|
||||
if output.status.success() {
|
||||
info!("updated binary");
|
||||
Ok("".to_string())
|
||||
} else {
|
||||
Err(format!(
|
||||
"failed to update binary: {}, {}",
|
||||
std::str::from_utf8(&output.stdout)?,
|
||||
std::str::from_utf8(&output.stderr)?
|
||||
)
|
||||
.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -514,7 +591,7 @@ pub(crate) fn gen_srs_cmd(
|
||||
srs_path: PathBuf,
|
||||
logrows: u32,
|
||||
commitment: Commitments,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
match commitment {
|
||||
Commitments::KZG => {
|
||||
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
@@ -529,7 +606,7 @@ pub(crate) fn gen_srs_cmd(
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
async fn fetch_srs(uri: &str) -> Result<Vec<u8>, EZKLError> {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Downloading SRS (this may take a while) ...");
|
||||
@@ -549,7 +626,7 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, EZKLError> {
|
||||
use std::io::Read;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
@@ -572,7 +649,7 @@ fn check_srs_hash(
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
commitment: Commitments,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let path = get_srs_path(logrows, srs_path, commitment);
|
||||
let hash = get_file_hash(&path)?;
|
||||
|
||||
@@ -599,7 +676,7 @@ pub(crate) async fn get_srs_cmd(
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
commitment: Option<Commitments>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
// logrows overrides settings
|
||||
|
||||
let err_string = "You will need to provide a valid settings file to use the settings option. You should run gen-settings to generate a settings file (and calibrate-settings to pick optimal logrows).";
|
||||
@@ -667,7 +744,7 @@ pub(crate) async fn get_srs_cmd(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, EZKLError> {
|
||||
let model = Model::from_run_args(&run_args, &model)?;
|
||||
info!("\n {}", model.table_nodes());
|
||||
Ok(String::new())
|
||||
@@ -679,11 +756,11 @@ pub(crate) async fn gen_witness(
|
||||
output: Option<PathBuf>,
|
||||
vk_path: Option<PathBuf>,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> Result<GraphWitness, Box<dyn Error>> {
|
||||
) -> Result<GraphWitness, EZKLError> {
|
||||
// these aren't real values so the sanity checks are mostly meaningless
|
||||
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
let data = GraphData::from_path(data)?;
|
||||
let data: GraphData = GraphData::from_path(data)?;
|
||||
let settings = circuit.settings().clone();
|
||||
|
||||
let vk = if let Some(vk) = vk_path {
|
||||
@@ -780,7 +857,7 @@ pub(crate) fn gen_circuit_settings(
|
||||
model_path: PathBuf,
|
||||
params_output: PathBuf,
|
||||
run_args: RunArgs,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let circuit = GraphCircuit::from_run_args(&run_args, &model_path)?;
|
||||
let params = circuit.settings();
|
||||
params.save(¶ms_output)?;
|
||||
@@ -848,7 +925,7 @@ impl AccuracyResults {
|
||||
pub fn new(
|
||||
mut original_preds: Vec<crate::tensor::Tensor<f32>>,
|
||||
mut calibrated_preds: Vec<crate::tensor::Tensor<f32>>,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
) -> Result<Self, EZKLError> {
|
||||
let mut errors = vec![];
|
||||
let mut abs_errors = vec![];
|
||||
let mut squared_errors = vec![];
|
||||
@@ -937,7 +1014,7 @@ pub(crate) async fn calibrate(
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
) -> Result<GraphSettings, EZKLError> {
|
||||
use log::error;
|
||||
use std::collections::HashMap;
|
||||
use tabled::Table;
|
||||
@@ -1178,7 +1255,6 @@ pub(crate) async fn calibrate(
|
||||
);
|
||||
num_passed += 1;
|
||||
} else {
|
||||
error!("calibration failed {}", res.err().unwrap());
|
||||
num_failed += 1;
|
||||
}
|
||||
|
||||
@@ -1310,7 +1386,7 @@ pub(crate) async fn calibrate(
|
||||
pub(crate) fn mock(
|
||||
compiled_circuit_path: PathBuf,
|
||||
data_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
// mock should catch any issues by default so we set it to safe
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
|
||||
@@ -1327,10 +1403,9 @@ pub(crate) fn mock(
|
||||
&circuit,
|
||||
vec![public_inputs],
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
prover
|
||||
.verify()
|
||||
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
|
||||
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
|
||||
|
||||
prover.verify().map_err(ExecutionError::VerifyError)?;
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
@@ -1342,7 +1417,7 @@ pub(crate) async fn create_evm_verifier(
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
|
||||
@@ -1386,7 +1461,7 @@ pub(crate) async fn create_evm_vk(
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
|
||||
@@ -1426,16 +1501,19 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
_sol_code_path: PathBuf,
|
||||
_abi_path: PathBuf,
|
||||
_input: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
_witness: Option<PathBuf>,
|
||||
) -> Result<String, EZKLError> {
|
||||
#[allow(unused_imports)]
|
||||
use crate::graph::{DataSource, VarVisibility};
|
||||
use crate::{graph::Visibility, pfsys::get_proof_commitments};
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
|
||||
let visibility = VarVisibility::from_args(&settings.run_args)?;
|
||||
trace!("params computed");
|
||||
|
||||
let data = GraphData::from_path(_input)?;
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data = GraphData::from_path(_input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
|
||||
if visibility.output.is_private() {
|
||||
@@ -1463,19 +1541,34 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
None
|
||||
};
|
||||
|
||||
if input_data.is_some() || output_data.is_some() {
|
||||
let output = fix_da_sol(input_data, output_data)?;
|
||||
let mut f = File::create(_sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(_abi_path)?, &abi)?;
|
||||
// Read the settings file. Look if either the run_ars.input_visibility, run_args.output_visibility or run_args.param_visibility is KZGCommit
|
||||
// if so, then we need to load the witness
|
||||
|
||||
let commitment_bytes = if settings.run_args.input_visibility == Visibility::KZGCommit
|
||||
|| settings.run_args.output_visibility == Visibility::KZGCommit
|
||||
|| settings.run_args.param_visibility == Visibility::KZGCommit
|
||||
{
|
||||
let witness = GraphWitness::from_path(_witness.unwrap_or(DEFAULT_WITNESS.into()))?;
|
||||
let commitments = witness.get_polycommitments();
|
||||
let proof_first_bytes = get_proof_commitments::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&commitments);
|
||||
|
||||
Some(proof_first_bytes.unwrap())
|
||||
} else {
|
||||
return Err(
|
||||
"Neither input or output data source is on-chain. Atleast one must be on chain.".into(),
|
||||
);
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
let output = fix_da_sol(input_data, output_data, commitment_bytes)?;
|
||||
let mut f = File::create(_sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(_abi_path)?, &abi)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
@@ -1488,7 +1581,7 @@ pub(crate) async fn deploy_da_evm(
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let contract_address = deploy_da_verifier_via_solidity(
|
||||
settings_path,
|
||||
data,
|
||||
@@ -1514,7 +1607,7 @@ pub(crate) async fn deploy_evm(
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
contract_name: &str,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
@@ -1531,6 +1624,32 @@ pub(crate) async fn deploy_evm(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
/// Encodes the calldata for the EVM verifier (both aggregated and single proof)
|
||||
pub(crate) fn encode_evm_calldata(
|
||||
proof_path: PathBuf,
|
||||
calldata_path: PathBuf,
|
||||
addr_vk: Option<H160Flag>,
|
||||
) -> Result<Vec<u8>, EZKLError> {
|
||||
let snark = Snark::load::<IPACommitmentScheme<G1Affine>>(&proof_path)?;
|
||||
|
||||
let flattened_instances = snark.instances.into_iter().flatten();
|
||||
|
||||
let encoded = halo2_solidity_verifier::encode_calldata(
|
||||
addr_vk
|
||||
.as_ref()
|
||||
.map(|x| alloy::primitives::Address::from(*x).0)
|
||||
.map(|x| x.0),
|
||||
&snark.proof,
|
||||
&flattened_instances.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
log::debug!("Encoded calldata: {:?}", encoded);
|
||||
|
||||
File::create(calldata_path)?.write_all(encoded.as_slice())?;
|
||||
|
||||
Ok(encoded)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn verify_evm(
|
||||
proof_path: PathBuf,
|
||||
@@ -1538,7 +1657,7 @@ pub(crate) async fn verify_evm(
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<H160Flag>,
|
||||
addr_vk: Option<H160Flag>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::eth::verify_proof_with_data_attestation;
|
||||
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
@@ -1580,7 +1699,7 @@ pub(crate) async fn create_evm_aggregate_verifier(
|
||||
circuit_settings: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let srs_path = get_srs_path(logrows, srs_path, Commitments::KZG);
|
||||
let params: ParamsKZG<Bn256> = load_srs_verifier::<KZGCommitmentScheme<Bn256>>(srs_path)?;
|
||||
|
||||
@@ -1637,7 +1756,7 @@ pub(crate) fn compile_circuit(
|
||||
model_path: PathBuf,
|
||||
compiled_circuit: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let circuit = GraphCircuit::from_settings(&settings, &model_path, CheckMode::UNSAFE)?;
|
||||
circuit.save(compiled_circuit)?;
|
||||
@@ -1651,7 +1770,7 @@ pub(crate) fn setup(
|
||||
pk_path: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
// these aren't real values so the sanity checks are mostly meaningless
|
||||
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit)?;
|
||||
@@ -1703,7 +1822,7 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
rpc_url: Option<String>,
|
||||
input_source: TestDataSource,
|
||||
output_source: TestDataSource,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::graph::TestOnChainData;
|
||||
|
||||
let mut data = GraphData::from_path(data_path)?;
|
||||
@@ -1738,7 +1857,7 @@ pub(crate) async fn test_update_account_calls(
|
||||
addr: H160Flag,
|
||||
data: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::eth::update_account_calls;
|
||||
|
||||
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
|
||||
@@ -1756,7 +1875,7 @@ pub(crate) fn prove(
|
||||
srs_path: Option<PathBuf>,
|
||||
proof_type: ProofType,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
|
||||
let data = GraphWitness::from_path(data_path)?;
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
|
||||
@@ -1908,15 +2027,11 @@ pub(crate) fn prove(
|
||||
pub(crate) fn swap_proof_commitments_cmd(
|
||||
proof_path: PathBuf,
|
||||
witness: PathBuf,
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
|
||||
let snark = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
let witness = GraphWitness::from_path(witness)?;
|
||||
let commitments = witness.get_polycommitments();
|
||||
|
||||
if commitments.is_empty() {
|
||||
log::warn!("no commitments found in witness");
|
||||
}
|
||||
|
||||
let snark_new = swap_proof_commitments_polycommit(&snark, &commitments)?;
|
||||
|
||||
if snark_new.proof != *snark.proof {
|
||||
@@ -1931,7 +2046,7 @@ pub(crate) fn mock_aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
split_proofs: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let mut snarks = vec![];
|
||||
for proof_path in aggregation_snarks.iter() {
|
||||
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
|
||||
@@ -1958,10 +2073,8 @@ pub(crate) fn mock_aggregate(
|
||||
let circuit = AggregationCircuit::new(&G1Affine::generator().into(), snarks, split_proofs)?;
|
||||
|
||||
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
prover
|
||||
.verify()
|
||||
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
|
||||
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
|
||||
prover.verify().map_err(ExecutionError::VerifyError)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("Done.");
|
||||
Ok(String::new())
|
||||
@@ -1976,7 +2089,7 @@ pub(crate) fn setup_aggregate(
|
||||
split_proofs: bool,
|
||||
disable_selector_compression: bool,
|
||||
commitment: Commitments,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<String, EZKLError> {
|
||||
let mut snarks = vec![];
|
||||
for proof_path in sample_snarks.iter() {
|
||||
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
|
||||
@@ -2039,7 +2152,7 @@ pub(crate) fn aggregate(
|
||||
check_mode: CheckMode,
|
||||
split_proofs: bool,
|
||||
commitment: Commitments,
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
) -> Result<Snark<Fr, G1Affine>, EZKLError> {
|
||||
let mut snarks = vec![];
|
||||
for proof_path in aggregation_snarks.iter() {
|
||||
match Snark::load::<KZGCommitmentScheme<Bn256>>(proof_path) {
|
||||
@@ -2219,7 +2332,7 @@ pub(crate) fn verify(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
reduced_srs: bool,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
) -> Result<bool, EZKLError> {
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
|
||||
let logrows = circuit_settings.run_args.logrows;
|
||||
@@ -2313,7 +2426,7 @@ fn verify_commitment<
|
||||
vk_path: PathBuf,
|
||||
params: &'a Scheme::ParamsVerifier,
|
||||
logrows: u32,
|
||||
) -> Result<bool, Box<dyn Error>>
|
||||
) -> Result<bool, EZKLError>
|
||||
where
|
||||
Scheme::Scalar: FromUniformBytes<64>
|
||||
+ SerdeObject
|
||||
@@ -2349,7 +2462,7 @@ pub(crate) fn verify_aggr(
|
||||
logrows: u32,
|
||||
reduced_srs: bool,
|
||||
commitment: Commitments,
|
||||
) -> Result<bool, Box<dyn Error>> {
|
||||
) -> Result<bool, EZKLError> {
|
||||
match commitment {
|
||||
Commitments::KZG => {
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
@@ -2424,11 +2537,11 @@ pub(crate) fn load_params_verifier<Scheme: CommitmentScheme>(
|
||||
srs_path: Option<PathBuf>,
|
||||
logrows: u32,
|
||||
commitment: Commitments,
|
||||
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
|
||||
) -> Result<Scheme::ParamsVerifier, EZKLError> {
|
||||
let srs_path = get_srs_path(logrows, srs_path, commitment);
|
||||
let mut params = load_srs_verifier::<Scheme>(srs_path)?;
|
||||
info!("downsizing params to {} logrows", logrows);
|
||||
if logrows < params.k() {
|
||||
info!("downsizing params to {} logrows", logrows);
|
||||
params.downsize(logrows);
|
||||
}
|
||||
Ok(params)
|
||||
@@ -2439,11 +2552,11 @@ pub(crate) fn load_params_prover<Scheme: CommitmentScheme>(
|
||||
srs_path: Option<PathBuf>,
|
||||
logrows: u32,
|
||||
commitment: Commitments,
|
||||
) -> Result<Scheme::ParamsProver, Box<dyn Error>> {
|
||||
) -> Result<Scheme::ParamsProver, EZKLError> {
|
||||
let srs_path = get_srs_path(logrows, srs_path, commitment);
|
||||
let mut params = load_srs_prover::<Scheme>(srs_path)?;
|
||||
info!("downsizing params to {} logrows", logrows);
|
||||
if logrows < params.k() {
|
||||
info!("downsizing params to {} logrows", logrows);
|
||||
params.downsize(logrows);
|
||||
}
|
||||
Ok(params)
|
||||
|
||||
137
src/graph/errors.rs
Normal file
137
src/graph/errors.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
/// The wrong inputs were passed to a lookup node
|
||||
#[error("invalid inputs for a lookup node")]
|
||||
InvalidLookupInputs,
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
/// A requested node is missing in the graph
|
||||
#[error("a requested node is missing in the graph: {0}")]
|
||||
MissingNode(usize),
|
||||
/// The wrong method was called on an operation
|
||||
#[error("an unsupported method was called on node {0} ({1})")]
|
||||
OpMismatch(usize, String),
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported datatype in graph node {0} ({1})")]
|
||||
UnsupportedDataType(usize, String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node is missing required params: {0}")]
|
||||
MissingParams(String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node is has misformed params: {0}")]
|
||||
MisformedParams(String),
|
||||
/// Error in the configuration of the visibility of variables
|
||||
#[error("there should be at least one set of public variables")]
|
||||
Visibility,
|
||||
/// Ezkl only supports divisions by constants
|
||||
#[error("ezkl currently only supports division by constants")]
|
||||
NonConstantDiv,
|
||||
/// Ezkl only supports constant powers
|
||||
#[error("ezkl currently only supports constant exponents")]
|
||||
NonConstantPower,
|
||||
/// Error when attempting to rescale an operation
|
||||
#[error("failed to rescale inputs for {0}")]
|
||||
RescalingError(String),
|
||||
/// Error when attempting to load a model from a file
|
||||
#[error("failed to load model")]
|
||||
ModelLoad(#[from] std::io::Error),
|
||||
/// Model serialization error
|
||||
#[error("failed to ser/deser model: {0}")]
|
||||
ModelSerialize(#[from] bincode::Error),
|
||||
/// Tract error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[tract] {0}")]
|
||||
TractError(#[from] tract_onnx::tract_core::anyhow::Error),
|
||||
/// Packing exponent is too large
|
||||
#[error("largest packing exponent exceeds max. try reducing the scale")]
|
||||
PackingExponent,
|
||||
/// Invalid Input Types
|
||||
#[error("invalid input types")]
|
||||
InvalidInputTypes,
|
||||
/// Missing results
|
||||
#[error("missing results")]
|
||||
MissingResults,
|
||||
/// Tensor error
|
||||
#[error("[tensor] {0}")]
|
||||
TensorError(#[from] crate::tensor::TensorError),
|
||||
/// Public visibility for params is deprecated
|
||||
#[error("public visibility for params is deprecated, please use `fixed` instead")]
|
||||
ParamsPublicVisibility,
|
||||
/// Slice length mismatch
|
||||
#[error("slice length mismatch: {0}")]
|
||||
SliceLengthMismatch(#[from] std::array::TryFromSliceError),
|
||||
/// Bad conversion
|
||||
#[error("invalid conversion: {0}")]
|
||||
InvalidConversion(#[from] Infallible),
|
||||
/// Circuit error
|
||||
#[error("[circuit] {0}")]
|
||||
CircuitError(#[from] crate::circuit::CircuitError),
|
||||
/// Halo2 error
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] halo2_proofs::plonk::Error),
|
||||
/// System time error
|
||||
#[error("[system time] {0}")]
|
||||
SystemTimeError(#[from] std::time::SystemTimeError),
|
||||
/// Missing Batch Size
|
||||
#[error("unknown dimension batch_size in model inputs, set batch_size in variables")]
|
||||
MissingBatchSize,
|
||||
/// Tokio postgres error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[tokio postgres] {0}")]
|
||||
TokioPostgresError(#[from] tokio_postgres::Error),
|
||||
/// Eth error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[eth] {0}")]
|
||||
EthError(#[from] crate::eth::EthError),
|
||||
/// Json error
|
||||
#[error("[json] {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
/// Missing instances
|
||||
#[error("missing instances")]
|
||||
MissingInstances,
|
||||
/// Missing constants
|
||||
#[error("missing constants")]
|
||||
MissingConstants,
|
||||
/// Missing input for a node
|
||||
#[error("missing input for node {0}")]
|
||||
MissingInput(usize),
|
||||
///
|
||||
#[error("range only supports constant inputs in a zk circuit")]
|
||||
NonConstantRange,
|
||||
///
|
||||
#[error("trilu only supports constant diagonals in a zk circuit")]
|
||||
NonConstantTrilu,
|
||||
///
|
||||
#[error("insufficient witness values to generate a fixed output")]
|
||||
InsufficientWitnessValues,
|
||||
/// Missing scale
|
||||
#[error("missing scale")]
|
||||
MissingScale,
|
||||
/// Extended k is too large
|
||||
#[error("extended k is too large to accommodate the quotient polynomial with logrows {0}")]
|
||||
ExtendedKTooLarge(u32),
|
||||
/// Max lookup input is too large
|
||||
#[error("lookup range {0} is too large")]
|
||||
LookupRangeTooLarge(usize),
|
||||
/// Max range check input is too large
|
||||
#[error("range check {0} is too large")]
|
||||
RangeCheckTooLarge(usize),
|
||||
///Cannot use on-chain data source as private data
|
||||
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
|
||||
OnChainDataSource,
|
||||
/// Missing data source
|
||||
#[error("missing data source")]
|
||||
MissingDataSource,
|
||||
/// Invalid RunArg
|
||||
#[error("invalid RunArgs: {0}")]
|
||||
InvalidRunArgs(String),
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use super::GraphError;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -211,9 +211,7 @@ impl PostgresSource {
|
||||
}
|
||||
|
||||
/// Fetch data from postgres
|
||||
pub async fn fetch(
|
||||
&self,
|
||||
) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// clone to move into thread
|
||||
let user = self.user.clone();
|
||||
let host = self.host.clone();
|
||||
@@ -247,9 +245,7 @@ impl PostgresSource {
|
||||
}
|
||||
|
||||
/// Fetch data from postgres and format it as a FileSource
|
||||
pub async fn fetch_and_format_as_file(
|
||||
&self,
|
||||
) -> Result<Vec<Vec<FileSourceInner>>, Box<dyn std::error::Error>> {
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
@@ -279,7 +275,7 @@ impl OnChainSource {
|
||||
scales: Vec<crate::Scale>,
|
||||
mut shapes: Vec<Vec<usize>>,
|
||||
rpc: Option<&str>,
|
||||
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
|
||||
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
|
||||
use crate::eth::{
|
||||
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
|
||||
};
|
||||
@@ -455,7 +451,7 @@ impl GraphData {
|
||||
&self,
|
||||
shapes: &[Vec<usize>],
|
||||
datum_types: &[tract_onnx::prelude::DatumType],
|
||||
) -> Result<TVec<TValue>, Box<dyn std::error::Error>> {
|
||||
) -> Result<TVec<TValue>, GraphError> {
|
||||
let mut inputs = TVec::new();
|
||||
match &self.input_data {
|
||||
DataSource::File(data) => {
|
||||
@@ -470,10 +466,10 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"non file data cannot be split into batches".to_string(),
|
||||
)))
|
||||
))
|
||||
}
|
||||
}
|
||||
Ok(inputs)
|
||||
@@ -488,7 +484,7 @@ impl GraphData {
|
||||
}
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let reader = std::fs::File::open(path)?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
|
||||
let mut buf = String::new();
|
||||
@@ -498,7 +494,7 @@ impl GraphData {
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
@@ -509,7 +505,7 @@ impl GraphData {
|
||||
pub async fn split_into_batches(
|
||||
&self,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Self>, GraphError> {
|
||||
// split input data into batches
|
||||
let mut batched_inputs = vec![];
|
||||
|
||||
@@ -522,10 +518,10 @@ impl GraphData {
|
||||
input_data: DataSource::OnChain(_),
|
||||
output_data: _,
|
||||
} => {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"on-chain data cannot be split into batches".to_string(),
|
||||
)))
|
||||
))
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
GraphData {
|
||||
@@ -539,11 +535,11 @@ impl GraphData {
|
||||
let input_size = shape.clone().iter().product::<usize>();
|
||||
let input = &iterable[i];
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"calibration data length must be evenly divisible by the original input_size"
|
||||
.to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
let mut batches = vec![];
|
||||
for batch in input.chunks(input_size) {
|
||||
|
||||
187
src/graph/mod.rs
187
src/graph/mod.rs
@@ -14,6 +14,9 @@ pub mod utilities;
|
||||
/// Representations of a computational graph's variables.
|
||||
pub mod vars;
|
||||
|
||||
/// errors for the graph
|
||||
pub mod errors;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored_json::ToColoredJson;
|
||||
#[cfg(unix)]
|
||||
@@ -24,6 +27,7 @@ pub use input::DataSource;
|
||||
use itertools::Itertools;
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use self::input::OnChainSource;
|
||||
use self::input::{FileSource, GraphData};
|
||||
@@ -58,7 +62,6 @@ use pyo3::types::PyDict;
|
||||
use pyo3::ToPyObject;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Deref;
|
||||
use thiserror::Error;
|
||||
pub use utilities::*;
|
||||
pub use vars::*;
|
||||
|
||||
@@ -88,62 +91,6 @@ lazy_static! {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_MAX_CIRCUIT_AREA: Option<usize> = None;
|
||||
|
||||
/// circuit related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GraphError {
|
||||
/// The wrong inputs were passed to a lookup node
|
||||
#[error("invalid inputs for a lookup node")]
|
||||
InvalidLookupInputs,
|
||||
/// Shape mismatch in circuit construction
|
||||
#[error("invalid dimensions used for node {0} ({1})")]
|
||||
InvalidDims(usize, String),
|
||||
/// Wrong method was called to configure an op
|
||||
#[error("wrong method was called to configure node {0} ({1})")]
|
||||
WrongMethod(usize, String),
|
||||
/// A requested node is missing in the graph
|
||||
#[error("a requested node is missing in the graph: {0}")]
|
||||
MissingNode(usize),
|
||||
/// The wrong method was called on an operation
|
||||
#[error("an unsupported method was called on node {0} ({1})")]
|
||||
OpMismatch(usize, String),
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported operation in graph")]
|
||||
UnsupportedOp,
|
||||
/// This operation is unsupported
|
||||
#[error("unsupported datatype in graph")]
|
||||
UnsupportedDataType,
|
||||
/// A node has missing parameters
|
||||
#[error("a node is missing required params: {0}")]
|
||||
MissingParams(String),
|
||||
/// A node has missing parameters
|
||||
#[error("a node is has misformed params: {0}")]
|
||||
MisformedParams(String),
|
||||
/// Error in the configuration of the visibility of variables
|
||||
#[error("there should be at least one set of public variables")]
|
||||
Visibility,
|
||||
/// Ezkl only supports divisions by constants
|
||||
#[error("ezkl currently only supports division by constants")]
|
||||
NonConstantDiv,
|
||||
/// Ezkl only supports constant powers
|
||||
#[error("ezkl currently only supports constant exponents")]
|
||||
NonConstantPower,
|
||||
/// Error when attempting to rescale an operation
|
||||
#[error("failed to rescale inputs for {0}")]
|
||||
RescalingError(String),
|
||||
/// Error when attempting to load a model
|
||||
#[error("failed to load")]
|
||||
ModelLoad,
|
||||
/// Packing exponent is too large
|
||||
#[error("largest packing exponent exceeds max. try reducing the scale")]
|
||||
PackingExponent,
|
||||
/// Invalid Input Types
|
||||
#[error("invalid input types")]
|
||||
InvalidInputTypes,
|
||||
/// Missing results
|
||||
#[error("missing results")]
|
||||
MissingResults,
|
||||
}
|
||||
|
||||
///
|
||||
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
|
||||
/// The minimum number of rows in the grid
|
||||
@@ -310,27 +257,24 @@ impl GraphWitness {
|
||||
}
|
||||
|
||||
/// Export the ezkl witness as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
pub fn as_json(&self) -> Result<String, GraphError> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load {}", path.display()))?;
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let file = std::fs::File::open(path.clone())?;
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
// use buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
@@ -595,11 +539,11 @@ impl GraphSettings {
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
pub fn as_json(&self) -> Result<String, GraphError> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
@@ -695,7 +639,7 @@ impl GraphCircuit {
|
||||
&self.core.model
|
||||
}
|
||||
///
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
@@ -703,7 +647,7 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
pub fn load(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
// read bytes from file
|
||||
let f = std::fs::File::open(path)?;
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
@@ -770,10 +714,7 @@ pub struct TestOnChainData {
|
||||
|
||||
impl GraphCircuit {
|
||||
///
|
||||
pub fn new(
|
||||
model: Model,
|
||||
run_args: &RunArgs,
|
||||
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
|
||||
pub fn new(model: Model, run_args: &RunArgs) -> Result<GraphCircuit, GraphError> {
|
||||
// // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
|
||||
let mut inputs: Vec<Vec<Fp>> = vec![];
|
||||
for shape in model.graph.input_shapes()? {
|
||||
@@ -820,7 +761,7 @@ impl GraphCircuit {
|
||||
model: Model,
|
||||
mut settings: GraphSettings,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<GraphCircuit, Box<dyn std::error::Error>> {
|
||||
) -> Result<GraphCircuit, GraphError> {
|
||||
// placeholder dummy inputs - must call prepare_public_inputs to load data afterwards
|
||||
let mut inputs: Vec<Vec<Fp>> = vec![];
|
||||
for shape in model.graph.input_shapes()? {
|
||||
@@ -844,20 +785,14 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
/// load inputs and outputs for the model
|
||||
pub fn load_graph_witness(
|
||||
&mut self,
|
||||
data: &GraphWitness,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn load_graph_witness(&mut self, data: &GraphWitness) -> Result<(), GraphError> {
|
||||
self.graph_witness = data.clone();
|
||||
// load the module settings
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepare the public inputs for the circuit.
|
||||
pub fn prepare_public_inputs(
|
||||
&self,
|
||||
data: &GraphWitness,
|
||||
) -> Result<Vec<Fp>, Box<dyn std::error::Error>> {
|
||||
pub fn prepare_public_inputs(&self, data: &GraphWitness) -> Result<Vec<Fp>, GraphError> {
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
let mut public_inputs: Vec<Fp> = vec![];
|
||||
@@ -890,7 +825,7 @@ impl GraphCircuit {
|
||||
pub fn pretty_public_inputs(
|
||||
&self,
|
||||
data: &GraphWitness,
|
||||
) -> Result<Option<PrettyElements>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Option<PrettyElements>, GraphError> {
|
||||
// dequantize the supplied data using the provided scale.
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
@@ -932,10 +867,7 @@ impl GraphCircuit {
|
||||
|
||||
///
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
pub fn load_graph_input(&mut self, data: &GraphData) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
@@ -946,7 +878,7 @@ impl GraphCircuit {
|
||||
pub fn load_graph_from_file_exclusively(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
@@ -956,7 +888,7 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
_ => Err("Cannot use non-file data source as input for this method.".into()),
|
||||
_ => unreachable!("cannot load from on-chain data"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -965,7 +897,7 @@ impl GraphCircuit {
|
||||
pub async fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
let shapes = self.model().graph.input_shapes()?;
|
||||
let scales = self.model().graph.get_input_scales();
|
||||
let input_types = self.model().graph.get_input_types()?;
|
||||
@@ -983,14 +915,12 @@ impl GraphCircuit {
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
match &data {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
DataSource::OnChain(_) => {
|
||||
Err("Cannot use on-chain data source as input for this method.".into())
|
||||
}
|
||||
DataSource::OnChain(_) => Err(GraphError::OnChainDataSource),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1002,7 +932,7 @@ impl GraphCircuit {
|
||||
shapes: Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
match &data {
|
||||
DataSource::OnChain(source) => {
|
||||
let mut per_item_scale = vec![];
|
||||
@@ -1030,7 +960,7 @@ impl GraphCircuit {
|
||||
source: OnChainSource,
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
|
||||
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
|
||||
let inputs = read_on_chain_inputs(client.clone(), client_address, &source.calls).await?;
|
||||
@@ -1054,7 +984,7 @@ impl GraphCircuit {
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
input_types: Vec<InputType>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
// quantize the supplied data using the provided scale.
|
||||
let mut data: Vec<Tensor<Fp>> = vec![];
|
||||
for (((d, shape), scale), input_type) in file_data
|
||||
@@ -1085,7 +1015,7 @@ impl GraphCircuit {
|
||||
&mut self,
|
||||
file_data: &[Vec<Fp>],
|
||||
shapes: &[Vec<usize>],
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
// quantize the supplied data using the provided scale.
|
||||
let mut data: Vec<Tensor<Fp>> = vec![];
|
||||
for (d, shape) in file_data.iter().zip(shapes) {
|
||||
@@ -1112,7 +1042,7 @@ impl GraphCircuit {
|
||||
&self,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i64,
|
||||
) -> Result<u32, Box<dyn std::error::Error>> {
|
||||
) -> Result<u32, GraphError> {
|
||||
// pick the range with the largest absolute size safe_lookup_range or max_range_size
|
||||
let safe_range = std::cmp::max(
|
||||
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
|
||||
@@ -1133,7 +1063,7 @@ impl GraphCircuit {
|
||||
max_range_size: i64,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i64,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), GraphError> {
|
||||
// load the max logrows
|
||||
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
|
||||
let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS);
|
||||
@@ -1142,15 +1072,18 @@ impl GraphCircuit {
|
||||
|
||||
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
|
||||
|
||||
let lookup_size = (safe_lookup_range.1 - safe_lookup_range.0).abs();
|
||||
// check if has overflowed max lookup input
|
||||
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
|
||||
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
|
||||
return Err(err_string.into());
|
||||
if lookup_size > MAX_LOOKUP_ABS / lookup_safety_margin {
|
||||
return Err(GraphError::LookupRangeTooLarge(
|
||||
lookup_size.unsigned_abs() as usize
|
||||
));
|
||||
}
|
||||
|
||||
if max_range_size.abs() > MAX_LOOKUP_ABS {
|
||||
let err_string = format!("max range check size {:?} is too large", max_range_size);
|
||||
return Err(err_string.into());
|
||||
return Err(GraphError::RangeCheckTooLarge(
|
||||
max_range_size.unsigned_abs() as usize,
|
||||
));
|
||||
}
|
||||
|
||||
// These are hard lower limits, we can't overflow instances or modules constraints
|
||||
@@ -1194,12 +1127,7 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
|
||||
let err_string = format!(
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
max_logrows
|
||||
);
|
||||
debug!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
return Err(GraphError::ExtendedKTooLarge(max_logrows));
|
||||
}
|
||||
|
||||
let logrows = max_logrows;
|
||||
@@ -1286,7 +1214,7 @@ impl GraphCircuit {
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
|
||||
) -> Result<GraphWitness, GraphError> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
let visibility = VarVisibility::from_args(&self.settings().run_args)?;
|
||||
@@ -1401,7 +1329,7 @@ impl GraphCircuit {
|
||||
pub fn from_run_args(
|
||||
run_args: &RunArgs,
|
||||
model_path: &std::path::Path,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
) -> Result<Self, GraphError> {
|
||||
let model = Model::from_run_args(run_args, model_path)?;
|
||||
Self::new(model, run_args)
|
||||
}
|
||||
@@ -1412,8 +1340,11 @@ impl GraphCircuit {
|
||||
params: &GraphSettings,
|
||||
model_path: &std::path::Path,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
params.run_args.validate()?;
|
||||
) -> Result<Self, GraphError> {
|
||||
params
|
||||
.run_args
|
||||
.validate()
|
||||
.map_err(GraphError::InvalidRunArgs)?;
|
||||
let model = Model::from_run_args(¶ms.run_args, model_path)?;
|
||||
Self::new_from_settings(model, params.clone(), check_mode)
|
||||
}
|
||||
@@ -1424,7 +1355,7 @@ impl GraphCircuit {
|
||||
&mut self,
|
||||
data: &mut GraphData,
|
||||
test_on_chain_data: TestOnChainData,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
) -> Result<(), GraphError> {
|
||||
// Set up local anvil instance for reading on-chain data
|
||||
|
||||
let input_scales = self.model().graph.get_input_scales();
|
||||
@@ -1438,15 +1369,13 @@ impl GraphCircuit {
|
||||
) {
|
||||
// if not public then fail
|
||||
if self.settings().run_args.input_visibility.is_private() {
|
||||
return Err("Cannot use on-chain data source as private data".into());
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
|
||||
let input_data = match &data.input_data {
|
||||
DataSource::File(input_data) => input_data,
|
||||
_ => {
|
||||
return Err("Cannot use non file source as input for on-chain test.
|
||||
Manually populate on-chain data from file source instead"
|
||||
.into())
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
};
|
||||
// Get the flatten length of input_data
|
||||
@@ -1467,19 +1396,13 @@ impl GraphCircuit {
|
||||
) {
|
||||
// if not public then fail
|
||||
if self.settings().run_args.output_visibility.is_private() {
|
||||
return Err("Cannot use on-chain data source as private data".into());
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
|
||||
let output_data = match &data.output_data {
|
||||
Some(DataSource::File(output_data)) => output_data,
|
||||
Some(DataSource::OnChain(_)) => {
|
||||
return Err(
|
||||
"Cannot use on-chain data source as output for on-chain test.
|
||||
Will manually populate on-chain data from file source instead"
|
||||
.into(),
|
||||
)
|
||||
}
|
||||
_ => return Err("No output data found".into()),
|
||||
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
|
||||
_ => return Err(GraphError::MissingDataSource),
|
||||
};
|
||||
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
|
||||
output_data,
|
||||
@@ -1522,12 +1445,10 @@ impl CircuitSize {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
pub fn as_json(&self) -> Result<String, GraphError> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use super::errors::GraphError;
|
||||
use super::extract_const_quantized_values;
|
||||
use super::node::*;
|
||||
use super::scale_to_multiplier;
|
||||
use super::vars::*;
|
||||
use super::GraphError;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
@@ -37,7 +37,6 @@ use std::collections::BTreeMap;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::error::Error;
|
||||
use std::fs;
|
||||
use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
@@ -396,7 +395,7 @@ impl ParsedNodes {
|
||||
}
|
||||
|
||||
/// Returns shapes of the computational graph's inputs
|
||||
pub fn input_shapes(&self) -> Result<Vec<Vec<usize>>, Box<dyn Error>> {
|
||||
pub fn input_shapes(&self) -> Result<Vec<Vec<usize>>, GraphError> {
|
||||
let mut inputs = vec![];
|
||||
|
||||
for input in self.inputs.iter() {
|
||||
@@ -470,7 +469,7 @@ impl Model {
|
||||
/// * `reader` - A reader for an Onnx file.
|
||||
/// * `run_args` - [RunArgs]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result<Self, Box<dyn Error>> {
|
||||
pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let visibility = VarVisibility::from_args(run_args)?;
|
||||
|
||||
let graph = Self::load_onnx_model(reader, run_args, &visibility)?;
|
||||
@@ -483,7 +482,7 @@ impl Model {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn save(&self, path: PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let writer = std::io::BufWriter::new(f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
@@ -491,7 +490,7 @@ impl Model {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn load(path: PathBuf) -> Result<Self, Box<dyn Error>> {
|
||||
pub fn load(path: PathBuf) -> Result<Self, GraphError> {
|
||||
// read bytes from file
|
||||
let mut f = std::fs::File::open(&path)?;
|
||||
let metadata = fs::metadata(&path)?;
|
||||
@@ -506,7 +505,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<GraphSettings, Box<dyn Error>> {
|
||||
) -> Result<GraphSettings, GraphError> {
|
||||
let instance_shapes = self.instance_shapes()?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
@@ -536,7 +535,7 @@ impl Model {
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs, false, false)?;
|
||||
|
||||
@@ -583,7 +582,7 @@ impl Model {
|
||||
run_args: &RunArgs,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
) -> Result<ForwardResult, GraphError> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
@@ -601,15 +600,12 @@ impl Model {
|
||||
fn load_onnx_using_tract(
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
) -> Result<TractResult, Box<dyn Error>> {
|
||||
) -> Result<TractResult, GraphError> {
|
||||
use tract_onnx::{
|
||||
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
|
||||
};
|
||||
|
||||
let mut model = tract_onnx::onnx().model_for_read(reader).map_err(|e| {
|
||||
error!("Error loading model: {}", e);
|
||||
GraphError::ModelLoad
|
||||
})?;
|
||||
let mut model = tract_onnx::onnx().model_for_read(reader)?;
|
||||
|
||||
let variables: std::collections::HashMap<String, usize> =
|
||||
std::collections::HashMap::from_iter(run_args.variables.clone());
|
||||
@@ -622,7 +618,7 @@ impl Model {
|
||||
if matches!(x, GenericFactoid::Any) {
|
||||
let batch_size = match variables.get("batch_size") {
|
||||
Some(x) => x,
|
||||
None => return Err("Unknown dimension batch_size in model inputs, set batch_size in variables".into()),
|
||||
None => return Err(GraphError::MissingBatchSize),
|
||||
};
|
||||
fact.shape
|
||||
.set_dim(i, tract_onnx::prelude::TDim::Val(*batch_size as i64));
|
||||
@@ -680,12 +676,12 @@ impl Model {
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
visibility: &VarVisibility,
|
||||
) -> Result<ParsedNodes, Box<dyn Error>> {
|
||||
) -> Result<ParsedNodes, GraphError> {
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let (model, symbol_values) = Self::load_onnx_using_tract(reader, run_args)?;
|
||||
|
||||
let scales = VarScales::from_args(run_args)?;
|
||||
let scales = VarScales::from_args(run_args);
|
||||
let nodes = Self::nodes_from_graph(
|
||||
&model,
|
||||
run_args,
|
||||
@@ -762,7 +758,7 @@ impl Model {
|
||||
symbol_values: &SymbolValues,
|
||||
override_input_scales: Option<Vec<crate::Scale>>,
|
||||
override_output_scales: Option<HashMap<usize, crate::Scale>>,
|
||||
) -> Result<BTreeMap<usize, NodeType>, Box<dyn Error>> {
|
||||
) -> Result<BTreeMap<usize, NodeType>, GraphError> {
|
||||
use crate::graph::node_output_shapes;
|
||||
|
||||
let mut nodes = BTreeMap::<usize, NodeType>::new();
|
||||
@@ -976,14 +972,11 @@ impl Model {
|
||||
model_path: &std::path::Path,
|
||||
data_chunks: &[GraphData],
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Vec<Tensor<f32>>>, Box<dyn Error>> {
|
||||
) -> Result<Vec<Vec<Tensor<f32>>>, GraphError> {
|
||||
use tract_onnx::tract_core::internal::IntoArcTensor;
|
||||
|
||||
let (model, _) = Model::load_onnx_using_tract(
|
||||
&mut std::fs::File::open(model_path)
|
||||
.map_err(|_| format!("failed to load {}", model_path.display()))?,
|
||||
run_args,
|
||||
)?;
|
||||
let (model, _) =
|
||||
Model::load_onnx_using_tract(&mut std::fs::File::open(model_path)?, run_args)?;
|
||||
|
||||
let datum_types: Vec<DatumType> = model
|
||||
.input_outlets()?
|
||||
@@ -1011,15 +1004,8 @@ impl Model {
|
||||
/// # Arguments
|
||||
/// * `params` - A [GraphSettings] struct holding parsed CLI arguments.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn from_run_args(
|
||||
run_args: &RunArgs,
|
||||
model: &std::path::Path,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
Model::new(
|
||||
&mut std::fs::File::open(model)
|
||||
.map_err(|_| format!("failed to load {}", model.display()))?,
|
||||
run_args,
|
||||
)
|
||||
pub fn from_run_args(run_args: &RunArgs, model: &std::path::Path) -> Result<Self, GraphError> {
|
||||
Model::new(&mut std::fs::File::open(model)?, run_args)
|
||||
}
|
||||
|
||||
/// Configures a model for the circuit
|
||||
@@ -1031,7 +1017,7 @@ impl Model {
|
||||
meta: &mut ConstraintSystem<Fp>,
|
||||
vars: &ModelVars<Fp>,
|
||||
settings: &GraphSettings,
|
||||
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
|
||||
) -> Result<PolyConfig<Fp>, GraphError> {
|
||||
debug!("configuring model");
|
||||
|
||||
let lookup_range = settings.run_args.lookup_range;
|
||||
@@ -1093,7 +1079,7 @@ impl Model {
|
||||
vars: &mut ModelVars<Fp>,
|
||||
witnessed_outputs: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
|
||||
info!("model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -1103,7 +1089,11 @@ impl Model {
|
||||
let input_shapes = self.graph.input_shapes()?;
|
||||
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
|
||||
if self.visibility.input.is_public() {
|
||||
let instance = vars.instance.as_ref().ok_or("no instance")?.clone();
|
||||
let instance = vars
|
||||
.instance
|
||||
.as_ref()
|
||||
.ok_or(GraphError::MissingInstances)?
|
||||
.clone();
|
||||
results.insert(*input_idx, vec![instance]);
|
||||
vars.increment_instance_idx();
|
||||
} else {
|
||||
@@ -1123,7 +1113,12 @@ impl Model {
|
||||
let outputs = layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(
|
||||
region,
|
||||
0,
|
||||
run_args.num_inner_cols,
|
||||
original_constants.clone(),
|
||||
);
|
||||
// we need to do this as this loop is called multiple times
|
||||
vars.set_instance_idx(instance_idx);
|
||||
|
||||
@@ -1147,24 +1142,31 @@ impl Model {
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let comparators = if run_args.output_visibility == Visibility::Public {
|
||||
let res = vars.instance.as_ref().ok_or("no instance")?.clone();
|
||||
let res = vars
|
||||
.instance
|
||||
.as_ref()
|
||||
.ok_or(GraphError::MissingInstances)?
|
||||
.clone();
|
||||
vars.increment_instance_idx();
|
||||
res
|
||||
} else {
|
||||
// if witnessed_outputs is of len less than i error
|
||||
if witnessed_outputs.len() <= i {
|
||||
return Err("you provided insufficient witness values to generate a fixed output".into());
|
||||
return Err(GraphError::InsufficientWitnessValues);
|
||||
}
|
||||
witnessed_outputs[i].clone()
|
||||
};
|
||||
|
||||
config.base.layout(
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
config
|
||||
.base
|
||||
.layout(
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
Box::new(HybridOp::RangeCheck(tolerance)),
|
||||
)
|
||||
.map_err(|e| e.into())
|
||||
})
|
||||
.collect::<Result<Vec<_>,_>>();
|
||||
.collect::<Result<Vec<_>, GraphError>>();
|
||||
res.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
@@ -1178,7 +1180,6 @@ impl Model {
|
||||
|
||||
Ok(outputs)
|
||||
},
|
||||
|
||||
)?;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
@@ -1192,7 +1193,7 @@ impl Model {
|
||||
config: &mut ModelConfig,
|
||||
region: &mut RegionCtx<Fp>,
|
||||
results: &mut BTreeMap<usize, Vec<ValTensor<Fp>>>,
|
||||
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
|
||||
// index over results to get original inputs
|
||||
let orig_inputs: BTreeMap<usize, _> = results
|
||||
.clone()
|
||||
@@ -1237,7 +1238,10 @@ impl Model {
|
||||
let res = if node.is_constant() && node.num_uses() == 1 {
|
||||
log::debug!("node {} is a constant with 1 use", n.idx);
|
||||
let mut node = n.clone();
|
||||
let c = node.opkind.get_mutable_constant().ok_or("no constant")?;
|
||||
let c = node
|
||||
.opkind
|
||||
.get_mutable_constant()
|
||||
.ok_or(GraphError::MissingConstants)?;
|
||||
Some(c.quantized_values.clone().try_into()?)
|
||||
} else {
|
||||
config
|
||||
@@ -1394,7 +1398,7 @@ impl Model {
|
||||
inputs: &[ValTensor<Fp>],
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
) -> Result<DummyPassRes, GraphError> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
@@ -1549,7 +1553,7 @@ impl Model {
|
||||
}
|
||||
|
||||
/// Shapes of the computational graph's public inputs (if any)
|
||||
pub fn instance_shapes(&self) -> Result<Vec<Vec<usize>>, Box<dyn Error>> {
|
||||
pub fn instance_shapes(&self) -> Result<Vec<Vec<usize>>, GraphError> {
|
||||
let mut instance_shapes = vec![];
|
||||
if self.visibility.input.is_public() {
|
||||
instance_shapes.extend(self.graph.input_shapes()?);
|
||||
|
||||
@@ -11,11 +11,12 @@ use halo2curves::bn256::{Fr as Fp, G1Affine};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::errors::GraphError;
|
||||
use super::{VarVisibility, Visibility};
|
||||
|
||||
/// poseidon len to hash in tree
|
||||
pub const POSEIDON_LEN_GRAPH: usize = 32;
|
||||
/// Poseidon number of instancess
|
||||
/// Poseidon number of instances
|
||||
pub const POSEIDON_INSTANCES: usize = 1;
|
||||
|
||||
/// Poseidon module type
|
||||
@@ -295,7 +296,7 @@ impl GraphModules {
|
||||
element_visibility: &Visibility,
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
|
||||
) -> Result<ModuleForwardResult, GraphError> {
|
||||
let mut poseidon_hash = None;
|
||||
let mut polycommit = None;
|
||||
|
||||
|
||||
@@ -8,11 +8,14 @@ use super::Visibility;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::CircuitError;
|
||||
use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::errors::GraphError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
use crate::tensor::TensorError;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -22,7 +25,6 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::collections::BTreeMap;
|
||||
use std::error::Error;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::fmt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -65,7 +67,7 @@ impl Op<Fp> for Rescaled {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
.zip(self.scale.iter())
|
||||
@@ -80,11 +82,9 @@ impl Op<Fp> for Rescaled {
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
if self.scale.len() != values.len() {
|
||||
return Err(Box::new(TensorError::DimMismatch(
|
||||
"rescaled inputs".to_string(),
|
||||
)));
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
|
||||
}
|
||||
|
||||
let res =
|
||||
@@ -210,7 +210,7 @@ impl Op<Fp> for RebaseScale {
|
||||
)
|
||||
}
|
||||
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.target_scale)
|
||||
}
|
||||
|
||||
@@ -219,11 +219,11 @@ impl Op<Fp> for RebaseScale {
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
let original_res = self
|
||||
.inner
|
||||
.layout(config, region, values)?
|
||||
.ok_or("no inner layout")?;
|
||||
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
}
|
||||
|
||||
@@ -306,7 +306,7 @@ impl SupportedOp {
|
||||
fn homogenous_rescale(
|
||||
&self,
|
||||
in_scales: Vec<crate::Scale>,
|
||||
) -> Result<Box<dyn Op<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
let op = self.clone_dyn();
|
||||
@@ -372,7 +372,7 @@ impl Op<Fp> for SupportedOp {
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
self.as_op().layout(config, region, values)
|
||||
}
|
||||
|
||||
@@ -400,7 +400,7 @@ impl Op<Fp> for SupportedOp {
|
||||
self
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
self.as_op().out_scale(in_scales)
|
||||
}
|
||||
}
|
||||
@@ -478,7 +478,7 @@ impl Node {
|
||||
symbol_values: &SymbolValues,
|
||||
div_rebasing: bool,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
) -> Result<Self, GraphError> {
|
||||
trace!("Create {:?}", node);
|
||||
trace!("Create op {:?}", node.op);
|
||||
|
||||
@@ -504,10 +504,15 @@ impl Node {
|
||||
input_ids
|
||||
.iter()
|
||||
.map(|(i, _)| {
|
||||
inputs.push(other_nodes.get(i).ok_or("input not found")?.clone());
|
||||
inputs.push(
|
||||
other_nodes
|
||||
.get(i)
|
||||
.ok_or(GraphError::MissingInput(idx))?
|
||||
.clone(),
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let (mut opkind, deleted_indices) = new_op_from_onnx(
|
||||
idx,
|
||||
@@ -544,13 +549,13 @@ impl Node {
|
||||
let idx = inputs
|
||||
.iter()
|
||||
.position(|x| *idx == x.idx())
|
||||
.ok_or("input not found")?;
|
||||
.ok_or(GraphError::MissingInput(*idx))?;
|
||||
Ok(inputs[idx].out_scales()[*outlet])
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let homogenous_inputs = opkind.requires_homogenous_input_scales();
|
||||
// autoamtically increases a constant's scale if it is only used once and
|
||||
// automatically increases a constant's scale if it is only used once and
|
||||
for input in homogenous_inputs
|
||||
.into_iter()
|
||||
.filter(|i| !deleted_indices.contains(i))
|
||||
@@ -558,7 +563,7 @@ impl Node {
|
||||
if inputs.len() > input {
|
||||
let input_node = other_nodes
|
||||
.get_mut(&inputs[input].idx())
|
||||
.ok_or("input not found")?;
|
||||
.ok_or(GraphError::MissingInput(idx))?;
|
||||
let input_opkind = &mut input_node.opkind();
|
||||
if let Some(constant) = input_opkind.get_mutable_constant() {
|
||||
rescale_const_with_single_use(
|
||||
@@ -615,10 +620,10 @@ fn rescale_const_with_single_use(
|
||||
in_scales: Vec<crate::Scale>,
|
||||
param_visibility: &Visibility,
|
||||
num_uses: usize,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
) -> Result<(), GraphError> {
|
||||
if num_uses == 1 {
|
||||
let current_scale = constant.out_scale(vec![])?;
|
||||
let scale_max = in_scales.iter().max().ok_or("no scales")?;
|
||||
let scale_max = in_scales.iter().max().ok_or(GraphError::MissingScale)?;
|
||||
if scale_max > ¤t_scale {
|
||||
let raw_values = constant.raw_values.clone();
|
||||
constant.quantized_values =
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::GraphError;
|
||||
use super::errors::GraphError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::VarScales;
|
||||
use super::{Rescaled, SupportedOp, Visibility};
|
||||
@@ -16,7 +15,6 @@ use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::{debug, warn};
|
||||
use std::error::Error;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::sync::Arc;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -92,7 +90,7 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
pub fn node_output_shapes(
|
||||
node: &OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
symbol_values: &SymbolValues,
|
||||
) -> Result<Vec<Vec<usize>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<Vec<usize>>, GraphError> {
|
||||
let mut shapes = Vec::new();
|
||||
let outputs = node.outputs.to_vec();
|
||||
for output in outputs {
|
||||
@@ -109,7 +107,7 @@ use tract_onnx::prelude::SymbolValues;
|
||||
/// Extracts the raw values from a tensor.
|
||||
pub fn extract_tensor_value(
|
||||
input: Arc<tract_onnx::prelude::Tensor>,
|
||||
) -> Result<Tensor<f32>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Tensor<f32>, GraphError> {
|
||||
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
|
||||
let dt = input.datum_type();
|
||||
@@ -194,20 +192,20 @@ pub fn extract_tensor_value(
|
||||
// Generally a shape or hyperparam
|
||||
let vec = input.as_slice::<tract_onnx::prelude::TDim>()?.to_vec();
|
||||
|
||||
let cast: Result<Vec<f32>, &str> = vec
|
||||
let cast: Result<Vec<f32>, GraphError> = vec
|
||||
.par_iter()
|
||||
.map(|x| match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => match x.to_i64() {
|
||||
Ok(v) => Ok(v as f32),
|
||||
Err(_) => Err("could not evaluate tdim"),
|
||||
Err(_) => Err(GraphError::UnsupportedDataType(0, "TDim".to_string())),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
const_value = Tensor::<f32>::new(Some(&cast?), &dims)?;
|
||||
}
|
||||
_ => return Err("unsupported data type".into()),
|
||||
_ => return Err(GraphError::UnsupportedDataType(0, format!("{:?}", dt))),
|
||||
}
|
||||
const_value.reshape(&dims)?;
|
||||
|
||||
@@ -219,12 +217,12 @@ fn load_op<C: tract_onnx::prelude::Op + Clone>(
|
||||
op: &dyn tract_onnx::prelude::Op,
|
||||
idx: usize,
|
||||
name: String,
|
||||
) -> Result<C, Box<dyn std::error::Error>> {
|
||||
) -> Result<C, GraphError> {
|
||||
// Extract the slope layer hyperparams
|
||||
let op: &C = match op.downcast_ref::<C>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, name)));
|
||||
return Err(GraphError::OpMismatch(idx, name));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -247,7 +245,7 @@ pub fn new_op_from_onnx(
|
||||
inputs: &mut [super::NodeType],
|
||||
symbol_values: &SymbolValues,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<(SupportedOp, Vec<usize>), Box<dyn std::error::Error>> {
|
||||
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
use crate::circuit::InputType;
|
||||
@@ -260,7 +258,7 @@ pub fn new_op_from_onnx(
|
||||
let mut replace_const = |scale: crate::Scale,
|
||||
index: usize,
|
||||
default_op: SupportedOp|
|
||||
-> Result<SupportedOp, Box<dyn std::error::Error>> {
|
||||
-> Result<SupportedOp, GraphError> {
|
||||
let mut constant = inputs[index].opkind();
|
||||
let constant = constant.get_mutable_constant();
|
||||
if let Some(c) = constant {
|
||||
@@ -285,19 +283,13 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(1);
|
||||
let raw_values = &c.raw_values;
|
||||
if raw_values.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"shift left".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "shift left".to_string()));
|
||||
}
|
||||
SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] - raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"ShiftLeft".to_string(),
|
||||
)));
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftLeft".to_string()));
|
||||
}
|
||||
}
|
||||
"ShiftRight" => {
|
||||
@@ -307,19 +299,13 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(1);
|
||||
let raw_values = &c.raw_values;
|
||||
if raw_values.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"shift right".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "shift right".to_string()));
|
||||
}
|
||||
SupportedOp::Linear(PolyOp::Identity {
|
||||
out_scale: Some(input_scales[0] + raw_values[0] as i32),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"ShiftRight".to_string(),
|
||||
)));
|
||||
return Err(GraphError::OpMismatch(idx, "ShiftRight".to_string()));
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
@@ -337,7 +323,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
for (i, input) in inputs.iter_mut().enumerate() {
|
||||
if !input.opkind().is_constant() {
|
||||
return Err("Range only supports constant inputs in a zk circuit".into());
|
||||
return Err(GraphError::NonConstantRange);
|
||||
} else {
|
||||
input.decrement_use();
|
||||
deleted_indices.push(i);
|
||||
@@ -348,7 +334,7 @@ pub fn new_op_from_onnx(
|
||||
assert_eq!(input_ops.len(), 3, "Range requires 3 inputs");
|
||||
let input_ops = input_ops
|
||||
.iter()
|
||||
.map(|x| x.get_constant().ok_or("Range requires constant inputs"))
|
||||
.map(|x| x.get_constant().ok_or(GraphError::NonConstantRange))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let start = input_ops[0].raw_values.map(|x| x as usize)[0];
|
||||
@@ -375,11 +361,11 @@ pub fn new_op_from_onnx(
|
||||
deleted_indices.push(1);
|
||||
let raw_values = &c.raw_values;
|
||||
if raw_values.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "trilu".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "trilu".to_string()));
|
||||
}
|
||||
raw_values[0] as i32
|
||||
} else {
|
||||
return Err("we only support constant inputs for trilu diagonal".into());
|
||||
return Err(GraphError::NonConstantTrilu);
|
||||
};
|
||||
|
||||
SupportedOp::Linear(PolyOp::Trilu { upper, k: diagonal })
|
||||
@@ -387,7 +373,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
"Gather" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "gather".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "gather".to_string()));
|
||||
};
|
||||
let op = load_op::<Gather>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
@@ -456,10 +442,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"ScatterElements" => {
|
||||
if inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"scatter elements".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "scatter elements".to_string()));
|
||||
};
|
||||
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
@@ -494,10 +477,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"ScatterNd" => {
|
||||
if inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"scatter nd".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "scatter nd".to_string()));
|
||||
};
|
||||
// just verify it deserializes correctly
|
||||
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
@@ -529,10 +509,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
"GatherNd" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"gather nd".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "gather nd".to_string()));
|
||||
};
|
||||
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
let batch_dims = op.batch_dims;
|
||||
@@ -566,10 +543,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
"GatherElements" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"gather elements".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "gather elements".to_string()));
|
||||
};
|
||||
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
@@ -615,10 +589,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"MoveAxis".to_string(),
|
||||
)))
|
||||
return Err(GraphError::OpMismatch(idx, "MoveAxis".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -654,7 +625,9 @@ pub fn new_op_from_onnx(
|
||||
| DatumType::U32
|
||||
| DatumType::U64 => 0,
|
||||
DatumType::F16 | DatumType::F32 | DatumType::F64 => scales.params,
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
_ => {
|
||||
return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt)));
|
||||
}
|
||||
};
|
||||
|
||||
// if all raw_values are round then set scale to 0
|
||||
@@ -672,7 +645,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<ArgMax(false)>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "argmax".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "argmax".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
@@ -682,7 +655,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<ArgMin(false)>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "argmin".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "argmin".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
@@ -692,7 +665,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<Min>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
@@ -701,7 +674,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<Max>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
@@ -710,7 +683,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<Prod>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "prod".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "prod".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes: Vec<usize> = op.axes.into_iter().collect();
|
||||
@@ -727,7 +700,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<Sum>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "sum".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "sum".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
@@ -736,10 +709,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Reduce<MeanOfSquares>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"mean of squares".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "mean of squares".to_string()));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
@@ -759,7 +729,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Max".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
@@ -768,10 +738,10 @@ pub fn new_op_from_onnx(
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Max".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
@@ -790,7 +760,7 @@ pub fn new_op_from_onnx(
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "max".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
}
|
||||
"Min" => {
|
||||
@@ -805,7 +775,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Min".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
@@ -814,10 +784,10 @@ pub fn new_op_from_onnx(
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Min".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
@@ -834,7 +804,7 @@ pub fn new_op_from_onnx(
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "min".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
@@ -855,10 +825,7 @@ pub fn new_op_from_onnx(
|
||||
let leaky_op: &LeakyRelu = match leaky_op.0.downcast_ref::<LeakyRelu>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"leaky relu".to_string(),
|
||||
)));
|
||||
return Err(GraphError::OpMismatch(idx, "leaky relu".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -867,7 +834,7 @@ pub fn new_op_from_onnx(
|
||||
})
|
||||
}
|
||||
"Scan" => {
|
||||
return Err("scan should never be analyzed explicitly".into());
|
||||
unreachable!();
|
||||
}
|
||||
"QuantizeLinearU8" | "DequantizeLinearF32" => {
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
@@ -932,7 +899,9 @@ pub fn new_op_from_onnx(
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Source" => {
|
||||
let (scale, datum_type) = match node.outputs[0].fact.datum_type {
|
||||
let dt = node.outputs[0].fact.datum_type;
|
||||
|
||||
let (scale, datum_type) = match dt {
|
||||
DatumType::Bool => (0, InputType::Bool),
|
||||
DatumType::TDim => (0, InputType::TDim),
|
||||
DatumType::I64
|
||||
@@ -946,7 +915,7 @@ pub fn new_op_from_onnx(
|
||||
DatumType::F16 => (scales.input, InputType::F16),
|
||||
DatumType::F32 => (scales.input, InputType::F32),
|
||||
DatumType::F64 => (scales.input, InputType::F64),
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
|
||||
};
|
||||
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
|
||||
}
|
||||
@@ -985,7 +954,7 @@ pub fn new_op_from_onnx(
|
||||
DatumType::F16 | DatumType::F32 | DatumType::F64 => {
|
||||
SupportedOp::Linear(PolyOp::Identity { out_scale: None })
|
||||
}
|
||||
_ => return Err(Box::new(GraphError::UnsupportedDataType)),
|
||||
_ => return Err(GraphError::UnsupportedDataType(idx, format!("{:?}", dt))),
|
||||
}
|
||||
}
|
||||
"Add" => SupportedOp::Linear(PolyOp::Add),
|
||||
@@ -1001,7 +970,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "mul".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "mul".to_string()));
|
||||
}
|
||||
|
||||
if const_idx.len() == 1 {
|
||||
@@ -1027,17 +996,14 @@ pub fn new_op_from_onnx(
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::Less)
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "less".to_string())));
|
||||
return Err(GraphError::InvalidDims(idx, "less".to_string()));
|
||||
}
|
||||
}
|
||||
"LessEqual" => {
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::LessEqual)
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"less equal".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
|
||||
}
|
||||
}
|
||||
"Greater" => {
|
||||
@@ -1045,10 +1011,7 @@ pub fn new_op_from_onnx(
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::Greater)
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"greater".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
|
||||
}
|
||||
}
|
||||
"GreaterEqual" => {
|
||||
@@ -1056,10 +1019,7 @@ pub fn new_op_from_onnx(
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::GreaterEqual)
|
||||
} else {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"greater equal".to_string(),
|
||||
)));
|
||||
return Err(GraphError::InvalidDims(idx, "greater equal".to_string()));
|
||||
}
|
||||
}
|
||||
"EinSum" => {
|
||||
@@ -1067,7 +1027,7 @@ pub fn new_op_from_onnx(
|
||||
let op: &EinSum = match node.op().downcast_ref::<EinSum>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "einsum".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "einsum".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1081,7 +1041,7 @@ pub fn new_op_from_onnx(
|
||||
let softmax_op: &Softmax = match node.op().downcast_ref::<Softmax>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "softmax".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "softmax".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1100,7 +1060,7 @@ pub fn new_op_from_onnx(
|
||||
let sumpool_node: &MaxPool = match op.downcast_ref() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Maxpool".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Maxpool".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1108,9 +1068,9 @@ pub fn new_op_from_onnx(
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let stride = pool_spec
|
||||
@@ -1122,7 +1082,7 @@ pub fn new_op_from_onnx(
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
@@ -1170,15 +1130,15 @@ pub fn new_op_from_onnx(
|
||||
let conv_node: &Conv = match node.op().downcast_ref::<Conv>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "conv".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "conv".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(dilations) = &conv_node.pool_spec.dilations {
|
||||
if dilations.iter().any(|x| *x != 1) {
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"non unit dilations not supported".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1186,15 +1146,15 @@ pub fn new_op_from_onnx(
|
||||
&& (conv_node.pool_spec.data_format != DataFormat::CHW))
|
||||
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let stride = match conv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
return Err(GraphError::MissingParams("strides".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1203,7 +1163,7 @@ pub fn new_op_from_onnx(
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1234,30 +1194,30 @@ pub fn new_op_from_onnx(
|
||||
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "deconv".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "deconv".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(dilations) = &deconv_node.pool_spec.dilations {
|
||||
if dilations.iter().any(|x| *x != 1) {
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"non unit dilations not supported".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if (deconv_node.pool_spec.data_format != DataFormat::NCHW)
|
||||
|| (deconv_node.kernel_format != KernelFormat::OIHW)
|
||||
{
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"data or kernel in wrong format".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let stride = match deconv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
return Err(GraphError::MissingParams("strides".to_string()));
|
||||
}
|
||||
};
|
||||
let padding = match &deconv_node.pool_spec.padding {
|
||||
@@ -1265,7 +1225,7 @@ pub fn new_op_from_onnx(
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1295,10 +1255,7 @@ pub fn new_op_from_onnx(
|
||||
let downsample_node: Downsample = match node.op().downcast_ref::<Downsample>() {
|
||||
Some(b) => b.clone(),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(
|
||||
idx,
|
||||
"downsample".to_string(),
|
||||
)));
|
||||
return Err(GraphError::OpMismatch(idx, "downsample".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1323,7 +1280,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
// check if optional scale factor is present
|
||||
if inputs.len() != 2 && inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Resize".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Resize".to_string()));
|
||||
}
|
||||
|
||||
let scale_factor_node = // find optional_scales_input in the string and extract the value inside the Some
|
||||
@@ -1337,7 +1294,7 @@ pub fn new_op_from_onnx(
|
||||
.collect::<Vec<_>>()[1]
|
||||
.split(')')
|
||||
.collect::<Vec<_>>()[0]
|
||||
.parse::<usize>()?)
|
||||
.parse::<usize>().map_err(|_| GraphError::OpMismatch(idx, "Resize".to_string()))?)
|
||||
};
|
||||
|
||||
let scale_factor = if let Some(scale_factor_node) = scale_factor_node {
|
||||
@@ -1345,7 +1302,7 @@ pub fn new_op_from_onnx(
|
||||
if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
c.map(|x| x as usize).into_iter().collect::<Vec<usize>>()
|
||||
} else {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "Resize".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "Resize".to_string()));
|
||||
}
|
||||
} else {
|
||||
// default
|
||||
@@ -1369,7 +1326,7 @@ pub fn new_op_from_onnx(
|
||||
let sumpool_node: &SumPool = match op.downcast_ref() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "sumpool".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "sumpool".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1377,9 +1334,9 @@ pub fn new_op_from_onnx(
|
||||
|
||||
// only support pytorch type formatting for now
|
||||
if pool_spec.data_format != DataFormat::NCHW {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
return Err(GraphError::MissingParams(
|
||||
"data in wrong format".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
let stride = pool_spec
|
||||
@@ -1391,7 +1348,7 @@ pub fn new_op_from_onnx(
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1411,7 +1368,7 @@ pub fn new_op_from_onnx(
|
||||
let pad_node: &Pad = match node.op().downcast_ref::<Pad>() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Err(Box::new(GraphError::OpMismatch(idx, "pad".to_string())));
|
||||
return Err(GraphError::OpMismatch(idx, "pad".to_string()));
|
||||
}
|
||||
};
|
||||
// we only support constant 0 padding
|
||||
@@ -1420,9 +1377,9 @@ pub fn new_op_from_onnx(
|
||||
tract_onnx::prelude::Tensor::zero::<f32>(&[])?,
|
||||
))
|
||||
{
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
return Err(GraphError::MisformedParams(
|
||||
"pad mode or pad type".to_string(),
|
||||
)));
|
||||
));
|
||||
}
|
||||
|
||||
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
|
||||
@@ -1473,7 +1430,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
const_value: Tensor<f32>,
|
||||
scale: crate::Scale,
|
||||
visibility: &Visibility,
|
||||
) -> Result<Tensor<F>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Tensor<F>, TensorError> {
|
||||
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
|
||||
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
@@ -1492,7 +1449,7 @@ use crate::tensor::ValTensor;
|
||||
pub(crate) fn split_valtensor(
|
||||
values: &ValTensor<Fp>,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<ValTensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<Vec<ValTensor<Fp>>, GraphError> {
|
||||
let mut tensors: Vec<ValTensor<Fp>> = Vec::new();
|
||||
let mut start = 0;
|
||||
for shape in shapes {
|
||||
@@ -1510,7 +1467,7 @@ pub fn homogenize_input_scales(
|
||||
op: Box<dyn Op<Fp>>,
|
||||
input_scales: Vec<crate::Scale>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<Fp>>, Box<dyn Error>> {
|
||||
) -> Result<Box<dyn Op<Fp>>, GraphError> {
|
||||
let relevant_input_scales = input_scales
|
||||
.clone()
|
||||
.into_iter()
|
||||
@@ -1529,7 +1486,7 @@ pub fn homogenize_input_scales(
|
||||
|
||||
let mut multipliers: Vec<u128> = vec![1; input_scales.len()];
|
||||
|
||||
let max_scale = input_scales.iter().max().ok_or("no max scale")?;
|
||||
let max_scale = input_scales.iter().max().ok_or(GraphError::MissingScale)?;
|
||||
let _ = input_scales
|
||||
.iter()
|
||||
.enumerate()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::error::Error;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::tensor::TensorType;
|
||||
@@ -17,6 +16,8 @@ use pyo3::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
use self::errors::GraphError;
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Label enum to track whether model input, model parameters, and model output are public, private, or hashed
|
||||
@@ -72,7 +73,7 @@ impl ToFlags for Visibility {
|
||||
impl<'a> From<&'a str> for Visibility {
|
||||
fn from(s: &'a str) -> Self {
|
||||
if s.contains("hashed/private") {
|
||||
// split on last occurence of '/'
|
||||
// split on last occurrence of '/'
|
||||
let (_, outlets) = s.split_at(s.rfind('/').unwrap());
|
||||
let outlets = outlets
|
||||
.trim_start_matches('/')
|
||||
@@ -261,12 +262,12 @@ impl VarScales {
|
||||
}
|
||||
|
||||
/// Place in [VarScales] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
|
||||
Ok(Self {
|
||||
pub fn from_args(args: &RunArgs) -> Self {
|
||||
Self {
|
||||
input: args.input_scale,
|
||||
params: args.param_scale,
|
||||
rebase_multiplier: args.scale_rebase_multiplier,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,15 +304,13 @@ impl Default for VarVisibility {
|
||||
impl VarVisibility {
|
||||
/// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover.
|
||||
/// Place in [VarVisibility] struct.
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, Box<dyn Error>> {
|
||||
pub fn from_args(args: &RunArgs) -> Result<Self, GraphError> {
|
||||
let input_vis = &args.input_visibility;
|
||||
let params_vis = &args.param_visibility;
|
||||
let output_vis = &args.output_visibility;
|
||||
|
||||
if params_vis.is_public() {
|
||||
return Err(
|
||||
"public visibility for params is deprecated, please use `fixed` instead".into(),
|
||||
);
|
||||
return Err(GraphError::ParamsPublicVisibility);
|
||||
}
|
||||
|
||||
if !output_vis.is_public()
|
||||
@@ -327,7 +326,7 @@ impl VarVisibility {
|
||||
& !params_vis.is_polycommit()
|
||||
& !input_vis.is_polycommit()
|
||||
{
|
||||
return Err(Box::new(GraphError::Visibility));
|
||||
return Err(GraphError::Visibility);
|
||||
}
|
||||
Ok(Self {
|
||||
input: input_vis.clone(),
|
||||
|
||||
79
src/lib.rs
79
src/lib.rs
@@ -28,6 +28,59 @@
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
|
||||
/// Error type
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum EZKLError {
|
||||
#[error("[aggregation] {0}")]
|
||||
AggregationError(#[from] pfsys::evm::aggregation_kzg::AggregationError),
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[eth] {0}")]
|
||||
EthError(#[from] eth::EthError),
|
||||
#[error("[graph] {0}")]
|
||||
GraphError(#[from] graph::errors::GraphError),
|
||||
#[error("[pfsys] {0}")]
|
||||
PfsysError(#[from] pfsys::errors::PfsysError),
|
||||
#[error("[circuit] {0}")]
|
||||
CircuitError(#[from] circuit::errors::CircuitError),
|
||||
#[error("[tensor] {0}")]
|
||||
TensorError(#[from] tensor::errors::TensorError),
|
||||
#[error("[module] {0}")]
|
||||
ModuleError(#[from] circuit::modules::errors::ModuleError),
|
||||
#[error("[io] {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
#[error("[json] {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
#[error("[utf8] {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[reqwest] {0}")]
|
||||
ReqwestError(#[from] reqwest::Error),
|
||||
#[error("[fmt] {0}")]
|
||||
FmtError(#[from] std::fmt::Error),
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] halo2_proofs::plonk::Error),
|
||||
#[error("[Uncategorized] {0}")]
|
||||
UncategorizedError(String),
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[execute] {0}")]
|
||||
ExecutionError(#[from] execute::ExecutionError),
|
||||
#[error("[srs] {0}")]
|
||||
SrsError(#[from] pfsys::srs::SrsError),
|
||||
}
|
||||
|
||||
impl From<&str> for EZKLError {
|
||||
fn from(s: &str) -> Self {
|
||||
EZKLError::UncategorizedError(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for EZKLError {
|
||||
fn from(s: String) -> Self {
|
||||
EZKLError::UncategorizedError(s)
|
||||
}
|
||||
}
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
@@ -178,37 +231,37 @@ impl From<String> for Commitments {
|
||||
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
|
||||
pub struct RunArgs {
|
||||
/// The tolerance for error on model outputs
|
||||
#[arg(short = 'T', long, default_value = "0")]
|
||||
#[arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other)]
|
||||
pub tolerance: Tolerance,
|
||||
/// The denominator in the fixed point representation used when quantizing inputs
|
||||
#[arg(short = 'S', long, default_value = "7", allow_hyphen_values = true)]
|
||||
#[arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other)]
|
||||
pub input_scale: Scale,
|
||||
/// The denominator in the fixed point representation used when quantizing parameters
|
||||
#[arg(long, default_value = "7", allow_hyphen_values = true)]
|
||||
#[arg(long, default_value = "7", value_hint = clap::ValueHint::Other)]
|
||||
pub param_scale: Scale,
|
||||
/// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
#[arg(long, default_value = "1")]
|
||||
#[arg(long, default_value = "1", value_hint = clap::ValueHint::Other)]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i64, i64>, default_value = "-32768->32768")]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17")]
|
||||
#[arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other)]
|
||||
pub logrows: u32,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'N', long, default_value = "2")]
|
||||
#[arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other)]
|
||||
pub num_inner_cols: usize,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other)]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Flags whether inputs are public, private, fixed, hashed, polycommit
|
||||
#[arg(long, default_value = "private")]
|
||||
#[arg(long, default_value = "private", value_hint = clap::ValueHint::Other)]
|
||||
pub input_visibility: Visibility,
|
||||
/// Flags whether outputs are public, private, fixed, hashed, polycommit
|
||||
#[arg(long, default_value = "public")]
|
||||
#[arg(long, default_value = "public", value_hint = clap::ValueHint::Other)]
|
||||
pub output_visibility: Visibility,
|
||||
/// Flags whether params are fixed, private, hashed, polycommit
|
||||
#[arg(long, default_value = "private")]
|
||||
#[arg(long, default_value = "private", value_hint = clap::ValueHint::Other)]
|
||||
pub param_visibility: Visibility,
|
||||
#[arg(long, default_value = "false")]
|
||||
/// Rebase the scale using lookup table for division instead of using a range check
|
||||
@@ -217,10 +270,10 @@ pub struct RunArgs {
|
||||
#[arg(long, default_value = "false")]
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
/// check mode (safe, unsafe, etc)
|
||||
#[arg(long, default_value = "unsafe")]
|
||||
#[arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other)]
|
||||
pub check_mode: CheckMode,
|
||||
/// commitment scheme
|
||||
#[arg(long, default_value = "kzg")]
|
||||
#[arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other)]
|
||||
pub commitment: Option<Commitments>,
|
||||
}
|
||||
|
||||
@@ -248,7 +301,7 @@ impl Default for RunArgs {
|
||||
|
||||
impl RunArgs {
|
||||
///
|
||||
pub fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.param_visibility == Visibility::Public {
|
||||
return Err(
|
||||
"params cannot be public instances, you are probably trying to use `fixed` or `kzgcommit`"
|
||||
|
||||
27
src/pfsys/errors.rs
Normal file
27
src/pfsys/errors.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error type for the pfsys module
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PfsysError {
|
||||
/// Failed to save the proof
|
||||
#[error("failed to save the proof: {0}")]
|
||||
SaveProof(String),
|
||||
/// Failed to load the proof
|
||||
#[error("failed to load the proof: {0}")]
|
||||
LoadProof(String),
|
||||
/// Halo2 error
|
||||
#[error("[halo2] {0}")]
|
||||
Halo2Error(#[from] halo2_proofs::plonk::Error),
|
||||
/// Failed to write point to transcript
|
||||
#[error("failed to write point to transcript: {0}")]
|
||||
WritePoint(String),
|
||||
/// Invalid commitment scheme
|
||||
#[error("invalid commitment scheme")]
|
||||
InvalidCommitmentScheme,
|
||||
/// Failed to load vk from file
|
||||
#[error("failed to load vk from file: {0}")]
|
||||
LoadVk(String),
|
||||
/// Failed to load pk from file
|
||||
#[error("failed to load pk from file: {0}")]
|
||||
LoadPk(String),
|
||||
}
|
||||
@@ -10,18 +10,15 @@ pub enum EvmVerificationError {
|
||||
#[error("Solidity verifier found the proof invalid")]
|
||||
InvalidProof,
|
||||
/// If the Solidity verifier threw and error (e.g. OutOfGas)
|
||||
#[error("Execution of Solidity code failed")]
|
||||
SolidityExecution,
|
||||
/// EVM execution errors
|
||||
#[error("EVM execution of raw code failed")]
|
||||
RawExecution,
|
||||
#[error("Execution of Solidity code failed: {0}")]
|
||||
SolidityExecution(String),
|
||||
/// EVM verify errors
|
||||
#[error("evm verification reverted")]
|
||||
Reverted,
|
||||
#[error("evm verification reverted: {0}")]
|
||||
Reverted(String),
|
||||
/// EVM verify errors
|
||||
#[error("evm deployment failed")]
|
||||
Deploy,
|
||||
/// Invalid Visibilit
|
||||
#[error("evm deployment failed: {0}")]
|
||||
DeploymentFailed(String),
|
||||
/// Invalid Visibility
|
||||
#[error("Invalid visibility")]
|
||||
InvalidVisibility,
|
||||
}
|
||||
|
||||
101
src/pfsys/mod.rs
101
src/pfsys/mod.rs
@@ -4,6 +4,11 @@ pub mod evm;
|
||||
/// SRS generation, processing, verification and downloading
|
||||
pub mod srs;
|
||||
|
||||
/// errors related to pfsys
|
||||
pub mod errors;
|
||||
|
||||
pub use errors::PfsysError;
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::graph::GraphWitness;
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
@@ -32,7 +37,6 @@ use serde::{Deserialize, Serialize};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::verifier::plonk::PlonkProtocol;
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, BufWriter, Cursor, Write};
|
||||
use std::ops::Deref;
|
||||
@@ -364,24 +368,28 @@ where
|
||||
}
|
||||
|
||||
/// Saves the Proof to a specified `proof_path`.
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box<dyn Error>> {
|
||||
let file = std::fs::File::create(proof_path)?;
|
||||
pub fn save(&self, proof_path: &PathBuf) -> Result<(), PfsysError> {
|
||||
let file = std::fs::File::create(proof_path)
|
||||
.map_err(|e| PfsysError::SaveProof(format!("{}", e)))?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(&mut writer, &self)?;
|
||||
serde_json::to_writer(&mut writer, &self)
|
||||
.map_err(|e| PfsysError::SaveProof(format!("{}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a json serialized proof from the provided path.
|
||||
pub fn load<Scheme: CommitmentScheme<Curve = C, Scalar = F>>(
|
||||
proof_path: &PathBuf,
|
||||
) -> Result<Self, Box<dyn Error>>
|
||||
) -> Result<Self, PfsysError>
|
||||
where
|
||||
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
|
||||
{
|
||||
trace!("reading proof");
|
||||
let file = std::fs::File::open(proof_path)?;
|
||||
let file =
|
||||
std::fs::File::open(proof_path).map_err(|e| PfsysError::LoadProof(format!("{}", e)))?;
|
||||
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let proof: Self = serde_json::from_reader(reader)?;
|
||||
let proof: Self =
|
||||
serde_json::from_reader(reader).map_err(|e| PfsysError::LoadProof(format!("{}", e)))?;
|
||||
Ok(proof)
|
||||
}
|
||||
}
|
||||
@@ -541,7 +549,7 @@ pub fn create_proof_circuit<
|
||||
transcript_type: TranscriptType,
|
||||
split: Option<ProofSplitCommit>,
|
||||
protocol: Option<PlonkProtocol<Scheme::Curve>>,
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, Box<dyn Error>>
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, PfsysError>
|
||||
where
|
||||
Scheme::ParamsVerifier: 'params,
|
||||
Scheme::Scalar: Serialize
|
||||
@@ -626,7 +634,35 @@ pub fn swap_proof_commitments<
|
||||
>(
|
||||
snark: &Snark<Scheme::Scalar, Scheme::Curve>,
|
||||
commitments: &[Scheme::Curve],
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, Box<dyn Error>>
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, PfsysError>
|
||||
where
|
||||
Scheme::Scalar: SerdeObject
|
||||
+ PrimeField
|
||||
+ FromUniformBytes<64>
|
||||
+ WithSmallOrderMulGroup<3>
|
||||
+ Ord
|
||||
+ Serialize
|
||||
+ DeserializeOwned,
|
||||
Scheme::Curve: Serialize + DeserializeOwned,
|
||||
{
|
||||
let proof_first_bytes = get_proof_commitments::<Scheme, E, TW>(commitments)?;
|
||||
|
||||
let mut snark_new = snark.clone();
|
||||
// swap the proof bytes for the new ones
|
||||
snark_new.proof[..proof_first_bytes.len()].copy_from_slice(&proof_first_bytes);
|
||||
snark_new.create_hex_proof();
|
||||
|
||||
Ok(snark_new)
|
||||
}
|
||||
|
||||
/// Returns the bytes encoded proof commitments
|
||||
pub fn get_proof_commitments<
|
||||
Scheme: CommitmentScheme,
|
||||
E: EncodedChallenge<Scheme::Curve>,
|
||||
TW: TranscriptWriterBuffer<Vec<u8>, Scheme::Curve, E>,
|
||||
>(
|
||||
commitments: &[Scheme::Curve],
|
||||
) -> Result<Vec<u8>, PfsysError>
|
||||
where
|
||||
Scheme::Scalar: SerdeObject
|
||||
+ PrimeField
|
||||
@@ -639,28 +675,27 @@ where
|
||||
{
|
||||
let mut transcript_new: TW = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
|
||||
|
||||
// polycommit commitments are the first set of points in the proof, this we'll always be the first set of advice
|
||||
// polycommit commitments are the first set of points in the proof, this will always be the first set of advice
|
||||
for commit in commitments {
|
||||
transcript_new
|
||||
.write_point(*commit)
|
||||
.map_err(|_| "failed to write point")?;
|
||||
.map_err(|e| PfsysError::WritePoint(format!("{}", e)))?;
|
||||
}
|
||||
|
||||
let proof_first_bytes = transcript_new.finalize();
|
||||
|
||||
let mut snark_new = snark.clone();
|
||||
// swap the proof bytes for the new ones
|
||||
snark_new.proof[..proof_first_bytes.len()].copy_from_slice(&proof_first_bytes);
|
||||
snark_new.create_hex_proof();
|
||||
if commitments.is_empty() {
|
||||
log::warn!("no commitments found in witness");
|
||||
}
|
||||
|
||||
Ok(snark_new)
|
||||
Ok(proof_first_bytes)
|
||||
}
|
||||
|
||||
/// Swap the proof commitments to a new set in the proof for KZG
|
||||
pub fn swap_proof_commitments_polycommit(
|
||||
snark: &Snark<Fr, G1Affine>,
|
||||
commitments: &[G1Affine],
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
) -> Result<Snark<Fr, G1Affine>, PfsysError> {
|
||||
let proof = match snark.commitment {
|
||||
Some(Commitments::KZG) => match snark.transcript_type {
|
||||
TranscriptType::EVM => swap_proof_commitments::<
|
||||
@@ -687,7 +722,7 @@ pub fn swap_proof_commitments_polycommit(
|
||||
>(snark, commitments)?,
|
||||
},
|
||||
None => {
|
||||
return Err("commitment scheme not found".into());
|
||||
return Err(PfsysError::InvalidCommitmentScheme);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -734,22 +769,22 @@ where
|
||||
pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
path: PathBuf,
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, Box<dyn Error>>
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
info!("loading verification key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
|
||||
debug!("loading verification key from {:?}", path);
|
||||
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadVk(format!("{}", e)))?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
params,
|
||||
)?;
|
||||
info!("done loading verification key ✅");
|
||||
)
|
||||
.map_err(|e| PfsysError::LoadVk(format!("{}", e)))?;
|
||||
info!("loaded verification key ✅");
|
||||
Ok(vk)
|
||||
}
|
||||
|
||||
@@ -757,22 +792,22 @@ where
|
||||
pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
path: PathBuf,
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, Box<dyn Error>>
|
||||
) -> Result<ProvingKey<Scheme::Curve>, PfsysError>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
{
|
||||
info!("loading proving key from {:?}", path);
|
||||
let f =
|
||||
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
|
||||
debug!("loading proving key from {:?}", path);
|
||||
let f = File::open(path.clone()).map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
|
||||
&mut reader,
|
||||
serde_format_from_str(&EZKL_KEY_FORMAT),
|
||||
params,
|
||||
)?;
|
||||
info!("done loading proving key ✅");
|
||||
)
|
||||
.map_err(|e| PfsysError::LoadPk(format!("{}", e)))?;
|
||||
info!("loaded proving key ✅");
|
||||
Ok(pk)
|
||||
}
|
||||
|
||||
@@ -784,7 +819,7 @@ pub fn save_pk<C: SerdeObject + CurveAffine>(
|
||||
where
|
||||
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
|
||||
{
|
||||
info!("saving proving key 💾");
|
||||
debug!("saving proving key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
@@ -801,7 +836,7 @@ pub fn save_vk<C: CurveAffine + SerdeObject>(
|
||||
where
|
||||
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
|
||||
{
|
||||
info!("saving verification key 💾");
|
||||
debug!("saving verification key 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
|
||||
@@ -815,7 +850,7 @@ pub fn save_params<Scheme: CommitmentScheme>(
|
||||
path: &PathBuf,
|
||||
params: &'_ Scheme::ParamsVerifier,
|
||||
) -> Result<(), io::Error> {
|
||||
info!("saving parameters 💾");
|
||||
debug!("saving parameters 💾");
|
||||
let f = File::create(path)?;
|
||||
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
params.write(&mut writer)?;
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use halo2_proofs::poly::commitment::CommitmentScheme;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
use log::info;
|
||||
use std::error::Error;
|
||||
use log::debug;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::path::PathBuf;
|
||||
@@ -16,24 +15,33 @@ pub fn gen_srs<Scheme: CommitmentScheme>(k: u32) -> Scheme::ParamsProver {
|
||||
Scheme::ParamsProver::new(k)
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum SrsError {
|
||||
#[error("failed to download srs from {0}")]
|
||||
DownloadError(String),
|
||||
#[error("failed to load srs from {0}")]
|
||||
LoadError(PathBuf),
|
||||
#[error("failed to read srs {0}")]
|
||||
ReadError(String),
|
||||
}
|
||||
|
||||
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
|
||||
pub fn load_srs_verifier<Scheme: CommitmentScheme>(
|
||||
path: PathBuf,
|
||||
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
|
||||
info!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load srs at {}", path.display()))?;
|
||||
) -> Result<Scheme::ParamsVerifier, SrsError> {
|
||||
debug!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone()).map_err(|_| SrsError::LoadError(path))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(|e| SrsError::ReadError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
|
||||
pub fn load_srs_prover<Scheme: CommitmentScheme>(
|
||||
path: PathBuf,
|
||||
) -> Result<Scheme::ParamsProver, Box<dyn Error>> {
|
||||
info!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load srs at {}", path.display()))?;
|
||||
) -> Result<Scheme::ParamsProver, SrsError> {
|
||||
debug!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone()).map_err(|_| SrsError::LoadError(path.clone()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(|e| SrsError::ReadError(e.to_string()))
|
||||
}
|
||||
|
||||
@@ -466,7 +466,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -1430,6 +1430,47 @@ fn verify_aggr(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Creates encoded evm calldata from a proof file
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// proof: str
|
||||
/// Path to the proof file
|
||||
///
|
||||
/// calldata: str
|
||||
/// Path to the calldata file to save
|
||||
///
|
||||
/// addr_vk: str
|
||||
/// The address of the verification key contract (if the verifier key is to be rendered as a separate contract)
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// vec[u8]
|
||||
/// The encoded calldata
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
proof=PathBuf::from(DEFAULT_PROOF),
|
||||
calldata=PathBuf::from(DEFAULT_CALLDATA),
|
||||
addr_vk=None,
|
||||
))]
|
||||
fn encode_evm_calldata<'a>(
|
||||
proof: PathBuf,
|
||||
calldata: PathBuf,
|
||||
addr_vk: Option<&'a str>,
|
||||
) -> Result<Vec<u8>, PyErr> {
|
||||
let addr_vk = if let Some(addr_vk) = addr_vk {
|
||||
let addr_vk = H160Flag::from(addr_vk);
|
||||
Some(addr_vk)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
crate::execute::encode_evm_calldata(proof, calldata, addr_vk).map_err(|e| {
|
||||
let err_str = format!("Failed to generate calldata: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an EVM compatible verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
@@ -1517,6 +1558,7 @@ fn create_evm_verifier(
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
|
||||
witness_path=None,
|
||||
))]
|
||||
fn create_evm_data_attestation(
|
||||
py: Python,
|
||||
@@ -1524,6 +1566,7 @@ fn create_evm_data_attestation(
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
witness_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_data_attestation(
|
||||
@@ -1531,6 +1574,7 @@ fn create_evm_data_attestation(
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
input_data,
|
||||
witness_path,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1888,6 +1932,6 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
|
||||
|
||||
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
30
src/tensor/errors.rs
Normal file
30
src/tensor/errors.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorError {
|
||||
/// Shape mismatch in a operation
|
||||
#[error("dimension mismatch in tensor op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Shape when instantiating
|
||||
#[error("dimensionality error when manipulating a tensor: {0}")]
|
||||
DimError(String),
|
||||
/// wrong method was called on a tensor-like struct
|
||||
#[error("wrong method called")]
|
||||
WrongMethod,
|
||||
/// Significant bit truncation when instantiating
|
||||
#[error("significant bit truncation when instantiating, try lowering the scale")]
|
||||
SigBitTruncationError,
|
||||
/// Failed to convert to field element tensor
|
||||
#[error("failed to convert to field element tensor")]
|
||||
FeltError,
|
||||
/// Unsupported operation
|
||||
#[error("unsupported operation on a tensor type")]
|
||||
Unsupported,
|
||||
/// Overflow
|
||||
#[error("unsigned integer overflow or underflow error in op: {0}")]
|
||||
Overflow(String),
|
||||
/// Unset visibility
|
||||
#[error("unset visibility")]
|
||||
UnsetVisibility,
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
/// Tensor related errors.
|
||||
pub mod errors;
|
||||
/// Implementations of common operations on tensors.
|
||||
pub mod ops;
|
||||
/// A wrapper around a tensor of circuit variables / advices.
|
||||
@@ -5,6 +7,8 @@ pub mod val;
|
||||
/// A wrapper around a tensor of Halo2 Value types.
|
||||
pub mod var;
|
||||
|
||||
pub use errors::TensorError;
|
||||
|
||||
use halo2curves::{bn256::Fr, ff::PrimeField};
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
@@ -40,40 +44,10 @@ use std::fmt::Debug;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
use thiserror::Error;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorError {
|
||||
/// Shape mismatch in a operation
|
||||
#[error("dimension mismatch in tensor op: {0}")]
|
||||
DimMismatch(String),
|
||||
/// Shape when instantiating
|
||||
#[error("dimensionality error when manipulating a tensor: {0}")]
|
||||
DimError(String),
|
||||
/// wrong method was called on a tensor-like struct
|
||||
#[error("wrong method called")]
|
||||
WrongMethod,
|
||||
/// Significant bit truncation when instantiating
|
||||
#[error("Significant bit truncation when instantiating, try lowering the scale")]
|
||||
SigBitTruncationError,
|
||||
/// Failed to convert to field element tensor
|
||||
#[error("Failed to convert to field element tensor")]
|
||||
FeltError,
|
||||
/// Table lookup error
|
||||
#[error("Table lookup error")]
|
||||
TableLookupError,
|
||||
/// Unsupported operation
|
||||
#[error("Unsupported operation on a tensor type")]
|
||||
Unsupported,
|
||||
/// Overflow
|
||||
#[error("Unsigned integer overflow or underflow error in op: {0}")]
|
||||
Overflow(String),
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
|
||||
|
||||
@@ -400,9 +374,7 @@ impl IntoI64 for () {
|
||||
fn into_i64(self) -> i64 {
|
||||
0
|
||||
}
|
||||
fn from_i64(_: i64) -> Self {
|
||||
|
||||
}
|
||||
fn from_i64(_: i64) -> Self {}
|
||||
}
|
||||
|
||||
impl IntoI64 for Fr {
|
||||
@@ -1852,7 +1824,7 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
pub fn get_broadcasted_shape(
|
||||
shape_a: &[usize],
|
||||
shape_b: &[usize],
|
||||
) -> Result<Vec<usize>, Box<dyn Error>> {
|
||||
) -> Result<Vec<usize>, TensorError> {
|
||||
let num_dims_a = shape_a.len();
|
||||
let num_dims_b = shape_b.len();
|
||||
|
||||
@@ -1867,9 +1839,9 @@ pub fn get_broadcasted_shape(
|
||||
}
|
||||
(a, b) if a < b => Ok(shape_b.to_vec()),
|
||||
(a, b) if a > b => Ok(shape_a.to_vec()),
|
||||
_ => Err(Box::new(TensorError::DimError(
|
||||
_ => Err(TensorError::DimError(
|
||||
"Unknown condition for broadcasting".to_string(),
|
||||
))),
|
||||
)),
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
|
||||
@@ -256,23 +256,23 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> TryFrom<Tensor<F>> for ValTensor<F> {
|
||||
type Error = Box<dyn Error>;
|
||||
fn try_from(t: Tensor<F>) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
type Error = TensorError;
|
||||
fn try_from(t: Tensor<F>) -> Result<ValTensor<F>, TensorError> {
|
||||
let visibility = t.visibility.clone();
|
||||
let dims = t.dims().to_vec();
|
||||
let inner = t.into_iter().map(|x| {
|
||||
if let Some(vis) = &visibility {
|
||||
match vis {
|
||||
Visibility::Fixed => Ok(ValType::Constant(x)),
|
||||
_ => {
|
||||
Ok(Value::known(x).into())
|
||||
let inner = t
|
||||
.into_iter()
|
||||
.map(|x| {
|
||||
if let Some(vis) = &visibility {
|
||||
match vis {
|
||||
Visibility::Fixed => Ok(ValType::Constant(x)),
|
||||
_ => Ok(Value::known(x).into()),
|
||||
}
|
||||
} else {
|
||||
Err(TensorError::UnsetVisibility)
|
||||
}
|
||||
}
|
||||
else {
|
||||
Err("visibility should be set to convert a tensor of field elements to a ValTensor.".into())
|
||||
}
|
||||
}).collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
})
|
||||
.collect::<Result<Vec<_>, TensorError>>()?;
|
||||
|
||||
let mut inner: Tensor<ValType<F>> = inner.into_iter().into();
|
||||
inner.reshape(&dims)?;
|
||||
@@ -378,13 +378,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// reverse order of elements whilst preserving the shape
|
||||
pub fn reverse(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
pub fn reverse(&mut self) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
v.reverse();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
@@ -420,7 +420,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn any_unknowns(&self) -> Result<bool, Box<dyn Error>> {
|
||||
pub fn any_unknowns(&self) -> Result<bool, TensorError> {
|
||||
match self {
|
||||
ValTensor::Instance { .. } => Ok(true),
|
||||
_ => Ok(self.get_inner()?.iter().any(|&x| {
|
||||
@@ -491,7 +491,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Fetch the underlying [Tensor] of field elements.
|
||||
pub fn get_felt_evals(&self) -> Result<Tensor<F>, Box<dyn Error>> {
|
||||
pub fn get_felt_evals(&self) -> Result<Tensor<F>, TensorError> {
|
||||
let mut felt_evals: Vec<F> = vec![];
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
@@ -504,7 +504,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
|
||||
let mut res: Tensor<F> = felt_evals.into_iter().into();
|
||||
@@ -521,7 +521,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `int_evals` on the inner tensor.
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i64>, Box<dyn Error>> {
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i64>, TensorError> {
|
||||
// finally convert to vector of integers
|
||||
let mut integer_evals: Vec<i64> = vec![];
|
||||
match self {
|
||||
@@ -547,7 +547,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
@@ -558,7 +558,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -567,14 +567,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
|
||||
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
@@ -592,13 +592,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
scale: *scale,
|
||||
}
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
Ok(slice)
|
||||
}
|
||||
|
||||
/// Calls `get_single_elem` on the inner tensor.
|
||||
pub fn get_single_elem(&self, index: usize) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
pub fn get_single_elem(&self, index: usize) -> Result<ValTensor<F>, TensorError> {
|
||||
let slice = match self {
|
||||
ValTensor::Value {
|
||||
inner: v,
|
||||
@@ -612,7 +612,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
scale: *scale,
|
||||
}
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
Ok(slice)
|
||||
}
|
||||
@@ -648,7 +648,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
})
|
||||
}
|
||||
/// Calls `expand` on the inner tensor.
|
||||
pub fn expand(&mut self, dims: &[usize]) -> Result<(), Box<dyn Error>> {
|
||||
pub fn expand(&mut self, dims: &[usize]) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -657,14 +657,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `move_axis` on the inner tensor.
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<(), Box<dyn Error>> {
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -673,14 +673,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets the [ValTensor]'s shape.
|
||||
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), Box<dyn Error>> {
|
||||
pub fn reshape(&mut self, new_dims: &[usize]) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -690,10 +690,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
ValTensor::Instance { dims: d, idx, .. } => {
|
||||
if d[*idx].iter().product::<usize>() != new_dims.iter().product::<usize>() {
|
||||
return Err(Box::new(TensorError::DimError(format!(
|
||||
return Err(TensorError::DimError(format!(
|
||||
"Cannot reshape {:?} to {:?} as they have number of elements",
|
||||
d[*idx], new_dims
|
||||
))));
|
||||
)));
|
||||
}
|
||||
d[*idx] = new_dims.to_vec();
|
||||
}
|
||||
@@ -702,12 +702,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Sets the [ValTensor]'s shape.
|
||||
pub fn slice(
|
||||
&mut self,
|
||||
axis: &usize,
|
||||
start: &usize,
|
||||
end: &usize,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
pub fn slice(&mut self, axis: &usize, start: &usize, end: &usize) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
@@ -716,7 +711,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
@@ -982,7 +977,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
/// inverts the inner values
|
||||
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
pub fn inverse(&self) -> Result<ValTensor<F>, TensorError> {
|
||||
let mut cloned_self = self.clone();
|
||||
|
||||
match &mut cloned_self {
|
||||
@@ -1000,7 +995,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
return Err(TensorError::WrongMethod);
|
||||
}
|
||||
};
|
||||
Ok(cloned_self)
|
||||
|
||||
@@ -31,6 +31,15 @@ pub enum VarTensor {
|
||||
}
|
||||
|
||||
impl VarTensor {
|
||||
/// name of the tensor
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
VarTensor::Advice { .. } => "Advice",
|
||||
VarTensor::Dummy { .. } => "Dummy",
|
||||
VarTensor::Empty => "Empty",
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn is_advice(&self) -> bool {
|
||||
matches!(self, VarTensor::Advice { .. })
|
||||
|
||||
@@ -1015,7 +1015,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(true, hardfork);
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "file", "public", "private");
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "file", "public", "private", "private");
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1025,7 +1025,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(true, Hardfork::Latest);
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "on-chain", "private", "public");
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "on-chain", "private", "public", "private");
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1035,7 +1035,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(true, Hardfork::Latest);
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "on-chain", "public", "public");
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "on-chain", "public", "public", "private");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1045,7 +1045,25 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(true, Hardfork::Latest);
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "on-chain", "hashed", "hashed");
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "on-chain", "hashed", "hashed", "private");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
#(#[test_case(TESTS_ON_CHAIN_INPUT[N])])*
|
||||
fn kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(true, Hardfork::Latest);
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "on-chain", "file", "public", "polycommit", "polycommit");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
#(#[test_case(TESTS_ON_CHAIN_INPUT[N])])*
|
||||
fn kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(true, Hardfork::Latest);
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "on-chain", "polycommit", "public", "polycommit");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1814,6 +1832,18 @@ mod native_tests {
|
||||
let settings_arg = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
let private_key = format!("--private-key={}", *ANVIL_DEFAULT_PRIVATE_KEY);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
&format!("{}/{}/aggr.pf", test_dir, example_name),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
let base_args = vec![
|
||||
"create-evm-verifier-aggr",
|
||||
"--vk-path",
|
||||
@@ -2036,6 +2066,18 @@ mod native_tests {
|
||||
let addr_path_arg = format!("--addr-path={}/{}/addr.txt", test_dir, example_name);
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
&format!("{}/{}/proof.pf", test_dir, example_name),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
// create the verifier
|
||||
let mut args = vec!["create-evm-verifier", "--vk-path", &vk_arg, &settings_arg];
|
||||
|
||||
@@ -2205,6 +2247,19 @@ mod native_tests {
|
||||
|
||||
let deployed_addr_arg_vk = format!("--addr-vk={}", addr_vk);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
&format!("{}/{}/proof.pf", test_dir, example_name),
|
||||
&deployed_addr_arg_vk,
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
// now verify the proof
|
||||
let pf_arg = format!("{}/{}/proof.pf", test_dir, example_name);
|
||||
let mut args = vec![
|
||||
@@ -2254,12 +2309,13 @@ mod native_tests {
|
||||
output_source: &str,
|
||||
input_visibility: &str,
|
||||
output_visibility: &str,
|
||||
param_visibility: &str,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
example_name.clone(),
|
||||
input_visibility,
|
||||
"private",
|
||||
param_visibility,
|
||||
output_visibility,
|
||||
1,
|
||||
"resources",
|
||||
@@ -2376,6 +2432,18 @@ mod native_tests {
|
||||
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
|
||||
// create encoded calldata
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"encode-evm-calldata",
|
||||
"--proof-path",
|
||||
&format!("{}/{}/proof.pf", test_dir, example_name),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
// create the verifier
|
||||
let mut args = vec!["create-evm-verifier", "--vk-path", &vk_arg, &settings_arg];
|
||||
|
||||
@@ -2413,15 +2481,23 @@ mod native_tests {
|
||||
|
||||
let sol_arg = format!("{}/{}/kzg.sol", test_dir, example_name);
|
||||
|
||||
let mut create_da_args = vec![
|
||||
"create-evm-da",
|
||||
&settings_arg,
|
||||
"--sol-code-path",
|
||||
sol_arg.as_str(),
|
||||
"-W",
|
||||
&witness_path,
|
||||
];
|
||||
|
||||
// if there is a on-chain source we add the data
|
||||
if input_source != "file" || output_source != "file" {
|
||||
create_da_args.push("-D");
|
||||
create_da_args.push(test_on_chain_data_path.as_str());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"create-evm-da",
|
||||
&settings_arg,
|
||||
"--sol-code-path",
|
||||
sol_arg.as_str(),
|
||||
"-D",
|
||||
test_on_chain_data_path.as_str(),
|
||||
])
|
||||
.args(&create_da_args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
@@ -402,6 +402,15 @@ async def test_create_evm_verifier():
|
||||
settings_path = os.path.join(folder_path, 'settings.json')
|
||||
sol_code_path = os.path.join(folder_path, 'test.sol')
|
||||
abi_path = os.path.join(folder_path, 'test.abi')
|
||||
proof_path = os.path.join(folder_path, 'test_evm.pf')
|
||||
calldata_path = os.path.join(folder_path, 'calldata.bytes')
|
||||
|
||||
# res is now a vector of bytes
|
||||
res = ezkl.encode_evm_calldata(proof_path, calldata_path)
|
||||
|
||||
assert os.path.isfile(calldata_path)
|
||||
assert len(res) > 0
|
||||
|
||||
|
||||
res = await ezkl.create_evm_verifier(
|
||||
vk_path,
|
||||
|
||||
Reference in New Issue
Block a user