Compare commits

...

24 Commits

Author SHA1 Message Date
dante
a5bf64b1a2 feat!: ipa commitments (#740)
BREAKING CHANGE: commitment is now an added flag
2024-03-16 16:31:01 +00:00
Ethan Cemer
56e2326be1 *nuke (#742) 2024-03-14 14:11:03 -05:00
Ethan Cemer
2be181db35 feat: merge @ezkljs/verify package into core repo. (#736) 2024-03-14 01:13:14 +00:00
jmjac
de9e3f2673 Add __version__ to python bindings (#739) 2024-03-13 14:22:20 +00:00
dante
a1450f8df7 feat: gather_nd/scatter_nd support (#737) 2024-03-11 22:05:40 +00:00
dante
ea535e2ecd refactor: use linear index constraints for gather and scatter (#735) 2024-03-09 18:00:21 +00:00
Alexander Camuto
f8aa91ed08 fix: windows compile 2024-03-06 11:40:44 +00:00
dante
a59e3780b2 chore: rm recip_int helper (#733) 2024-03-05 21:51:14 +00:00
dante
345fb5672a chore: cleanup unused args (#732) 2024-03-05 13:43:29 +00:00
dante
70daaff2e4 chore: cleanup calibrate (#731) 2024-03-04 17:52:11 +00:00
dante
a437d8a51f feat: "sub"-dynamic tables (#730) 2024-03-04 10:35:28 +00:00
Ethan Cemer
fe535c1cac feat: wasm felt to little endian string (#729)
---------

Co-authored-by: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com>
2024-03-01 14:06:20 +00:00
dante
3e8dcb001a chore: test for reduced-srs on wasm bundle (#728)
---------

Co-authored-by: Ethan <tylercemer@gmail.com>
2024-03-01 13:23:07 +00:00
dante
14786acb95 feat: dynamic lookups (#727) 2024-03-01 01:44:45 +00:00
dante
80a3c44cb4 feat: lookup-less recip by default (#725) 2024-02-28 16:35:20 +00:00
dante
1656846d1a fix: transcript should serialize as lc flag (#726) 2024-02-26 22:02:47 +00:00
dante
88098b8190 fix!: cleanup felt serialization language in python and wasm (#724)
BREAKING CHANGE: python and wasm felt utilities have new names
2024-02-25 14:06:48 +00:00
dante
6c0c17c9be fix: include tol check in fwd pass (#723) 2024-02-23 01:28:59 +00:00
dante
bf69b16fc1 fix: rm optional bool flags (#722) 2024-02-21 12:45:42 +00:00
dante
74feb829da feat: parse command ast into flag strings (#720) 2024-02-21 00:38:26 +00:00
dante
d429e7edab fix: buffer data read and writes (#719) 2024-02-19 11:49:15 +00:00
dante
f0e5b82787 refactor: selectable key ser (#718) 2024-02-19 11:26:18 +00:00
dante
3f7261f50b fix: set buf capacity for witness, settings, proof (#717) 2024-02-16 21:59:20 +00:00
dante
678a249dcb feat: allow for reduced n srs for verification (#716) 2024-02-16 18:28:54 +00:00
97 changed files with 7779 additions and 3354 deletions

View File

@@ -1,4 +1,4 @@
name: Build and Publish WASM<>JS Bindings
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
@@ -14,7 +14,7 @@ defaults:
run:
working-directory: .
jobs:
wasm-publish:
publish-wasm-bindings:
name: publish-wasm-bindings
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
@@ -174,3 +174,40 @@ jobs:
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
needs: ["publish-wasm-bindings"]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
npm install
npm run build
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -236,6 +236,8 @@ jobs:
with:
crate: cargo-nextest
locked: true
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
- name: kzg inputs
@@ -301,12 +303,24 @@ jobs:
with:
node-version: "18.12.1"
cache: "pnpm"
- name: Install dependencies
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --no-frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
env:
CI: false
NODE_ENV: development
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
- name: Build @ezkljs/verify package
run: |
cd in-browser-evm-verifier
pnpm build:commonjs
cd ..
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
@@ -362,7 +376,7 @@ jobs:
with:
node-version: "18.12.1"
cache: "pnpm"
- name: Install dependencies
- name: Install dependencies for js tests
run: |
pnpm install --no-frozen-lockfile
env:
@@ -378,6 +392,10 @@ jobs:
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: IPA prove and verify tests
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
- name: IPA prove and verify tests (ipa outputs)
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests single inner col
@@ -438,28 +456,6 @@ jobs:
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
fuzz-tests:
runs-on: ubuntu-latest-32-cores
needs: [build, library-tests, python-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: fuzz tests (EVM)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
# - name: fuzz tests
# run: cargo nextest run --release --verbose tests::kzg_fuzz_ --test-threads 6
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
@@ -475,7 +471,7 @@ jobs:
with:
crate: cargo-nextest
locked: true
- name: Mock aggr tests
- name: Mock aggr tests (KZG)
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
prove-and-verify-aggr-tests-gpu:
@@ -498,7 +494,7 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests, python-tests]
needs: [build, library-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -510,12 +506,14 @@ jobs:
with:
crate: cargo-nextest
locked: true
- name: KZG )tests
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 8 -- --include-ignored
- name: KZG tests
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests, python-tests]
needs: [build, library-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -573,7 +571,7 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
run: source .env/bin/activate; pytest -vv
@@ -597,7 +595,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
@@ -615,7 +613,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
@@ -632,7 +630,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -643,12 +641,12 @@ jobs:
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --no-capture
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --no-capture
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

1
.gitignore vendored
View File

@@ -45,6 +45,7 @@ var/
*.whl
*.bak
node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key

166
Cargo.lock generated
View File

@@ -112,7 +112,7 @@ checksum = "c0391754c09fab4eae3404d19d0d297aa1c670c1775ab51d8a5312afeca23157"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -464,7 +464,7 @@ dependencies = [
"proc-macro2",
"quote",
"serde",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -490,7 +490,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -843,7 +843,7 @@ dependencies = [
"anstyle",
"bitflags 1.3.2",
"clap_lex",
"strsim 0.10.0",
"strsim",
]
[[package]]
@@ -855,7 +855,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -1191,41 +1191,6 @@ dependencies = [
"cuda-config",
]
[[package]]
name = "darling"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.9.3",
"syn 1.0.109",
]
[[package]]
name = "darling_macro"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72"
dependencies = [
"darling_core",
"quote",
"syn 1.0.109",
]
[[package]]
name = "der"
version = "0.7.6"
@@ -1258,31 +1223,6 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "derive_builder"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0"
dependencies = [
"darling",
"derive_builder_core",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "derive_builder_core"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "derive_more"
version = "0.99.17"
@@ -1491,7 +1431,7 @@ checksum = "48016319042fb7c87b78d2993084a831793a897a5cd1a2a67cab9d1eeb4b7d76"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -1657,7 +1597,7 @@ dependencies = [
"regex",
"serde",
"serde_json",
"syn 2.0.48",
"syn 2.0.50",
"toml",
"walkdir",
]
@@ -1675,7 +1615,7 @@ dependencies = [
"proc-macro2",
"quote",
"serde_json",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -1701,7 +1641,7 @@ dependencies = [
"serde",
"serde_json",
"strum",
"syn 2.0.48",
"syn 2.0.50",
"tempfile",
"thiserror",
"tiny-keccak",
@@ -1902,6 +1842,7 @@ dependencies = [
"thiserror",
"tokio",
"tokio-util",
"tosubcommand",
"tract-onnx",
"unzip-n",
"wasm-bindgen",
@@ -2103,7 +2044,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -2262,7 +2203,7 @@ dependencies = [
[[package]]
name = "halo2_gadgets"
version = "0.2.0"
source = "git+https://github.com/zkonduit/halo2?branch=main#ca603c14eb57030739b252e580a979023fa59040"
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
dependencies = [
"arrayvec 0.7.4",
"bitvec 1.0.1",
@@ -2279,7 +2220,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2?branch=main#ca603c14eb57030739b252e580a979023fa59040"
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
dependencies = [
"blake2b_simd",
"env_logger",
@@ -2289,12 +2230,10 @@ dependencies = [
"icicle",
"log",
"maybe-rayon",
"plotters",
"rand_chacha",
"rand_core 0.6.4",
"rustacuda",
"sha3 0.9.1",
"tabbycat",
"tracing",
]
@@ -2623,12 +2562,6 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.4.0"
@@ -2975,7 +2908,7 @@ checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -3285,7 +3218,7 @@ dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -3369,7 +3302,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -3570,7 +3503,7 @@ dependencies = [
"pest_meta",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -3647,7 +3580,7 @@ dependencies = [
"phf_shared 0.11.2",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -3685,7 +3618,7 @@ checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -3821,7 +3754,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9825a04601d60621feed79c4e6b56d65db77cdca55cef43b46b0de1096d1c282"
dependencies = [
"proc-macro2",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -4011,7 +3944,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -4023,7 +3956,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -4735,7 +4668,7 @@ checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -5013,12 +4946,6 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "strsim"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
[[package]]
name = "strsim"
version = "0.10.0"
@@ -5044,7 +4971,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -5079,9 +5006,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.48"
version = "2.0.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f"
checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb"
dependencies = [
"proc-macro2",
"quote",
@@ -5109,17 +5036,6 @@ dependencies = [
"libc",
]
[[package]]
name = "tabbycat"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c45590f0f859197b4545be1b17b2bc3cc7bb075f7d1cc0ea1dc6521c0bf256a3"
dependencies = [
"anyhow",
"derive_builder",
"regex",
]
[[package]]
name = "tabled"
version = "0.12.2"
@@ -5259,7 +5175,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -5348,7 +5264,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -5444,6 +5360,24 @@ dependencies = [
"winnow 0.5.39",
]
[[package]]
name = "tosubcommand"
version = "0.1.0"
source = "git+https://github.com/zkonduit/enum_to_subcommand#42e9870f1f757932bab64ab30ebf1ff08a392265"
dependencies = [
"tosubcommand_derive",
]
[[package]]
name = "tosubcommand_derive"
version = "0.1.0"
source = "git+https://github.com/zkonduit/enum_to_subcommand#42e9870f1f757932bab64ab30ebf1ff08a392265"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.50",
]
[[package]]
name = "tower-service"
version = "0.3.2"
@@ -5470,7 +5404,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]
[[package]]
@@ -5854,7 +5788,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
"wasm-bindgen-shared",
]
@@ -5898,7 +5832,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -6309,5 +6243,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.50",
]

View File

@@ -15,71 +15,96 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "main" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
"derive_serde",
] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.3.3", features = ["derive"]}
clap = { version = "4.3.3", features = ["derive"] }
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = ["float_roundtrip", "raw_value"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = [
"float_roundtrip",
"raw_value",
], optional = true }
log = { version = "0.4.17", default_features = false, optional = true }
thiserror = { version = "1.0.38", default_features = false }
hex = { version = "0.4.3", default_features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features=["derive_serde"]}
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
indicatif = {version = "0.17.5", features = ["rayon"]}
gag = { version = "1.0.0", default_features = false}
ethers = { version = "2.0.11", default_features = false, features = [
"ethers-solc",
] }
indicatif = { version = "0.17.5", features = ["rayon"] }
gag = { version = "1.0.0", default_features = false }
instant = { version = "0.1" }
reqwest = { version = "0.11.14", default-features = false, features = ["default-tls", "multipart", "stream"] }
reqwest = { version = "0.11.14", default-features = false, features = [
"default-tls",
"multipart",
"stream",
] }
openssl = { version = "0.10.55", features = ["vendored"] }
postgres = "0.19.5"
pg_bigdecimal = "0.1.5"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true}
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
tokio = { version = "1.26.0", default_features = false, features = [
"macros",
"rt",
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3 = { version = "0.20.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = [
"attributes",
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true}
env_logger = { version = "0.10.0", default_features = false, optional = true}
colored = { version = "2.0.0", default_features = false, optional = true }
env_logger = { version = "0.10.0", default_features = false, optional = true }
chrono = "0.4.31"
sha256 = "1.4.0"
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = [ "wasm-bindgen", "inaccurate" ] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional=true }
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"]}
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[dev-dependencies]
criterion = {version = "0.3", features = ["html_reports"]}
criterion = { version = "0.3", features = ["html_reports"] }
tempfile = "3.3.0"
lazy_static = "1.4.0"
mnist = "0.5"
@@ -151,18 +176,32 @@ required-features = ["ezkl"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup"]
render = ["halo2_proofs/dev-graph", "plotters"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled/color", "colored_json", "halo2_proofs/circuit-params"]
mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidity_verifier/mv-lookup"]
ezkl = [
"onnx",
"serde",
"serde_json",
"log",
"colored",
"env_logger",
"tabled/color",
"colored_json",
"halo2_proofs/circuit-params",
]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
"halo2_solidity_verifier/mv-lookup",
]
det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix"}
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[profile.release]
rustflags = [ "-C", "relocation-model=pic" ]
rustflags = ["-C", "relocation-model=pic"]

View File

@@ -74,6 +74,10 @@ For more details visit the [docs](https://docs.ezkl.xyz).
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
#### In-browser EVM verifier
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
### building the project 🔨

View File

@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -15,6 +17,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
static mut KERNEL_HEIGHT: usize = 2;
static mut KERNEL_WIDTH: usize = 2;
@@ -121,28 +124,35 @@ fn runcnvrl(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
&circuit, &params, true,
)
.unwrap();
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -1,11 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -14,6 +16,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -90,25 +93,35 @@ fn rundot(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -1,11 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -14,6 +16,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -94,25 +97,35 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -4,17 +4,20 @@ use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
const BITS: Range = (-32768, 32768);
@@ -112,25 +115,35 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -4,17 +4,20 @@ use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::table::Range;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
const BITS: Range = (-8180, 8180);
@@ -115,25 +118,35 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(k as u64));
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(k as u64));
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -1,11 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -14,6 +16,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -86,25 +89,35 @@ fn runsum(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::hybrid::HybridOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -15,6 +17,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
static mut IMAGE_HEIGHT: usize = 2;
static mut IMAGE_WIDTH: usize = 2;
@@ -101,28 +104,35 @@ fn runsumpool(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
&circuit, &params, true,
)
.unwrap();
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -1,11 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -14,6 +16,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -84,25 +87,35 @@ fn runadd(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -15,6 +17,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -83,25 +86,35 @@ fn runpow(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -4,12 +4,13 @@ use ezkl::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use ezkl::circuit::modules::Module;
use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::circuit::Value;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -18,6 +19,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
const L: usize = 10;
@@ -62,7 +64,7 @@ fn runposeidon(c: &mut Criterion) {
let params = gen_srs::<KZGCommitmentScheme<_>>(k);
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
let output =
let _output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
.unwrap();
@@ -76,25 +78,35 @@ fn runposeidon(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
Some(output[0].clone()),
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -2,11 +2,12 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -14,6 +15,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
const BITS: Range = (-32768, 32768);
static mut LEN: usize = 4;
@@ -91,25 +93,35 @@ fn runrelu(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
NLCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -6,6 +6,7 @@ use ezkl::fieldutils;
use ezkl::fieldutils::i32_to_felt;
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::commitment::Params;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -489,6 +490,7 @@ pub fn runconv() {
strategy,
pi_for_real_prover,
&mut transcript,
params.n(),
);
assert!(verify.is_ok());

View File

@@ -309,7 +309,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(ezkl.string_to_felt(res['processed_outputs']['poseidon_hash'][0]))"
"print(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]))"
]
},
{
@@ -325,7 +325,7 @@
"metadata": {},
"outputs": [],
"source": [
"from web3 import Web3, HTTPProvider, utils\n",
"from web3 import Web3, HTTPProvider\n",
"from solcx import compile_standard\n",
"from decimal import Decimal\n",
"import json\n",
@@ -338,7 +338,7 @@
"\n",
"def test_on_chain_data(res):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = [int(ezkl.string_to_felt(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
" data = [int(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
"\n",
" # Step 1: Prepare the data\n",
" # Step 2: Prepare and compile the contract.\n",
@@ -648,7 +648,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
},
"orig_nbformat": 4
},

View File

@@ -695,7 +695,7 @@
"formatted_output = \"[\"\n",
"for i, value in enumerate(proof[\"instances\"]):\n",
" for j, field_element in enumerate(value):\n",
" onchain_input_array.append(ezkl.string_to_felt(field_element))\n",
" onchain_input_array.append(ezkl.felt_to_big_endian(field_element))\n",
" formatted_output += str(onchain_input_array[-1])\n",
" if j != len(value) - 1:\n",
" formatted_output += \", \"\n",
@@ -705,7 +705,7 @@
"# copy them over to remix and see if they verify\n",
"# What happens when you change a value?\n",
"print(\"pubInputs: \", formatted_output)\n",
"print(\"proof: \", \"0x\" + proof[\"proof\"])"
"print(\"proof: \", proof[\"proof\"])"
]
},
{

View File

@@ -10,7 +10,7 @@
"\n",
"## Generalized Inverse\n",
"\n",
"We show how to use EZKL to prove that we know matrices $A$ and its generalized inverse $B$. Since these are large we deal with the KZG commitments, with $a$ the kzgcommit of $A$, $b$ the kzgcommit of $B$, and $ABA = A$.\n"
"We show how to use EZKL to prove that we know matrices $A$ and its generalized inverse $B$. Since these are large we deal with the KZG commitments, with $a$ the polycommit of $A$, $b$ the polycommit of $B$, and $ABA = A$.\n"
]
},
{
@@ -77,7 +77,7 @@
"outputs": [],
"source": [
"gip_run_args = ezkl.PyRunArgs()\n",
"gip_run_args.input_visibility = \"kzgcommit\" # matrix and generalized inverse commitments\n",
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
]
@@ -340,4 +340,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -161,7 +161,7 @@
"- `fixed`: known to the prover and verifier (as a commit), but not modifiable by the prover.\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be inserted directly within the proof bytes. \n",
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be inserted directly within the proof bytes. \n",
"\n",
"\n",
"Here we create the following setup:\n",
@@ -510,4 +510,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -154,11 +154,11 @@
"- `fixed`: known to the prover and verifier (as a commit), but not modifiable by the prover.\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"kzgcommit\"\n",
"- `param_visibility`: \"kzgcommit\"\n",
"- `input_visibility`: \"polycommit\"\n",
"- `param_visibility`: \"polycommit\"\n",
"- `output_visibility`: public\n",
"\n",
"We encourage you to play around with other setups :) \n",
@@ -186,8 +186,8 @@
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"kzgcommit\"\n",
"run_args.param_visibility = \"kzgcommit\"\n",
"run_args.input_visibility = \"polycommit\"\n",
"run_args.param_visibility = \"polycommit\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
@@ -512,4 +512,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -264,9 +264,9 @@
"### KZG commitment intermediate calculations\n",
"\n",
"the visibility parameters are:\n",
"- `input_visibility`: \"kzgcommit\"\n",
"- `input_visibility`: \"polycommit\"\n",
"- `param_visibility`: \"public\"\n",
"- `output_visibility`: kzgcommit"
"- `output_visibility`: polycommit"
]
},
{
@@ -280,15 +280,15 @@
"srs_path = os.path.join('kzg.srs')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"kzgcommit\"\n",
"run_args.input_visibility = \"polycommit\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"kzgcommit\"\n",
"run_args.output_visibility = \"polycommit\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"run_args.input_scale = 0\n",
"run_args.param_scale = 0\n",
"run_args.logrows = 18\n",
"\n",
"ezkl.get_srs(logrows=run_args.logrows)\n"
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)\n"
]
},
{
@@ -343,7 +343,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" compress_selectors=True,\n",
" )\n",
"\n",
" assert res == True\n",

View File

@@ -208,7 +208,7 @@
"- `private`: known only to the prover\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
@@ -234,7 +234,7 @@
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n",
"\n",
"ezkl.get_srs(logrows=run_args.logrows)"
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
]
},
{
@@ -385,9 +385,9 @@
"### KZG commitment intermediate calculations\n",
"\n",
"This time the visibility parameters are:\n",
"- `input_visibility`: \"kzgcommit\"\n",
"- `input_visibility`: \"polycommit\"\n",
"- `param_visibility`: \"public\"\n",
"- `output_visibility`: kzgcommit"
"- `output_visibility`: polycommit"
]
},
{
@@ -399,9 +399,9 @@
"import ezkl\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"kzgcommit\"\n",
"run_args.input_visibility = \"polycommit\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"kzgcommit\"\n",
"run_args.output_visibility = \"polycommit\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n"

View File

@@ -122,8 +122,8 @@
"# Loop through each element in the y tensor\n",
"for e in y_input:\n",
" # Apply the custom function and append the result to the list\n",
" print(ezkl.float_to_string(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_string(e, 7)])[0])\n",
" print(ezkl.float_to_felt(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_felt(e, 7)])[0])\n",
"\n",
"y = y.unsqueeze(0)\n",
"y = y.reshape(1, 9)\n",

View File

@@ -275,7 +275,6 @@
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
@@ -291,7 +290,7 @@
"source": [
"# Generate a larger SRS. This is needed for the aggregated proof\n",
"\n",
"res = ezkl.get_srs(settings_path=None, logrows=21)"
"res = ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
]
},
{

View File

@@ -9,7 +9,7 @@
"source": [
"## Solvency demo\n",
"\n",
"Here we create a demo of a solvency calculation in the manner of [summa-solvency](https://github.com/summa-dev/summa-solvency). The aim here is to demonstrate the use of the new kzgcommit method detailed [here](https://blog.ezkl.xyz/post/commits/). \n",
"Here we create a demo of a solvency calculation in the manner of [summa-solvency](https://github.com/summa-dev/summa-solvency). The aim here is to demonstrate the use of the new polycommit method detailed [here](https://blog.ezkl.xyz/post/commits/). \n",
"\n",
"In this setup:\n",
"- the commitments to users, respective balances, and total balance are known are publicly known to the prover and verifier. \n",
@@ -126,7 +126,7 @@
"# Loop through each element in the y tensor\n",
"for e in user_preimages:\n",
" # Apply the custom function and append the result to the list\n",
" users.append(ezkl.poseidon_hash([ezkl.float_to_string(e, 0)])[0])\n",
" users.append(ezkl.poseidon_hash([ezkl.float_to_felt(e, 0)])[0])\n",
"\n",
"users_t = torch.tensor(user_preimages)\n",
"users_t = users_t.reshape(1, 6)\n",
@@ -177,10 +177,10 @@
"- `private`: known only to the prover\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"kzgcommit\"\n",
"- `input_visibility`: \"polycommit\"\n",
"- `param_visibility`: \"public\"\n",
"- `output_visibility`: public\n",
"\n",
@@ -202,8 +202,8 @@
"outputs": [],
"source": [
"run_args = ezkl.PyRunArgs()\n",
"# \"kzgcommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"kzgcommit\"\n",
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"polycommit\"\n",
"# the parameters are public\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public (this is the inequality test)\n",
@@ -303,7 +303,7 @@
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [ezkl.float_to_string(1.0, 0)]\n",
"witness[\"outputs\"][0] = [ezkl.float_to_felt(1.0, 0)]\n",
"json.dump(witness, open(witness_path, \"w\"))"
]
},
@@ -417,7 +417,7 @@
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [ezkl.float_to_string(1.0, 0)]\n",
"witness[\"outputs\"][0] = [ezkl.float_to_felt(1.0, 0)]\n",
"json.dump(witness, open(witness_path, \"w\"))\n"
]
},
@@ -510,9 +510,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -633,7 +633,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
]
},
{
@@ -664,7 +664,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
")"
]
},

View File

@@ -503,11 +503,11 @@
"pyplot.arrow(0, 0, 1, 0, width=0.02, alpha=0.5)\n",
"pyplot.arrow(0, 0, 0, 1, width=0.02, alpha=0.5)\n",
"\n",
"arrow_x = ezkl.string_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.string_to_float(witness['outputs'][0][1], out_scale)\n",
"arrow_x = ezkl.felt_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.felt_to_float(witness['outputs'][0][1], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)\n",
"arrow_x = ezkl.string_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.string_to_float(witness['outputs'][0][3], out_scale)\n",
"arrow_x = ezkl.felt_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.felt_to_float(witness['outputs'][0][3], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)"
]
}

View File

@@ -0,0 +1,48 @@
from torch import nn
import json
import numpy as np
import tf2onnx
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
# gather_nd in tf then export to onnx
x = in1 = Input((15, 18,))
w = in2 = Input((15, 1), dtype=tf.int32)
x = tf.gather_nd(x, w, batch_dims=1)
tm = Model((in1, in2), x )
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 15, 18]
index_shape = [1, 15, 1]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)
# w = random int tensor
w = np.random.randint(0, 10, index_shape)
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d, d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import sys
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}

Binary file not shown.

View File

@@ -0,0 +1,60 @@
# inbrowser-evm-verify
We would like the Solidity verifier to be canonical and usually all you ever need. For this, we need to be able to run that verifier in browser.
## How to use (Node js)
```ts
import localEVMVerify from '@ezkljs/verify';
// Load in the proof file as a buffer
const proofFileBuffer = fs.readFileSync(`${path}/${example}/proof.pf`)
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
const result = await localEVMVerify(proofFileBuffer, bytecode)
console.log('result', result)
```
**Note**: Run `ezkl create-evm-verifier` to get the Solidity verifier, with which you can retrieve the bytecode once compiled. We recommend compiling to the Shanghai hardfork target, else you will have to pass an additional parameter specifying the EVM version to the `localEVMVerify` function like so (for Paris hardfork):
```ts
import localEVMVerify, { hardfork } from '@ezkljs/verify';
const result = await localEVMVerify(proofFileBuffer, bytecode, hardfork['Paris'])
```
**Note**: You can also verify separated vk verifiers using the `localEVMVerify` function. Just pass the vk verifier bytecode as the third parameter like so:
```ts
import localEVMVerify from '@ezkljs/verify';
const result = await localEVMVerify(proofFileBuffer, verifierBytecode, VKBytecode)
```
## How to use (Browser)
```ts
import localEVMVerify from '@ezkljs/verify';
// Load in the proof file as a buffer using the web apis (fetch, FileReader, etc)
// We use fetch in this example to load the proof file as a buffer
const proofFileBuffer = await fetch(`${path}/${example}/proof.pf`).then(res => res.arrayBuffer())
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
const result = await browserEVMVerify(proofFileBuffer, bytecode)
console.log('result', result)
```
Output:
```ts
result: true
```

View File

@@ -0,0 +1,42 @@
{
"name": "@ezkljs/verify",
"version": "0.0.0",
"publishConfig": {
"access": "public"
},
"description": "Evm verify EZKL proofs in the browser.",
"main": "dist/commonjs/index.js",
"module": "dist/esm/index.js",
"types": "dist/commonjs/index.d.ts",
"files": [
"dist",
"LICENSE",
"README.md"
],
"scripts": {
"clean": "rm -r dist || true",
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",
"@ethereumjs/evm": "^2.0.0",
"@ethereumjs/statemanager": "^2.0.0",
"@ethereumjs/tx": "^5.0.0",
"@ethereumjs/util": "^9.0.0",
"@ethereumjs/vm": "^7.0.0",
"@ethersproject/abi": "^5.7.0",
"@ezkljs/engine": "^9.4.4",
"ethers": "^6.7.1",
"json-bigint": "^1.0.0"
},
"devDependencies": {
"@types/node": "^20.8.3",
"ts-loader": "^9.5.0",
"ts-node": "^10.9.1",
"resolve-tspaths": "^0.8.16",
"tsconfig-paths": "^4.2.0",
"typescript": "^5.2.2"
}
}

1479
in-browser-evm-verifier/pnpm-lock.yaml generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,145 @@
import { defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
import { Address, hexToBytes } from '@ethereumjs/util'
import { Chain, Common, Hardfork } from '@ethereumjs/common'
import { LegacyTransaction, LegacyTxData } from '@ethereumjs/tx'
// import { DefaultStateManager } from '@ethereumjs/statemanager'
// import { Blockchain } from '@ethereumjs/blockchain'
import { VM } from '@ethereumjs/vm'
import { EVM } from '@ethereumjs/evm'
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
import { getAccountNonce, insertAccount } from './utils/account-utils'
import { encodeVerifierCalldata } from '../nodejs/ezkl';
import { error } from 'console'
async function deployContract(
vm: VM,
common: Common,
senderPrivateKey: Uint8Array,
deploymentBytecode: string
): Promise<Address> {
// Contracts are deployed by sending their deployment bytecode to the address 0
// The contract params should be abi-encoded and appended to the deployment bytecode.
// const data =
const data = encodeDeployment(deploymentBytecode)
const txData = {
data,
nonce: await getAccountNonce(vm, senderPrivateKey),
}
const tx = LegacyTransaction.fromTxData(
buildTransaction(txData) as LegacyTxData,
{ common, allowUnlimitedInitCodeSize: true },
).sign(senderPrivateKey)
const deploymentResult = await vm.runTx({
tx,
skipBlockGasLimitValidation: true,
skipNonce: true
})
if (deploymentResult.execResult.exceptionError) {
throw deploymentResult.execResult.exceptionError
}
return deploymentResult.createdAddress!
}
async function verify(
vm: VM,
contractAddress: Address,
caller: Address,
proof: Uint8Array | Uint8ClampedArray,
vkAddress?: Address | Uint8Array,
): Promise<boolean> {
if (proof instanceof Uint8Array) {
proof = new Uint8ClampedArray(proof.buffer)
}
if (vkAddress) {
const vkAddressBytes = hexToBytes(vkAddress.toString())
const vkAddressArray = Array.from(vkAddressBytes)
let string = JSON.stringify(vkAddressArray)
const uint8Array = new TextEncoder().encode(string);
// Step 3: Convert to Uint8ClampedArray
vkAddress = new Uint8Array(uint8Array.buffer);
// convert uitn8array of length
error('vkAddress', vkAddress)
}
const data = encodeVerifierCalldata(proof, vkAddress)
const verifyResult = await vm.evm.runCall({
to: contractAddress,
caller: caller,
origin: caller, // The tx.origin is also the caller here
data: data,
})
if (verifyResult.execResult.exceptionError) {
throw verifyResult.execResult.exceptionError
}
const results = AbiCoder.decode(['bool'], verifyResult.execResult.returnValue)
return results[0]
}
/**
* Spins up an ephemeral EVM instance for executing the bytecode of a solidity verifier
* @param proof Json serialized proof file
* @param bytecode The bytecode of a compiled solidity verifier.
* @param bytecode_vk The bytecode of a contract that stores the vk. (Optional, only required if the vk is stored in a separate contract)
* @param evmVersion The evm version to use for the verification. (Default: London)
* @returns The result of the evm verification.
* @throws If the verify transaction reverts
*/
export default async function localEVMVerify(
proof: Uint8Array | Uint8ClampedArray,
bytecode_verifier: string,
bytecode_vk?: string,
evmVersion?: Hardfork,
): Promise<boolean> {
try {
const hardfork = evmVersion ? evmVersion : Hardfork['Shanghai']
const common = new Common({ chain: Chain.Mainnet, hardfork })
const accountPk = hexToBytes(
'0xe331b6d69882b4cb4ea581d88e0b604039a3de5967688d3dcffdd2270c0fd109', // anvil deterministic Pk
)
const evm = new EVM({
allowUnlimitedContractSize: true,
allowUnlimitedInitCodeSize: true,
})
const vm = await VM.create({ common, evm })
const accountAddress = Address.fromPrivateKey(accountPk)
await insertAccount(vm, accountAddress)
const verifierAddress = await deployContract(
vm,
common,
accountPk,
bytecode_verifier
)
if (bytecode_vk) {
const accountPk = hexToBytes("0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"); // anvil deterministic Pk
const accountAddress = Address.fromPrivateKey(accountPk)
await insertAccount(vm, accountAddress)
const output = await deployContract(vm, common, accountPk, bytecode_vk)
const result = await verify(vm, verifierAddress, accountAddress, proof, output)
return true
}
const result = await verify(vm, verifierAddress, accountAddress, proof)
return result
} catch (error) {
// log or re-throw the error, depending on your needs
console.error('An error occurred:', error)
throw error
}
}

View File

@@ -0,0 +1,32 @@
import { VM } from '@ethereumjs/vm'
import { Account, Address } from '@ethereumjs/util'
export const keyPair = {
secretKey:
'0x3cd7232cd6f3fc66a57a6bedc1a8ed6c228fff0a327e169c2bcc5e869ed49511',
publicKey:
'0x0406cc661590d48ee972944b35ad13ff03c7876eae3fd191e8a2f77311b0a3c6613407b5005e63d7d8d76b89d5f900cde691497688bb281e07a5052ff61edebdc0',
}
export const insertAccount = async (vm: VM, address: Address) => {
const acctData = {
nonce: 0,
balance: BigInt('1000000000000000000'), // 1 eth
}
const account = Account.fromAccountData(acctData)
await vm.stateManager.putAccount(address, account)
}
export const getAccountNonce = async (
vm: VM,
accountPrivateKey: Uint8Array,
) => {
const address = Address.fromPrivateKey(accountPrivateKey)
const account = await vm.stateManager.getAccount(address)
if (account) {
return account.nonce
} else {
return BigInt(0)
}
}

View File

@@ -0,0 +1,59 @@
import { Interface, defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
import {
AccessListEIP2930TxData,
FeeMarketEIP1559TxData,
TxData,
} from '@ethereumjs/tx'
type TransactionsData =
| TxData
| AccessListEIP2930TxData
| FeeMarketEIP1559TxData
export const encodeFunction = (
method: string,
params?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
types: any[]
values: unknown[]
},
): string => {
const parameters = params?.types ?? []
const methodWithParameters = `function ${method}(${parameters.join(',')})`
const signatureHash = new Interface([methodWithParameters]).getSighash(method)
const encodedArgs = AbiCoder.encode(parameters, params?.values ?? [])
return signatureHash + encodedArgs.slice(2)
}
export const encodeDeployment = (
bytecode: string,
params?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
types: any[]
values: unknown[]
},
) => {
const deploymentData = '0x' + bytecode
if (params) {
const argumentsEncoded = AbiCoder.encode(params.types, params.values)
return deploymentData + argumentsEncoded.slice(2)
}
return deploymentData
}
export const buildTransaction = (
data: Partial<TransactionsData>,
): TransactionsData => {
const defaultData: Partial<TransactionsData> = {
gasLimit: 3_000_000_000_000_000,
gasPrice: 7,
value: 0,
data: '0x',
}
return {
...defaultData,
...data,
}
}

View File

@@ -0,0 +1,7 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "CommonJS",
"outDir": "./dist/commonjs"
}
}

View File

@@ -0,0 +1,7 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "ES2020",
"outDir": "./dist/esm"
}
}

View File

@@ -0,0 +1,62 @@
{
"compilerOptions": {
"rootDir": "src",
"target": "es2017",
"outDir": "dist",
"declaration": true,
"lib": [
"dom",
"dom.iterable",
"esnext"
],
"allowJs": true,
"checkJs": true,
"skipLibCheck": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noEmit": false,
"esModuleInterop": true,
"module": "CommonJS",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
// "incremental": true,
"noUncheckedIndexedAccess": true,
"baseUrl": ".",
"paths": {
"@/*": [
"./src/*"
]
}
},
"include": [
"src/**/*.ts",
"src/**/*.tsx",
"src/**/*.cjs",
"src/**/*.mjs"
],
"exclude": [
"node_modules"
],
// NEW: Options for file/directory watching
"watchOptions": {
// Use native file system events for files and directories
"watchFile": "useFsEvents",
"watchDirectory": "useFsEvents",
// Poll files for updates more frequently
// when they're updated a lot.
"fallbackPolling": "dynamicPriority",
// Don't coalesce watch notification
"synchronousWatchDirectory": true,
// Finally, two additional settings for reducing the amount of possible
// files to track work from these directories
"excludeDirectories": [
"**/node_modules",
"_build"
],
"excludeFiles": [
"build/fileWhichChangesOften.ts"
]
}
}

View File

@@ -7,7 +7,7 @@
"test": "jest"
},
"devDependencies": {
"@ezkljs/engine": "^2.4.5",
"@ezkljs/engine": "^9.4.4",
"@ezkljs/verify": "^0.0.6",
"@jest/types": "^29.6.3",
"@types/file-saver": "^2.0.5",
@@ -27,4 +27,4 @@
"tsconfig-paths": "^4.2.0",
"typescript": "5.1.6"
}
}
}

11
pnpm-lock.yaml generated
View File

@@ -6,8 +6,8 @@ settings:
devDependencies:
'@ezkljs/engine':
specifier: ^2.4.5
version: 2.4.5
specifier: ^9.4.4
version: 9.4.4
'@ezkljs/verify':
specifier: ^0.0.6
version: 0.0.6(buffer@6.0.3)
@@ -785,6 +785,13 @@ packages:
json-bigint: 1.0.0
dev: true
/@ezkljs/engine@9.4.4:
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
dependencies:
'@types/json-bigint': 1.0.1
json-bigint: 1.0.0
dev: true
/@ezkljs/verify@0.0.6(buffer@6.0.3):
resolution: {integrity: sha512-9DHoEhLKl1DBGuUVseXLThuMyYceY08Zymr/OsLH0zbdA9OoISYhb77j4QPm4ANRKEm5dCi8oHDqkwGbFc2xFQ==}
dependencies:

View File

@@ -11,8 +11,8 @@ use ezkl::execute::run;
#[cfg(not(target_arch = "wasm32"))]
use ezkl::logger::init_logger;
#[cfg(not(target_arch = "wasm32"))]
use log::{error, info};
#[cfg(not(target_arch = "wasm32"))]
use log::{debug, error, info};
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
use rand::prelude::SliceRandom;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "icicle")]
@@ -25,6 +25,7 @@ use std::error::Error;
pub async fn main() -> Result<(), Box<dyn Error>> {
let args = Cli::parse();
init_logger();
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
banner();
#[cfg(feature = "icicle")]
if env::var("ENABLE_ICICLE_GPU").is_ok() {
@@ -32,7 +33,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
} else {
info!("Running with CPU");
}
info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
let res = run(args.command).await;
match &res {
Ok(_) => info!("succeeded"),
@@ -44,7 +45,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
#[cfg(target_arch = "wasm32")]
pub fn main() {}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
fn banner() {
let ell: Vec<&str> = vec![
"for Neural Networks",

View File

@@ -2,7 +2,7 @@
pub mod poseidon;
///
pub mod kzg;
pub mod polycommit;
///
pub mod planner;

View File

@@ -15,7 +15,7 @@ use halo2_proofs::{
Instance, Selector, TableColumn,
},
};
use log::{trace, warn};
use log::{debug, trace};
/// A simple [`FloorPlanner`] that performs minimal optimizations.
#[derive(Debug)]
@@ -119,7 +119,7 @@ impl<'a, F: Field, CS: Assignment<F> + 'a + SyncDeps> Layouter<F> for ModuleLayo
Error::Synthesis
})?;
if !self.regions.contains_key(&index) {
warn!("spawning module {}", index)
debug!("spawning module {}", index)
};
self.current_module = index;
}

View File

@@ -6,10 +6,9 @@ Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/pose
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::poly::commitment::{Blind, Params};
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
use halo2_proofs::poly::commitment::{Blind, CommitmentScheme, Params};
use halo2_proofs::{circuit::*, plonk::*};
use halo2curves::bn256::{Bn256, G1Affine};
use halo2curves::bn256::G1Affine;
use halo2curves::group::prime::PrimeCurveAffine;
use halo2curves::group::Curve;
use halo2curves::CurveAffine;
@@ -18,35 +17,33 @@ use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::Module;
/// The number of instance columns used by the KZG hash function
/// The number of instance columns used by the PolyCommit hash function
pub const NUM_INSTANCE_COLUMNS: usize = 0;
/// The number of advice columns used by the KZG hash function
/// The number of advice columns used by the PolyCommit hash function
pub const NUM_INNER_COLS: usize = 1;
#[derive(Debug, Clone)]
/// WIDTH, RATE and L are const generics for the struct, which represent the width, rate, and number of inputs for the Poseidon hash function, respectively.
/// This means they are values that are known at compile time and can be used to specialize the implementation of the struct.
/// The actual chip provided by halo2_gadgets is added to the parent Chip.
pub struct KZGConfig {
/// Configuration for the PolyCommit chip
pub struct PolyCommitConfig {
///
pub hash_inputs: VarTensor,
pub inputs: VarTensor,
}
type InputAssignments = ();
/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash
///
#[derive(Debug)]
pub struct KZGChip {
config: KZGConfig,
pub struct PolyCommitChip {
config: PolyCommitConfig,
}
impl KZGChip {
impl PolyCommitChip {
/// Commit to the message using the KZG commitment scheme
pub fn commit(
message: Vec<Fp>,
pub fn commit<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
message: Vec<Scheme::Scalar>,
degree: u32,
num_unusable_rows: u32,
params: &ParamsKZG<Bn256>,
params: &Scheme::ParamsProver,
) -> Vec<G1Affine> {
let k = params.k();
let domain = halo2_proofs::poly::EvaluationDomain::new(degree, k);
@@ -81,14 +78,14 @@ impl KZGChip {
}
}
impl Module<Fp> for KZGChip {
type Config = KZGConfig;
impl Module<Fp> for PolyCommitChip {
type Config = PolyCommitConfig;
type InputAssignments = InputAssignments;
type RunInputs = Vec<Fp>;
type Params = (usize, usize);
fn name(&self) -> &'static str {
"KZG"
"PolyCommit"
}
fn instance_increment_input(&self) -> Vec<usize> {
@@ -102,8 +99,8 @@ impl Module<Fp> for KZGChip {
/// Configuration of the PoseidonChip
fn configure(meta: &mut ConstraintSystem<Fp>, params: Self::Params) -> Self::Config {
let hash_inputs = VarTensor::new_unblinded_advice(meta, params.0, NUM_INNER_COLS, params.1);
Self::Config { hash_inputs }
let inputs = VarTensor::new_unblinded_advice(meta, params.0, NUM_INNER_COLS, params.1);
Self::Config { inputs }
}
fn layout_inputs(
@@ -125,8 +122,8 @@ impl Module<Fp> for KZGChip {
) -> Result<ValTensor<Fp>, Error> {
assert_eq!(input.len(), 1);
layouter.assign_region(
|| "kzg commit",
|mut region| self.config.hash_inputs.assign(&mut region, 0, &input[0]),
|| "PolyCommit",
|mut region| self.config.inputs.assign(&mut region, 0, &input[0]),
)
}
@@ -163,7 +160,7 @@ mod tests {
}
impl Circuit<Fp> for HashCircuit {
type Config = KZGConfig;
type Config = PolyCommitConfig;
type FloorPlanner = ModulePlanner;
type Params = ();
@@ -178,7 +175,7 @@ mod tests {
fn configure(meta: &mut ConstraintSystem<Fp>) -> Self::Config {
let params = (K, R);
KZGChip::configure(meta, params)
PolyCommitChip::configure(meta, params)
}
fn synthesize(
@@ -186,8 +183,8 @@ mod tests {
config: Self::Config,
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let kzg_chip = KZGChip::new(config);
kzg_chip.layout(&mut layouter, &[self.message.clone()], 0);
let polycommit_chip = PolyCommitChip::new(config);
polycommit_chip.layout(&mut layouter, &[self.message.clone()], 0);
Ok(())
}
@@ -195,7 +192,7 @@ mod tests {
#[test]
#[ignore]
fn kzg_for_a_range_of_input_sizes() {
fn polycommit_chip_for_a_range_of_input_sizes() {
let rng = rand::rngs::OsRng;
#[cfg(not(target_arch = "wasm32"))]
@@ -225,7 +222,7 @@ mod tests {
#[test]
#[ignore]
fn kzg_commit_much_longer_input() {
fn polycommit_chip_much_longer_input() {
#[cfg(not(target_arch = "wasm32"))]
env_logger::init();

View File

@@ -12,15 +12,11 @@ pub enum BaseOp {
DotInit,
CumProdInit,
CumProd,
Identity,
Add,
Mult,
Sub,
SumInit,
Sum,
Neg,
Range { tol: i32 },
IsZero,
IsBoolean,
}
@@ -36,12 +32,8 @@ impl BaseOp {
let (a, b) = inputs;
match &self {
BaseOp::Add => a + b,
BaseOp::Identity => b,
BaseOp::Neg => -b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::Range { .. } => b,
BaseOp::IsZero => b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
@@ -73,19 +65,15 @@ impl BaseOp {
/// display func
pub fn as_str(&self) -> &'static str {
match self {
BaseOp::Identity => "IDENTITY",
BaseOp::Dot => "DOT",
BaseOp::DotInit => "DOTINIT",
BaseOp::CumProdInit => "CUMPRODINIT",
BaseOp::CumProd => "CUMPROD",
BaseOp::Add => "ADD",
BaseOp::Neg => "NEG",
BaseOp::Sub => "SUB",
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::Range { .. } => "RANGE",
BaseOp::IsZero => "ISZERO",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -93,8 +81,6 @@ impl BaseOp {
/// Returns the range of the query offset for this operation.
pub fn query_offset_rng(&self) -> (i32, usize) {
match self {
BaseOp::Identity => (0, 1),
BaseOp::Neg => (0, 1),
BaseOp::DotInit => (0, 1),
BaseOp::Dot => (-1, 2),
BaseOp::CumProd => (-1, 2),
@@ -104,8 +90,6 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::Range { .. } => (0, 1),
BaseOp::IsZero => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -113,8 +97,6 @@ impl BaseOp {
/// Returns the number of inputs for this operation.
pub fn num_inputs(&self) -> usize {
match self {
BaseOp::Identity => 1,
BaseOp::Neg => 1,
BaseOp::DotInit => 2,
BaseOp::Dot => 2,
BaseOp::CumProdInit => 1,
@@ -124,8 +106,6 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::Range { .. } => 1,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}
@@ -133,19 +113,15 @@ impl BaseOp {
/// Returns the number of outputs for this operation.
pub fn constraint_idx(&self) -> usize {
match self {
BaseOp::Identity => 0,
BaseOp::Neg => 0,
BaseOp::DotInit => 0,
BaseOp::Dot => 1,
BaseOp::Add => 0,
BaseOp::Sub => 0,
BaseOp::Mult => 0,
BaseOp::Range { .. } => 0,
BaseOp::Sum => 1,
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}

View File

@@ -16,10 +16,11 @@ use pyo3::{
types::PyString,
};
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
use crate::{
circuit::ops::base::BaseOp,
circuit::{
ops::base::BaseOp,
table::{Range, RangeCheck, Table},
utils,
},
@@ -61,6 +62,22 @@ pub enum CheckMode {
UNSAFE,
}
impl std::fmt::Display for CheckMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CheckMode::SAFE => write!(f, "safe"),
CheckMode::UNSAFE => write!(f, "unsafe"),
}
}
}
impl ToFlags for CheckMode {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<String> for CheckMode {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
@@ -83,6 +100,19 @@ pub struct Tolerance {
pub scale: utils::F32,
}
impl std::fmt::Display for Tolerance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.2}", self.val)
}
}
impl ToFlags for Tolerance {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl FromStr for Tolerance {
type Err = String;
@@ -158,31 +188,158 @@ impl<'source> FromPyObject<'source> for Tolerance {
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct DynamicLookups {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub lookup_selectors: BTreeMap<(usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub table_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub tables: Vec<VarTensor>,
}
impl DynamicLookups {
/// Returns a new [DynamicLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let single_col_dummy_var = VarTensor::dummy(col_size, 1);
Self {
lookup_selectors: BTreeMap::new(),
table_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
tables: vec![
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
],
}
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct Shuffles {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub input_selectors: BTreeMap<(usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub reference_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub references: Vec<VarTensor>,
}
impl Shuffles {
/// Returns a new [DynamicLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let single_col_dummy_var = VarTensor::dummy(col_size, 1);
Self {
input_selectors: BTreeMap::new(),
reference_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone()],
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
}
}
}
/// A struct representing the selectors for the static lookup tables
#[derive(Clone, Debug, Default)]
pub struct StaticLookups<F: PrimeField + TensorType + PartialOrd> {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub index: VarTensor,
///
pub output: VarTensor,
///
pub input: VarTensor,
}
impl<F: PrimeField + TensorType + PartialOrd> StaticLookups<F> {
/// Returns a new [StaticLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
selectors: BTreeMap::new(),
tables: BTreeMap::new(),
index: dummy_var.clone(),
output: dummy_var.clone(),
input: dummy_var,
}
}
}
/// A struct representing the selectors for custom gates
#[derive(Clone, Debug, Default)]
pub struct CustomGates {
/// the inputs to the accumulated operations.
pub inputs: Vec<VarTensor>,
/// the (currently singular) output of the accumulated operations.
pub output: VarTensor,
/// selector
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
}
impl CustomGates {
/// Returns a new [CustomGates] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
inputs: vec![dummy_var.clone(), dummy_var.clone()],
output: dummy_var,
selectors: BTreeMap::new(),
}
}
}
/// A struct representing the selectors for the range checks
#[derive(Clone, Debug, Default)]
pub struct RangeChecks<F: PrimeField + TensorType + PartialOrd> {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub ranges: BTreeMap<Range, RangeCheck<F>>,
///
pub index: VarTensor,
///
pub input: VarTensor,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeChecks<F> {
/// Returns a new [RangeChecks] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
selectors: BTreeMap::new(),
ranges: BTreeMap::new(),
index: dummy_var.clone(),
input: dummy_var,
}
}
}
/// Configuration for an accumulated arg.
#[derive(Clone, Debug, Default)]
pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
/// the inputs to the accumulated operations.
pub inputs: Vec<VarTensor>,
/// the VarTensor reserved for lookup operations (could be an element of inputs)
/// Note that you should be careful to ensure that the lookup_input is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub lookup_input: VarTensor,
/// the (currently singular) output of the accumulated operations.
pub output: VarTensor,
/// the VarTensor reserved for lookup operations (could be an element of inputs or the same as output)
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub lookup_output: VarTensor,
///
pub lookup_index: VarTensor,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure [BaseOp].
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
///
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Custom gates
pub custom_gates: CustomGates,
/// StaticLookups
pub static_lookups: StaticLookups<F>,
/// [Selector]s for the dynamic lookup tables
pub dynamic_lookups: DynamicLookups,
/// [Selector]s for the range checks
pub range_checks: RangeChecks<F>,
/// [Selector]s for the shuffles
pub shuffles: Shuffles,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
@@ -191,19 +348,12 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
inputs: vec![dummy_var.clone(), dummy_var.clone()],
lookup_input: dummy_var.clone(),
output: dummy_var.clone(),
lookup_output: dummy_var.clone(),
lookup_index: dummy_var,
selectors: BTreeMap::new(),
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
_marker: PhantomData,
}
@@ -236,10 +386,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
for j in 0..output.num_inner_cols() {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Neg, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Identity, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -284,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
vec![expected_output[base_op.constraint_idx()].clone()]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
@@ -343,16 +484,15 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
.collect();
Self {
selectors,
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
inputs: inputs.to_vec(),
lookup_input: VarTensor::Empty,
lookup_output: VarTensor::Empty,
lookup_index: VarTensor::Empty,
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
output: output.clone(),
custom_gates: CustomGates {
inputs: inputs.to_vec(),
output: output.clone(),
selectors,
},
static_lookups: StaticLookups::default(),
dynamic_lookups: DynamicLookups::default(),
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
check_mode,
_marker: PhantomData,
}
@@ -373,8 +513,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
where
F: Field,
{
let mut selectors = BTreeMap::new();
if !index.is_advice() {
return Err("wrong input type for lookup index".into());
}
@@ -387,9 +525,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
// we borrow mutably twice so we need to do this dance
let table = if !self.tables.contains_key(nl) {
let table = if !self.static_lookups.tables.contains_key(nl) {
// as all tables have the same input we see if there's another table who's input we can reuse
let table = if let Some(table) = self.tables.values().next() {
let table = if let Some(table) = self.static_lookups.tables.values().next() {
Table::<F>::configure(
cs,
lookup_range,
@@ -400,7 +538,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
} else {
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
};
self.tables.insert(nl.clone(), table.clone());
self.static_lookups.tables.insert(nl.clone(), table.clone());
table
} else {
return Ok(());
@@ -484,48 +622,217 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
res
});
}
selectors.insert((nl.clone(), x, y), multi_col_selector);
self.static_lookups
.selectors
.insert((nl.clone(), x, y), multi_col_selector);
}
}
self.lookup_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
if let VarTensor::Empty = self.static_lookups.input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
self.static_lookups.input = input.clone();
}
if let VarTensor::Empty = self.lookup_output {
if let VarTensor::Empty = self.static_lookups.output {
debug!("assigning lookup output");
self.lookup_output = output.clone();
self.static_lookups.output = output.clone();
}
if let VarTensor::Empty = self.lookup_index {
if let VarTensor::Empty = self.static_lookups.index {
debug!("assigning lookup index");
self.lookup_index = index.clone();
self.static_lookups.index = index.clone();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_dynamic_lookup(
&mut self,
cs: &mut ConstraintSystem<F>,
lookups: &[VarTensor; 3],
tables: &[VarTensor; 3],
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
for l in lookups.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
}
}
for t in tables.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
}
}
let one = Expression::Constant(F::ONE);
let s_ltable = cs.complex_selector();
for x in 0..lookups[0].num_blocks() {
for y in 0..lookups[0].num_inner_cols() {
let s_lookup = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_lookupq = cs.query_selector(s_lookup);
let mut expression = vec![];
let s_ltableq = cs.query_selector(s_ltable);
let mut lookup_queries = vec![one.clone()];
for lookup in lookups {
lookup_queries.push(match lookup {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut table_queries = vec![one.clone()];
for table in tables {
table_queries.push(match table {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.dynamic_lookups
.lookup_selectors
.entry((x, y))
.or_insert(s_lookup);
}
}
self.dynamic_lookups.table_selectors.push(s_ltable);
// if we haven't previously initialized the input/output, do so now
if self.dynamic_lookups.tables.is_empty() {
debug!("assigning dynamic lookup table");
self.dynamic_lookups.tables = tables.to_vec();
}
if self.dynamic_lookups.inputs.is_empty() {
debug!("assigning dynamic lookup input");
self.dynamic_lookups.inputs = lookups.to_vec();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_shuffles(
&mut self,
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
for l in inputs.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
}
}
for t in references.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
}
}
let one = Expression::Constant(F::ONE);
let s_reference = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_referenceq = cs.query_selector(s_reference);
let mut input_queries = vec![one.clone()];
for input in inputs {
input_queries.push(match input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.shuffles
.input_selectors
.entry((x, y))
.or_insert(s_input);
}
}
self.shuffles.reference_selectors.push(s_reference);
// if we haven't previously initialized the input/output, do so now
if self.shuffles.references.is_empty() {
debug!("assigning shuffles reference");
self.shuffles.references = references.to_vec();
}
if self.shuffles.inputs.is_empty() {
debug!("assigning shuffles input");
self.shuffles.inputs = inputs.to_vec();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_range_check(
&mut self,
cs: &mut ConstraintSystem<F>,
input: &VarTensor,
index: &VarTensor,
range: Range,
logrows: usize,
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
let mut selectors = BTreeMap::new();
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
}
// we borrow mutably twice so we need to do this dance
let range_check = if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
self.range_checks.ranges.entry(range)
{
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range);
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
} else {
@@ -534,39 +841,73 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
for x in 0..input.num_blocks() {
for y in 0..input.num_inner_cols() {
let single_col_sel = cs.complex_selector();
let len = range_check.selector_constructor.degree;
let multi_col_selector = cs.complex_selector();
cs.lookup("", |cs| {
let mut res = vec![];
let sel = cs.query_selector(single_col_sel);
for (col_idx, input_col) in range_check.inputs.iter().enumerate() {
cs.lookup("", |cs| {
let mut res = vec![];
let sel = cs.query_selector(multi_col_selector);
let input_query = match &input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
};
let synthetic_sel = match len {
1 => Expression::Constant(F::from(1)),
_ => match index {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
},
};
let default_x = range_check.get_first_element();
let input_query = match &input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
};
let not_sel = Expression::Constant(F::ONE) - sel.clone();
let default_x = range_check.get_first_element(col_idx);
res.extend([(
sel.clone() * input_query.clone()
+ not_sel.clone() * Expression::Constant(default_x),
range_check.input,
)]);
let col_expr = sel.clone()
* range_check
.selector_constructor
.get_expr_at_idx(col_idx, synthetic_sel);
res
});
selectors.insert((range, x, y), single_col_sel);
let multiplier = range_check
.selector_constructor
.get_selector_val_at_idx(col_idx);
let not_expr = Expression::Constant(multiplier) - col_expr.clone();
res.extend([(
col_expr.clone() * input_query.clone()
+ not_expr.clone() * Expression::Constant(default_x),
*input_col,
)]);
log::trace!("---------------- col {:?} ------------------", col_idx,);
log::trace!("expr: {:?}", col_expr,);
log::trace!("multiplier: {:?}", multiplier);
log::trace!("not_expr: {:?}", not_expr);
log::trace!("default x: {:?}", default_x);
res
});
}
self.range_checks
.selectors
.insert((range, x, y), multi_col_selector);
}
}
self.range_check_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
if let VarTensor::Empty = self.range_checks.input {
debug!("assigning range check input");
self.range_checks.input = input.clone();
}
if let VarTensor::Empty = self.range_checks.index {
debug!("assigning range check index");
self.range_checks.index = index.clone();
}
Ok(())
@@ -574,7 +915,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
for (i, table) in self.tables.values_mut().enumerate() {
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
if !table.is_assigned {
debug!(
"laying out table for {}",
@@ -595,7 +936,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
for range_check in self.range_checks.values_mut() {
for range_check in self.range_checks.ranges.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
range_check.layout(layouter)?;

View File

@@ -277,7 +277,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::div(
layouts::loop_div(
config,
region,
values[..].try_into()?,

File diff suppressed because it is too large Load Diff

View File

@@ -243,10 +243,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Sign => "SIGN".into(),
LookupOp::GreaterThan { .. } => "GREATER_THAN".into(),
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
LookupOp::LessThan { .. } => "LESS_THAN".into(),
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,

View File

@@ -14,10 +14,17 @@ pub enum PolyOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
GatherND {
batch_dims: usize,
indices: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterND {
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -60,8 +67,6 @@ pub enum PolyOp {
len_prod: usize,
},
Pow(u32),
Pack(u32, u32),
GlobalSumPool,
Concat {
axis: usize,
},
@@ -91,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -110,8 +117,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Sum { .. } => "SUM".into(),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Pack(_, _) => "PACK".into(),
PolyOp::GlobalSumPool => "GLOBALSUMPOOL".into(),
PolyOp::Conv { .. } => "CONV".into(),
PolyOp::DeConv { .. } => "DECONV".into(),
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -181,13 +186,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
output_padding,
stride,
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
PolyOp::Pack(base, scale) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pack inputs".to_string()));
}
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
}
PolyOp::Pow(u) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
@@ -206,7 +204,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
tensor::ops::prod_axes(&inputs[0], axes)
}
PolyOp::GlobalSumPool => unreachable!(),
PolyOp::Concat { axis } => {
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
}
@@ -225,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
@@ -234,13 +243,28 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if let Some(_) = constant_idx {
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
}?;
Ok(ForwardResult { output: res })
@@ -288,7 +312,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
}
}
PolyOp::GatherND {
batch_dims,
indices,
} => {
if let Some(idx) = indices {
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
} else {
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
@@ -304,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterND { constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter_nd(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
)?
.into()
} else {
layouts::scatter_nd(config, region, values[..].try_into()?)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -334,10 +380,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
input
}
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
PolyOp::Pack(base, scale) => {
layouts::pack(config, region, values[..].try_into()?, *base, *scale)?
}
PolyOp::GlobalSumPool => unreachable!(),
PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?,
PolyOp::Slice { axis, start, end } => {
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
@@ -405,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
} else if matches!(self, PolyOp::ScatterElements { .. })
| matches!(self, PolyOp::ScatterND { .. })
{
vec![0, 2]
} else {
vec![]

View File

@@ -20,6 +20,66 @@ use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
index: usize,
col_coord: usize,
}
impl DynamicLookupIndex {
/// Create a new dynamic lookup index
pub fn new(index: usize, col_coord: usize) -> DynamicLookupIndex {
DynamicLookupIndex { index, col_coord }
}
/// Get the lookup index
pub fn index(&self) -> usize {
self.index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
/// update with another dynamic lookup index
pub fn update(&mut self, other: &DynamicLookupIndex) {
self.index += other.index;
self.col_coord += other.col_coord;
}
}
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct ShuffleIndex {
index: usize,
col_coord: usize,
}
impl ShuffleIndex {
/// Create a new dynamic lookup index
pub fn new(index: usize, col_coord: usize) -> ShuffleIndex {
ShuffleIndex { index, col_coord }
}
/// Get the lookup index
pub fn index(&self) -> usize {
self.index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
/// update with another shuffle index
pub fn update(&mut self, other: &ShuffleIndex) {
self.index += other.index;
self.col_coord += other.col_coord;
}
}
/// Region error
#[derive(Debug, thiserror::Error)]
pub enum RegionError {
@@ -66,10 +126,14 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
throw_range_check_error: bool,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
@@ -78,6 +142,31 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants += n;
}
///
pub fn increment_dynamic_lookup_index(&mut self, n: usize) {
self.dynamic_lookup_index.index += n;
}
///
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
self.dynamic_lookup_index.col_coord += n;
}
///
pub fn increment_shuffle_index(&mut self, n: usize) {
self.shuffle_index.index += n;
}
///
pub fn increment_shuffle_col_coord(&mut self, n: usize) {
self.shuffle_index.col_coord += n;
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
}
/// Create a new region context
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
let region = Some(RefCell::new(region));
@@ -89,10 +178,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
}
}
/// Create a new region context from a wrapped region
@@ -100,6 +193,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
) -> RegionCtx<'a, F> {
let linear_coord = row * num_inner_cols;
RegionCtx {
@@ -108,15 +203,23 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
}
}
/// Create a new region context
pub fn new_dummy(row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -126,10 +229,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
}
}
@@ -139,8 +246,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -149,10 +255,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants,
used_lookups,
used_range_checks,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
}
}
@@ -203,12 +313,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs =
AtomicInt::new(self.max_lookup_inputs().try_into().unwrap_or_default());
let min_lookup_inputs =
AtomicInt::new(self.min_lookup_inputs().try_into().unwrap_or_default());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
*output = output
.par_enum_map(|idx, _| {
@@ -224,8 +334,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
starting_linear_coord,
starting_constants,
self.num_inner_cols,
HashSet::new(),
HashSet::new(),
self.throw_range_check_error,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -239,30 +348,30 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
Ordering::SeqCst,
);
let local_max_lookup_inputs =
local_reg.max_lookup_inputs().try_into().unwrap_or_default();
let local_min_lookup_inputs =
local_reg.min_lookup_inputs().try_into().unwrap_or_default();
max_lookup_inputs.fetch_max(local_max_lookup_inputs, Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_min_lookup_inputs, Ordering::SeqCst);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
// update the range checks
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
res
})
.map_err(|e| {
log::error!("dummy_loop: {:?}", e);
Error::Synthesis
})?;
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
self.max_lookup_inputs = max_lookup_inputs.into_inner() as i128;
self.min_lookup_inputs = min_lookup_inputs.into_inner() as i128;
self.max_lookup_inputs = max_lookup_inputs.into_inner();
self.min_lookup_inputs = min_lookup_inputs.into_inner();
}
self.row = row.into_inner();
self.used_lookups = Arc::try_unwrap(lookups)
@@ -279,6 +388,28 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?;
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?;
self.shuffle_index = Arc::try_unwrap(shuffle_index)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
Ok(())
}
@@ -307,8 +438,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
return Err("update_max_min_lookup_range: invalid range".into());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(range.1);
self.min_lookup_inputs = self.min_lookup_inputs.min(range.0);
let range_size = (range.1 - range.0).abs();
self.max_range_size = self.max_range_size.max(range_size);
Ok(())
}
@@ -348,6 +480,26 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants
}
/// Get the dynamic lookup index
pub fn dynamic_lookup_index(&self) -> usize {
self.dynamic_lookup_index.index
}
/// Get the dynamic lookup column coordinate
pub fn dynamic_lookup_col_coord(&self) -> usize {
self.dynamic_lookup_index.col_coord
}
/// Get the shuffle index
pub fn shuffle_index(&self) -> usize {
self.shuffle_index.index
}
/// Get the shuffle column coordinate
pub fn shuffle_col_coord(&self) -> usize {
self.shuffle_index.col_coord
}
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.used_lookups.clone()
@@ -368,6 +520,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.min_lookup_inputs
}
/// max range check
pub fn max_range_size(&self) -> i128 {
self.max_range_size
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;
@@ -392,6 +549,38 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
}
///
pub fn combined_dynamic_shuffle_coord(&self) -> usize {
self.dynamic_lookup_col_coord() + self.shuffle_col_coord()
}
/// Assign a valtensor to a vartensor
pub fn assign_dynamic_lookup(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
)
} else {
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor
pub fn assign_shuffle(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,

View File

@@ -130,14 +130,12 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
}
///
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
// double it to be safe
let range_len = range.1 - range.0;
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
}
///
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as i128)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
@@ -152,7 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = Self::num_cols_required(range, col_size);
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
log::debug!("table range: {:?}", range);
@@ -265,7 +263,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
#[derive(Clone, Debug)]
pub struct RangeCheck<F: PrimeField> {
/// Input to table.
pub input: TableColumn,
pub inputs: Vec<TableColumn>,
/// col size
pub col_size: usize,
/// selector cn
pub selector_constructor: SelectorConstructor<F>,
/// Flags if table has been previously assigned to.
@@ -277,8 +277,10 @@ pub struct RangeCheck<F: PrimeField> {
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self) -> F {
i128_to_felt(self.range.0)
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
// we index from 1 to prevent soundness issues
i128_to_felt(chunk * (self.col_size as i128) + self.range.0)
}
///
@@ -290,24 +292,58 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk =
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
i128_to_felt(chunk)
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range) -> RangeCheck<F> {
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
let inputs = cs.lookup_table_column();
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
let col_size = Self::cal_col_size(logrows, factors);
// number of cols needed to store the range
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
let inputs = {
let mut cols = vec![];
for _ in 0..num_cols {
cols.push(cs.lookup_table_column());
}
cols
};
let num_cols = inputs.len();
if num_cols > 1 {
warn!("Using {} columns for range-check.", num_cols);
}
RangeCheck {
input: inputs,
inputs,
col_size,
is_assigned: false,
selector_constructor: SelectorConstructor::new(2),
selector_constructor: SelectorConstructor::new(num_cols),
range,
_marker: PhantomData,
}
}
/// Take a linear coordinate and output the (column, row) position in the storage block.
pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize) {
let x = linear_coord / self.col_size;
let y = linear_coord % self.col_size;
(x, y)
}
/// Assigns values to the constraints generated when calling `configure`.
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
if self.is_assigned {
@@ -318,28 +354,43 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
layouter.assign_table(
|| "range check table",
|mut table| {
let _ = inputs
.iter()
.enumerate()
.map(|(row_offset, input)| {
table.assign_cell(
|| format!("rc_i_col row {}", row_offset),
self.input,
row_offset,
|| Value::known(*input),
)?;
let col_multipliers: Vec<F> = (0..chunked_inputs.len())
.map(|x| self.selector_constructor.get_selector_val_at_idx(x))
.collect();
let _ = chunked_inputs
.enumerate()
.map(|(chunk_idx, inputs)| {
layouter.assign_table(
|| "range check table",
|mut table| {
let _ = inputs
.iter()
.enumerate()
.map(|(mut row_offset, input)| {
let col_multiplier = col_multipliers[chunk_idx];
row_offset += chunk_idx * self.col_size;
let (x, y) = self.cartesian_coord(row_offset);
table.assign_cell(
|| format!("rc_i_col row {}", row_offset),
self.inputs[x],
y,
|| Value::known(*input * col_multiplier),
)?;
Ok(())
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
},
)?;
},
)
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
}
}

View File

@@ -1,4 +1,3 @@
use crate::circuit::ops::hybrid::HybridOp;
use crate::circuit::ops::poly::PolyOp;
use crate::circuit::*;
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
@@ -246,7 +245,13 @@ mod matmul_col_overflow {
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow_double_col {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -325,43 +330,46 @@ mod matmul_col_ultra_overflow_double_col {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
assert!(result.is_ok());
println!("done.");
}
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -439,36 +447,33 @@ mod matmul_col_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1140,7 +1145,15 @@ mod conv {
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::{
kzg::strategy::SingleStrategy,
kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
},
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -1238,36 +1251,33 @@ mod conv_col_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1275,7 +1285,13 @@ mod conv_col_ultra_overflow {
// not wasm 32 unknown
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_relu_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use halo2_proofs::poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -1388,36 +1404,33 @@ mod conv_relu_col_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
// use safe mode to verify that the proof is correct
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1555,6 +1568,280 @@ mod add {
}
}
#[cfg(test)]
mod dynamic_lookup {
use super::*;
const K: usize = 6;
const LEN: usize = 4;
const NUM_LOOP: usize = 5;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
tables: [[ValTensor<F>; 2]; NUM_LOOP],
lookups: [[ValTensor<F>; 2]; NUM_LOOP],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 2, LEN);
let b = VarTensor::new_advice(cs, K, 2, LEN);
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
config
.configure_dynamic_lookup(
cs,
&[a.clone(), b.clone(), c.clone()],
&[d.clone(), e.clone(), f.clone()],
)
.unwrap();
config
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
for i in 0..NUM_LOOP {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i],
&self.tables[i],
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.dynamic_lookup_col_coord(),
NUM_LOOP * self.tables[0][0].len()
);
assert_eq!(region.dynamic_lookup_index(), NUM_LOOP);
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn dynamiclookupcircuit() {
// parameters
let tables = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
tables: tables.clone().try_into().unwrap(),
lookups: lookups.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
let prev_idx = if loop_idx == 0 {
NUM_LOOP - 1
} else {
loop_idx - 1
};
[
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((prev_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
tables: tables.try_into().unwrap(),
lookups: lookups.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert!(prover.verify().is_err());
}
}
#[cfg(test)]
mod shuffle {
use super::*;
const K: usize = 6;
const LEN: usize = 4;
const NUM_LOOP: usize = 5;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
references: [[ValTensor<F>; 1]; NUM_LOOP],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 2, LEN);
let b = VarTensor::new_advice(cs, K, 2, LEN);
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
config
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.unwrap();
config
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
for i in 0..NUM_LOOP {
layouts::shuffles(
&config,
&mut region,
&self.inputs[i],
&self.references[i],
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.shuffle_col_coord(),
NUM_LOOP * self.references[0][0].len()
);
assert_eq!(region.shuffle_index(), NUM_LOOP);
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn shufflecircuit() {
// parameters
let references = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
references: references.clone().try_into().unwrap(),
inputs: inputs.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
let prev_idx = if loop_idx == 0 {
NUM_LOOP - 1
} else {
loop_idx - 1
};
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * prev_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
references: references.try_into().unwrap(),
inputs: inputs.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert!(prover.verify().is_err());
}
}
#[cfg(test)]
mod add_with_overflow {
use super::*;
@@ -1958,75 +2245,6 @@ mod pow {
}
}
#[cfg(test)]
mod pack {
use super::*;
const K: usize = 8;
const LEN: usize = 4;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 1],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(
&mut region,
&self.inputs.clone(),
Box::new(PolyOp::Pack(2, 1)),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
fn packcircuit() {
// parameters
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = MyCircuit::<F> {
inputs: [ValTensor::from(a)],
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}
#[cfg(test)]
mod matmul_relu {
use super::*;
@@ -2120,148 +2338,6 @@ mod matmul_relu {
}
}
#[cfg(test)]
mod rangecheckpercent {
use crate::circuit::Tolerance;
use crate::{circuit, tensor::Tensor};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
const RANGE: f32 = 1.0; // 1 percent error tolerance
const K: usize = 18;
const LEN: usize = 1;
const SCALE: usize = i128::pow(2, 7) as usize;
use super::*;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
input: ValTensor<F>,
output: ValTensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let scale = utils::F32(SCALE as f32);
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
// set up a new GreaterThan and Recip tables
let nl = &LookupOp::GreaterThan {
a: circuit::utils::F32((RANGE * SCALE.pow(2) as f32) / 100.0),
};
config
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, nl)
.unwrap();
config
.configure_lookup(
cs,
&b,
&output,
&a,
(-32768, 32768),
K,
&LookupOp::Recip {
input_scale: scale,
output_scale: scale,
},
)
.unwrap();
config
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(
&mut region,
&[self.output.clone(), self.input.clone()],
Box::new(HybridOp::RangeCheck(Tolerance {
val: RANGE,
scale: SCALE.into(),
})),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
#[allow(clippy::assertions_on_constants)]
fn test_range_check_percent() {
// Successful cases
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(101_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(200_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(199_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
// Unsuccessful case
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(102_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
match prover.verify() {
Ok(_) => {
assert!(false)
}
Err(_) => {
assert!(true)
}
}
}
}
}
#[cfg(test)]
mod relu {
use super::*;
@@ -2343,8 +2419,13 @@ mod lookup_ultra_overflow {
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
poly::commitment::ParamsProver,
poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
},
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
#[derive(Clone)]
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
@@ -2423,145 +2504,32 @@ mod lookup_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ReLUCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy);
assert!(result.is_ok());
println!("done.");
}
}
#[cfg(test)]
mod softmax {
use super::*;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
const K: usize = 18;
const LEN: usize = 3;
const SCALE: f32 = 128.0;
#[derive(Clone)]
struct SoftmaxCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for SoftmaxCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE);
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();
config
.configure_lookup(
cs,
&advices[0],
&advices[1],
&advices[2],
(-32768, 32768),
K,
&LookupOp::Exp {
scale: SCALE.into(),
},
)
.unwrap();
config
.configure_lookup(
cs,
&advices[0],
&advices[1],
&advices[2],
(-32768, 32768),
K,
&LookupOp::Recip {
input_scale: SCALE.into(),
output_scale: SCALE.into(),
},
)
.unwrap();
config
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
config.layout_tables(&mut layouter).unwrap();
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
let _output = config
.layout(
&mut region,
&[self.input.clone()],
Box::new(HybridOp::Softmax {
scale: SCALE.into(),
axes: vec![0],
}),
)
.unwrap();
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn softmax_circuit() {
let input = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = SoftmaxCircuit::<F> {
input: ValTensor::from(input),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}

View File

@@ -1,4 +1,4 @@
use clap::{Parser, Subcommand, ValueEnum};
use clap::{Parser, Subcommand};
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
#[cfg(feature = "python-bindings")]
@@ -9,10 +9,11 @@ use pyo3::{
types::PyString,
};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::path::PathBuf;
use std::{error::Error, str::FromStr};
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, RunArgs};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
#[cfg(not(target_arch = "wasm32"))]
@@ -76,7 +77,7 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
/// Default lookup safety margin
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
/// Default Compress selectors
pub const DEFAULT_COMPRESS_SELECTORS: &str = "false";
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
/// Default render vk seperately
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
/// Default VK sol path
@@ -85,15 +86,13 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol";
pub const DEFAULT_VK_ABI: &str = "vk.abi";
/// Default scale rebase multipliers for calibration
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
/// Default use reduced srs for verification
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
/// Default only check for range check rebase
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
/// Default commitment
pub const DEFAULT_COMMITMENT: &str = "kzg";
impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.to_possible_value()
.expect("no values are skipped")
.get_name()
.fmt(f)
}
}
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
impl IntoPy<PyObject> for TranscriptType {
@@ -138,17 +137,27 @@ impl Default for CalibrationTarget {
}
}
impl ToString for CalibrationTarget {
fn to_string(&self) -> String {
match self {
CalibrationTarget::Resources { col_overflow: true } => {
"resources/col-overflow".to_string()
impl std::fmt::Display for CalibrationTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
CalibrationTarget::Resources { col_overflow: true } => {
"resources/col-overflow".to_string()
}
CalibrationTarget::Resources {
col_overflow: false,
} => "resources".to_string(),
CalibrationTarget::Accuracy => "accuracy".to_string(),
}
CalibrationTarget::Resources {
col_overflow: false,
} => "resources".to_string(),
CalibrationTarget::Accuracy => "accuracy".to_string(),
}
)
}
}
impl ToFlags for CalibrationTarget {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
@@ -169,6 +178,36 @@ impl From<&str> for CalibrationTarget {
}
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
/// wrapper for H160 to make it easy to parse into flag vals
pub struct H160Flag {
inner: H160,
}
#[cfg(not(target_arch = "wasm32"))]
impl From<H160Flag> for H160 {
fn from(val: H160Flag) -> H160 {
val.inner
}
}
#[cfg(not(target_arch = "wasm32"))]
impl ToFlags for H160Flag {
fn to_flags(&self) -> Vec<String> {
vec![format!("{:#x}", self.inner)]
}
}
#[cfg(not(target_arch = "wasm32"))]
impl From<&str> for H160Flag {
fn from(s: &str) -> Self {
Self {
inner: H160::from_str(s).unwrap(),
}
}
}
#[cfg(feature = "python-bindings")]
/// Converts CalibrationTarget into a PyObject (Required for CalibrationTarget to be compatible with Python)
impl IntoPy<PyObject> for CalibrationTarget {
@@ -201,7 +240,7 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
}
}
}
// not wasm
use lazy_static::lazy_static;
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
@@ -242,7 +281,7 @@ impl Cli {
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
#[cfg(feature = "empty-cmd")]
/// Creates an empty buffer
@@ -257,21 +296,6 @@ pub enum Commands {
args: RunArgs,
},
#[cfg(feature = "render")]
/// Renders the model circuit to a .png file. For an overview of how to interpret these plots, see https://zcash.github.io/halo2/user/dev-tools.html
#[command(arg_required_else_help = true)]
RenderCircuit {
/// The path to the .onnx model file
#[arg(short = 'M', long)]
model: PathBuf,
/// Path to save the .png circuit render
#[arg(short = 'O', long)]
output: PathBuf,
/// proving arguments
#[clap(flatten)]
args: RunArgs,
},
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
@@ -336,9 +360,9 @@ pub enum Commands {
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
max_logrows: Option<u32>,
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
#[arg(long)]
div_rebasing: Option<bool>,
// whether to only range check rebases (instead of trying both range check and lookup)
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
only_range_check_rebase: bool,
},
/// Generates a dummy SRS
@@ -350,6 +374,9 @@ pub enum Commands {
/// number of logrows to use for srs
#[arg(long)]
logrows: usize,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -365,9 +392,9 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
logrows: Option<u32>,
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check: CheckMode,
/// Commitment used
#[arg(long, default_value = None)]
commitment: Option<Commitments>,
},
/// Loads model and input and runs mock prover (for testing)
Mock {
@@ -413,8 +440,11 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
},
/// Aggregates proofs :)
Aggregate {
@@ -434,7 +464,7 @@ pub enum Commands {
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,
@@ -447,6 +477,9 @@ pub enum Commands {
/// whether the accumulated proofs are segments of a larger circuit
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
@@ -478,38 +511,13 @@ pub enum Commands {
#[arg(short = 'W', long)]
witness: Option<PathBuf>,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Fuzzes the proof pipeline with random inputs, random parameters, and random keys
Fuzz {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
compiled_circuit: PathBuf,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::EVM,
value_enum
)]
transcript: TranscriptType,
/// number of fuzz iterations
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
num_runs: usize,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEVMData {
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'D', long)]
data: PathBuf,
@@ -537,7 +545,7 @@ pub enum Commands {
TestUpdateAccountCalls {
/// The path to the verifier contract's address
#[arg(long)]
addr: H160,
addr: H160Flag,
/// The path to the .json data file.
#[arg(short = 'D', long)]
data: PathBuf,
@@ -587,9 +595,9 @@ pub enum Commands {
check_mode: CheckMode,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEVMVerifier {
CreateEvmVerifier {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -612,9 +620,9 @@ pub enum Commands {
render_vk_seperately: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-vk")]
CreateEVMVK {
CreateEvmVK {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -632,9 +640,9 @@ pub enum Commands {
abi_path: PathBuf,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier that attests to on-chain inputs for a single proof
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEVMDataAttestation {
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
@@ -654,9 +662,9 @@ pub enum Commands {
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for an aggregate proof
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
CreateEVMVerifierAggr {
CreateEvmVerifierAggr {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
@@ -695,6 +703,9 @@ pub enum Commands {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
},
/// Verifies an aggregate proof, returning accept or reject
VerifyAggr {
@@ -704,12 +715,18 @@ pub enum Commands {
/// The path to the verification key file (generated using the setup-aggregate command)
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
vk_path: PathBuf,
/// reduced srs
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
@@ -776,23 +793,23 @@ pub enum Commands {
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Verifies a proof using a local EVM executor, returning accept or reject
/// Verifies a proof using a local Evm executor, returning accept or reject
#[command(name = "verify-evm")]
VerifyEVM {
VerifyEvm {
/// The path to the proof file (generated using the prove command)
#[arg(long, default_value = DEFAULT_PROOF)]
proof_path: PathBuf,
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
addr_verifier: H160,
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
rpc_url: Option<String>,
/// does the verifier use data attestation ?
#[arg(long)]
addr_da: Option<H160>,
addr_da: Option<H160Flag>,
// is the vk rendered seperately, if so specify an address
#[arg(long)]
addr_vk: Option<H160>,
addr_vk: Option<H160Flag>,
},
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ use crate::circuit::InputType;
use crate::fieldutils::i128_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::tensor::Tensor;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
use postgres::{Client, NoTls};
@@ -15,6 +16,8 @@ use pyo3::types::PyDict;
use pyo3::ToPyObject;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
@@ -490,16 +493,20 @@ impl GraphData {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let mut file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to open input at {}", path.display()))?;
let mut data = String::new();
file.read_to_string(&mut data)?;
serde_json::from_str(&data).map_err(|e| e.into())
let reader = std::fs::File::open(path)?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
let mut buf = String::new();
reader.read_to_string(&mut buf)?;
let graph_input = serde_json::from_str(&buf)?;
Ok(graph_input)
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, self)?;
Ok(())
}
///
@@ -617,13 +624,13 @@ impl ToPyObject for DataSource {
}
#[cfg(feature = "python-bindings")]
use crate::pfsys::field_to_string_montgomery;
use crate::pfsys::field_to_string;
#[cfg(feature = "python-bindings")]
impl ToPyObject for FileSourceInner {
fn to_object(&self, py: Python) -> PyObject {
match self {
FileSourceInner::Field(data) => field_to_string_montgomery(data).to_object(py),
FileSourceInner::Field(data) => field_to_string(data).to_object(py),
FileSourceInner::Bool(data) => data.to_object(py),
FileSourceInner::Float(data) => data.to_object(py),
}

View File

@@ -12,10 +12,13 @@ pub mod utilities;
pub mod vars;
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
use halo2_proofs::poly::commitment::CommitmentScheme;
pub use input::DataSource;
use itertools::Itertools;
use tosubcommand::ToFlags;
#[cfg(not(target_arch = "wasm32"))]
use self::input::OnChainSource;
@@ -23,21 +26,22 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::table::{Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::RunArgs;
use crate::{RunArgs, EZKL_BUF_CAPACITY};
use halo2_proofs::{
circuit::Layouter,
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, info, trace, warn};
use log::{debug, error, trace, warn};
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
pub use model::*;
pub use node::*;
@@ -48,20 +52,22 @@ use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
use std::ops::Deref;
use thiserror::Error;
pub use utilities::*;
pub use vars::*;
#[cfg(feature = "python-bindings")]
use crate::pfsys::field_to_string_montgomery;
use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
@@ -120,7 +126,7 @@ pub enum GraphError {
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Error when attempting to load a model
#[error("failed to load model")]
#[error("failed to load")]
ModelLoad,
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
@@ -133,15 +139,16 @@ pub enum GraphError {
MissingResults,
}
const ASSUMED_BLINDING_FACTORS: usize = 5;
///
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
/// The minimum number of rows in the grid
pub const MIN_LOGROWS: u32 = 6;
/// 26
pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
/// Lookup deg
pub const LOOKUP_DEG: usize = 5;
///
pub const RESERVED_BLINDING_ROWS: usize = ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD;
use std::cell::RefCell;
@@ -170,6 +177,8 @@ pub struct GraphWitness {
pub max_lookup_inputs: i128,
/// max lookup input
pub min_lookup_inputs: i128,
/// max range check size
pub max_range_size: i128,
}
impl GraphWitness {
@@ -197,6 +206,7 @@ impl GraphWitness {
processed_outputs: None,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
}
}
@@ -207,42 +217,41 @@ impl GraphWitness {
output_scales: Vec<crate::Scale>,
visibility: VarVisibility,
) {
let mut pretty_elements = PrettyElements::default();
pretty_elements.rescaled_inputs = self
.inputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = input_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect();
pretty_elements.inputs = self
.inputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect();
pretty_elements.rescaled_outputs = self
.outputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = output_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect();
pretty_elements.outputs = self
.outputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect();
let mut pretty_elements = PrettyElements {
rescaled_inputs: self
.inputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = input_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect(),
inputs: self
.inputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect(),
rescaled_outputs: self
.outputs
.iter()
.enumerate()
.map(|(i, t)| {
let scale = output_scales[i];
t.iter()
.map(|x| dequantize(*x, scale, 0.).to_string())
.collect()
})
.collect(),
outputs: self
.outputs
.iter()
.map(|t| t.iter().map(|x| format!("{:?}", x)).collect())
.collect(),
..Default::default()
};
if let Some(processed_inputs) = self.processed_inputs.clone() {
pretty_elements.processed_inputs = processed_inputs
@@ -275,20 +284,20 @@ impl GraphWitness {
}
///
pub fn get_kzg_commitments(&self) -> Vec<G1Affine> {
pub fn get_polycommitments(&self) -> Vec<G1Affine> {
let mut commitments = vec![];
if let Some(processed_inputs) = &self.processed_inputs {
if let Some(commits) = &processed_inputs.kzg_commit {
if let Some(commits) = &processed_inputs.polycommit {
commitments.extend(commits.iter().flatten());
}
}
if let Some(processed_params) = &self.processed_params {
if let Some(commits) = &processed_params.kzg_commit {
if let Some(commits) = &processed_params.polycommit {
commitments.extend(commits.iter().flatten());
}
}
if let Some(processed_outputs) = &self.processed_outputs {
if let Some(commits) = &processed_outputs.kzg_commit {
if let Some(commits) = &processed_outputs.polycommit {
commitments.extend(commits.iter().flatten());
}
}
@@ -308,16 +317,20 @@ impl GraphWitness {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let mut file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to load model at {}", path.display()))?;
let mut data = String::new();
file.read_to_string(&mut data)?;
serde_json::from_str(&data).map_err(|e| e.into())
let file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to load {}", path.display()))?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into())
// use buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| e.into())
}
///
@@ -351,27 +364,31 @@ impl ToPyObject for GraphWitness {
let inputs: Vec<Vec<String>> = self
.inputs
.iter()
.map(|x| x.iter().map(field_to_string_montgomery).collect())
.map(|x| x.iter().map(field_to_string).collect())
.collect();
let outputs: Vec<Vec<String>> = self
.outputs
.iter()
.map(|x| x.iter().map(field_to_string_montgomery).collect())
.map(|x| x.iter().map(field_to_string).collect())
.collect();
dict.set_item("inputs", inputs).unwrap();
dict.set_item("outputs", outputs).unwrap();
dict.set_item("max_lookup_inputs", self.max_lookup_inputs)
.unwrap();
dict.set_item("min_lookup_inputs", self.min_lookup_inputs)
.unwrap();
dict.set_item("max_range_size", self.max_range_size)
.unwrap();
if let Some(processed_inputs) = &self.processed_inputs {
//poseidon_hash
if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash {
insert_poseidon_hash_pydict(dict_inputs, processed_inputs_poseidon_hash).unwrap();
}
if let Some(processed_inputs_kzg_commit) = &processed_inputs.kzg_commit {
insert_kzg_commit_pydict(dict_inputs, processed_inputs_kzg_commit).unwrap();
if let Some(processed_inputs_polycommit) = &processed_inputs.polycommit {
insert_polycommit_pydict(dict_inputs, processed_inputs_polycommit).unwrap();
}
dict.set_item("processed_inputs", dict_inputs).unwrap();
@@ -381,8 +398,8 @@ impl ToPyObject for GraphWitness {
if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash {
insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash).unwrap();
}
if let Some(processed_params_kzg_commit) = &processed_params.kzg_commit {
insert_kzg_commit_pydict(dict_inputs, processed_params_kzg_commit).unwrap();
if let Some(processed_params_polycommit) = &processed_params.polycommit {
insert_polycommit_pydict(dict_inputs, processed_params_polycommit).unwrap();
}
dict.set_item("processed_params", dict_params).unwrap();
@@ -392,8 +409,8 @@ impl ToPyObject for GraphWitness {
if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash {
insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash).unwrap();
}
if let Some(processed_outputs_kzg_commit) = &processed_outputs.kzg_commit {
insert_kzg_commit_pydict(dict_inputs, processed_outputs_kzg_commit).unwrap();
if let Some(processed_outputs_polycommit) = &processed_outputs.polycommit {
insert_polycommit_pydict(dict_inputs, processed_outputs_polycommit).unwrap();
}
dict.set_item("processed_outputs", dict_outputs).unwrap();
@@ -405,23 +422,20 @@ impl ToPyObject for GraphWitness {
#[cfg(feature = "python-bindings")]
fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Result<(), PyErr> {
let poseidon_hash: Vec<String> = poseidon_hash
.iter()
.map(field_to_string_montgomery)
.collect();
let poseidon_hash: Vec<String> = poseidon_hash.iter().map(field_to_string).collect();
pydict.set_item("poseidon_hash", poseidon_hash)?;
Ok(())
}
#[cfg(feature = "python-bindings")]
fn insert_kzg_commit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
fn insert_polycommit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
use crate::python::PyG1Affine;
let poseidon_hash: Vec<Vec<PyG1Affine>> = commits
.iter()
.map(|c| c.iter().map(|x| PyG1Affine::from(*x)).collect())
.collect();
pydict.set_item("kzg_commit", poseidon_hash)?;
pydict.set_item("polycommit", poseidon_hash)?;
Ok(())
}
@@ -437,6 +451,14 @@ pub struct GraphSettings {
pub total_assignments: usize,
/// total const size
pub total_const_size: usize,
/// total dynamic column size
pub total_dynamic_col_size: usize,
/// number of dynamic lookups
pub num_dynamic_lookups: usize,
/// number of shuffles
pub num_shuffles: usize,
/// total shuffle column size
pub total_shuffle_col_size: usize,
/// the shape of public inputs to the model (in order of appearance)
pub model_instance_shapes: Vec<Vec<usize>>,
/// model output scales
@@ -460,6 +482,30 @@ pub struct GraphSettings {
}
impl GraphSettings {
fn model_constraint_logrows(&self) -> u32 {
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_and_shuffle_logrows(&self) -> u32 {
(self.total_dynamic_col_size as f64 + self.total_shuffle_col_size as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
self.total_dynamic_col_size + self.total_shuffle_col_size
}
fn module_constraint_logrows(&self) -> u32 {
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
}
fn constants_logrows(&self) -> u32 {
(self.total_const_size as f64).log2().ceil() as u32
}
/// calculate the total number of instances
pub fn total_instances(&self) -> Vec<usize> {
let mut instances: Vec<usize> = self
@@ -472,22 +518,33 @@ impl GraphSettings {
instances
}
/// calculate the log2 of the total number of instances
pub fn log2_total_instances(&self) -> u32 {
let sum = self.total_instances().iter().sum::<usize>();
// max between 1 and the log2 of the sums
std::cmp::max((sum as f64).log2().ceil() as u32, 1)
}
/// save params to file
pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> {
let encoded = serde_json::to_string(&self)?;
let mut file = std::fs::File::create(path)?;
file.write_all(encoded.as_bytes())
// buf writer
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| {
error!("failed to save settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})
}
/// load params from file
pub fn load(path: &std::path::PathBuf) -> Result<Self, std::io::Error> {
let mut file = std::fs::File::open(path).map_err(|e| {
error!("failed to open settings file at {}", e);
e
})?;
let mut data = String::new();
file.read_to_string(&mut data)?;
let res = serde_json::from_str(&data)?;
Ok(res)
// buf reader
let reader =
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
})
}
/// Export the ezkl configuration as json
@@ -533,11 +590,21 @@ impl GraphSettings {
|| self.run_args.param_visibility.is_hashed()
}
/// requires dynamic lookup
pub fn requires_dynamic_lookup(&self) -> bool {
self.num_dynamic_lookups > 0
}
/// requires dynamic shuffle
pub fn requires_shuffle(&self) -> bool {
self.num_shuffles > 0
}
/// any kzg visibility
pub fn module_requires_kzg(&self) -> bool {
self.run_args.input_visibility.is_kzgcommit()
|| self.run_args.output_visibility.is_kzgcommit()
|| self.run_args.param_visibility.is_kzgcommit()
pub fn module_requires_polycommit(&self) -> bool {
self.run_args.input_visibility.is_polycommit()
|| self.run_args.output_visibility.is_polycommit()
|| self.run_args.param_visibility.is_polycommit()
}
}
@@ -583,7 +650,7 @@ impl GraphCircuit {
///
pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box<dyn std::error::Error>> {
let f = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(f);
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
@@ -591,11 +658,10 @@ impl GraphCircuit {
///
pub fn load(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
// read bytes from file
let mut f = std::fs::File::open(&path)?;
let metadata = std::fs::metadata(&path)?;
let mut buffer = vec![0; metadata.len() as usize];
f.read_exact(&mut buffer)?;
let result = bincode::deserialize(&buffer)?;
let f = std::fs::File::open(path)?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
Ok(result)
}
}
@@ -610,6 +676,17 @@ pub enum TestDataSource {
OnChain,
}
impl std::fmt::Display for TestDataSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestDataSource::File => write!(f, "file"),
TestDataSource::OnChain => write!(f, "on-chain"),
}
}
}
impl ToFlags for TestDataSource {}
impl From<String> for TestDataSource {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
@@ -971,18 +1048,10 @@ impl GraphCircuit {
Ok(data)
}
fn reserved_blinding_rows() -> f64 {
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
}
fn calc_safe_lookup_range(
min_lookup_inputs: i128,
max_lookup_inputs: i128,
lookup_safety_margin: i128,
) -> Range {
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let mut margin = (
lookup_safety_margin * min_lookup_inputs,
lookup_safety_margin * max_lookup_inputs,
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
);
if lookup_safety_margin == 1 {
margin.0 += 4;
@@ -992,18 +1061,34 @@ impl GraphCircuit {
margin
}
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(
max_logrows as usize,
Self::reserved_blinding_rows() as usize,
);
Table::<Fp>::num_cols_required(safe_range, max_col_size)
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
fn calc_min_logrows(
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: i128,
) -> Result<u32, Box<dyn std::error::Error>> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
max_range_size,
);
let min_bits = (safe_range as f64 + RESERVED_BLINDING_ROWS as f64 + 1.)
.log2()
.ceil() as u32;
Ok(min_bits)
}
/// calculate the minimum logrows required for the circuit
pub fn calc_min_logrows(
&mut self,
min_lookup_inputs: i128,
max_lookup_inputs: i128,
min_max_lookup: Range,
max_range_size: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
@@ -1013,54 +1098,60 @@ impl GraphCircuit {
let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);
let mut min_logrows = MIN_LOGROWS;
let reserved_blinding_rows = Self::reserved_blinding_rows();
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// check if has overflowed max lookup input
if max_lookup_inputs > MAX_LOOKUP_ABS / lookup_safety_margin
|| min_lookup_inputs < -MAX_LOOKUP_ABS / lookup_safety_margin
{
let err_string = format!("max lookup input ({}) is too large", max_lookup_inputs);
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
return Err(err_string.into());
}
let safe_range = Self::calc_safe_lookup_range(
min_lookup_inputs,
max_lookup_inputs,
lookup_safety_margin,
if max_range_size.abs() > MAX_LOOKUP_ABS {
let err_string = format!("max range check size {:?} is too large", max_range_size);
return Err(err_string.into());
}
// These are hard lower limits, we can't overflow instances or modules constraints
let instance_logrows = self.settings().log2_total_instances();
let module_constraint_logrows = self.settings().module_constraint_logrows();
let dynamic_lookup_logrows = self.settings().dynamic_lookup_and_shuffle_logrows();
min_logrows = std::cmp::max(
min_logrows,
// max of the instance logrows and the module constraint logrows and the dynamic lookup logrows is the lower limit
*[
instance_logrows,
module_constraint_logrows,
dynamic_lookup_logrows,
]
.iter()
.max()
.unwrap(),
);
// These are upper limits, going above these is wasteful, but they are not hard limits
let model_constraint_logrows = self.settings().model_constraint_logrows();
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
let constants_logrows = self.settings().constants_logrows();
max_logrows = std::cmp::min(
max_logrows,
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
*[model_constraint_logrows, min_bits, constants_logrows]
.iter()
.max()
.unwrap(),
);
// we now have a min and max logrows
max_logrows = std::cmp::max(min_logrows, max_logrows);
// degrade the max logrows until the extended k is small enough
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
min_logrows,
Self::calc_num_cols(safe_range, min_logrows),
)
{
min_logrows += 1;
}
if !self
.extended_k_is_small_enough(min_logrows, Self::calc_num_cols(safe_range, min_logrows))
{
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
min_logrows
);
debug!("{}", err_string);
return Err(err_string.into());
}
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
max_logrows,
Self::calc_num_cols(safe_range, max_logrows),
)
&& !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size)
{
max_logrows -= 1;
}
if !self
.extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows))
{
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
@@ -1069,67 +1160,27 @@ impl GraphCircuit {
return Err(err_string.into());
}
let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.)
.log2()
.ceil() as usize;
let min_rows_from_constraints = (self.settings().num_rows as f64 + reserved_blinding_rows)
.log2()
.ceil() as usize;
let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints);
// if public input then public inputs col will have public inputs len
if self.settings().run_args.input_visibility.is_public()
|| self.settings().run_args.output_visibility.is_public()
{
let mut max_instance_len = self
.model()
.instance_shapes()?
.iter()
.fold(0, |acc, x| std::cmp::max(acc, x.iter().product::<usize>()))
as f64
+ reserved_blinding_rows;
// if there are modules then we need to add the max module size
if self.settings().uses_modules() {
max_instance_len += self
.settings()
.module_sizes
.num_instances()
.iter()
.sum::<usize>() as f64;
}
let instance_len_logrows = (max_instance_len).log2().ceil() as usize;
logrows = std::cmp::max(logrows, instance_len_logrows);
// this is for fixed const columns
}
// ensure logrows is at least 4
logrows = std::cmp::max(logrows, min_logrows as usize);
logrows = std::cmp::min(logrows, max_logrows as usize);
let logrows = max_logrows;
let model = self.model().clone();
let settings_mut = self.settings_mut();
settings_mut.run_args.lookup_range = safe_range;
settings_mut.run_args.logrows = logrows as u32;
settings_mut.run_args.lookup_range = safe_lookup_range;
settings_mut.run_args.logrows = logrows;
*settings_mut = GraphCircuit::new(model, &settings_mut.run_args)?
.settings()
.clone();
// recalculate the total const size give nthe new logrows
let total_const_len = settings_mut.total_const_size;
let const_len_logrows = (total_const_len as f64).log2().ceil() as u32;
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, const_len_logrows);
// recalculate the total number of constraints given the new logrows
let min_rows_from_constraints = (settings_mut.num_rows as f64 + reserved_blinding_rows)
.log2()
.ceil() as u32;
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints);
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
// recalculate the logrows if there has been overflow on the constants
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.constants_logrows(),
);
// recalculate the logrows if there has been overflow for the model constraints
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.model_constraint_logrows(),
);
debug!(
"setting lookup_range to: {:?}, setting logrows to: {}",
@@ -1140,12 +1191,48 @@ impl GraphCircuit {
Ok(())
}
fn extended_k_is_small_enough(&self, k: u32, num_lookup_cols: usize) -> bool {
let max_degree = self.settings().run_args.num_inner_cols + 2;
let max_lookup_degree = LOOKUP_DEG + num_lookup_cols - 1; // num_lookup_cols - 1 is the degree of the lookup synthetic selector
fn extended_k_is_small_enough(
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: i128,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
|| Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS
{
return false;
}
let max_degree = std::cmp::max(max_degree, max_lookup_degree);
let mut settings = self.settings().clone();
settings.run_args.lookup_range = safe_lookup_range;
settings.run_args.logrows = k;
settings.required_range_checks = vec![(0, max_range_size)];
let mut cs = ConstraintSystem::default();
// if unix get a gag
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(unix)]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
Self::configure_with_params(&mut cs, settings);
// drop the gag
#[cfg(unix)]
drop(_r);
#[cfg(unix)]
drop(_g);
#[cfg(feature = "mv-lookup")]
let cs = cs.chunk_lookups();
// quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial
let max_degree = cs.degree();
let quotient_poly_degree = (max_degree - 1) as u64;
// n = 2^k
let n = 1u64 << k;
@@ -1160,29 +1247,13 @@ impl GraphCircuit {
true
}
/// Calibrate the circuit to the supplied data.
pub fn calibrate_from_min_max(
&mut self,
min_lookup_inputs: i128,
max_lookup_inputs: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
self.calc_min_logrows(
min_lookup_inputs,
max_lookup_inputs,
max_logrows,
lookup_safety_margin,
)?;
Ok(())
}
/// Runs the forward pass of the model / graph of computations and any associated hashing.
pub fn forward(
pub fn forward<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
&self,
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&ParamsKZG<Bn256>>,
srs: Option<&Scheme::ParamsProver>,
throw_range_check_error: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1198,7 +1269,8 @@ impl GraphCircuit {
for outlet in &module_outlets {
module_inputs.push(inputs[*outlet].clone());
}
let res = GraphModules::forward(&module_inputs, &visibility.input, vk, srs)?;
let res =
GraphModules::forward::<Scheme>(&module_inputs, &visibility.input, vk, srs)?;
processed_inputs = Some(res.clone());
let module_results = res.get_result(visibility.input.clone());
@@ -1206,7 +1278,12 @@ impl GraphCircuit {
inputs[*outlet] = Tensor::from(module_results[i].clone().into_iter());
}
} else {
processed_inputs = Some(GraphModules::forward(inputs, &visibility.input, vk, srs)?);
processed_inputs = Some(GraphModules::forward::<Scheme>(
inputs,
&visibility.input,
vk,
srs,
)?);
}
}
@@ -1214,7 +1291,7 @@ impl GraphCircuit {
let params = self.model().get_all_params();
if !params.is_empty() {
let flattened_params = Tensor::new(Some(&params), &[params.len()])?.combine()?;
processed_params = Some(GraphModules::forward(
processed_params = Some(GraphModules::forward::<Scheme>(
&[flattened_params],
&visibility.params,
vk,
@@ -1223,7 +1300,9 @@ impl GraphCircuit {
}
}
let mut model_results = self.model().forward(inputs)?;
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1232,7 +1311,8 @@ impl GraphCircuit {
for outlet in &module_outlets {
module_inputs.push(model_results.outputs[*outlet].clone());
}
let res = GraphModules::forward(&module_inputs, &visibility.output, vk, srs)?;
let res =
GraphModules::forward::<Scheme>(&module_inputs, &visibility.output, vk, srs)?;
processed_outputs = Some(res.clone());
let module_results = res.get_result(visibility.output.clone());
@@ -1241,7 +1321,7 @@ impl GraphCircuit {
Tensor::from(module_results[i].clone().into_iter());
}
} else {
processed_outputs = Some(GraphModules::forward(
processed_outputs = Some(GraphModules::forward::<Scheme>(
&model_results.outputs,
&visibility.output,
vk,
@@ -1266,6 +1346,7 @@ impl GraphCircuit {
processed_outputs,
max_lookup_inputs: model_results.max_lookup_inputs,
min_lookup_inputs: model_results.min_lookup_inputs,
max_range_size: model_results.max_range_size,
};
witness.generate_rescaled_elements(
@@ -1472,34 +1553,18 @@ impl Circuit<Fp> for GraphCircuit {
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(
cs,
params.run_args.logrows as usize,
params.total_assignments,
params.run_args.num_inner_cols,
params.total_const_size,
params.module_requires_fixed(),
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
vars.instantiate_instance(
cs,
params.model_instance_shapes,
params.model_instance_shapes.clone(),
params.run_args.input_scale,
module_configs.instance,
);
let base = Model::configure(
cs,
&vars,
params.run_args.lookup_range,
params.run_args.logrows as usize,
params.required_lookups,
params.required_range_checks,
params.check_mode,
)
.unwrap();
let base = Model::configure(cs, &vars, &params).unwrap();
let model_config = ModelConfig { base, vars };
@@ -1512,7 +1577,7 @@ impl Circuit<Fp> for GraphCircuit {
let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows);
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"circuit size: \n {}",
circuit_size
.as_json()

View File

@@ -56,6 +56,8 @@ use unzip_n::unzip_n;
unzip_n!(pub 3);
#[cfg(not(target_arch = "wasm32"))]
type TractResult = (Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues);
/// The result of a forward pass.
#[derive(Clone, Debug)]
pub struct ForwardResult {
@@ -65,6 +67,8 @@ pub struct ForwardResult {
pub max_lookup_inputs: i128,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: i128,
/// The max range check size
pub max_range_size: i128,
}
impl From<DummyPassRes> for ForwardResult {
@@ -73,6 +77,7 @@ impl From<DummyPassRes> for ForwardResult {
outputs: res.outputs,
max_lookup_inputs: res.max_lookup_inputs,
min_lookup_inputs: res.min_lookup_inputs,
max_range_size: res.max_range_size,
}
}
}
@@ -94,6 +99,14 @@ pub type NodeGraph = BTreeMap<usize, NodeType>;
pub struct DummyPassRes {
/// number of rows use
pub num_rows: usize,
/// num dynamic lookups
pub num_dynamic_lookups: usize,
/// dynamic lookup col size
pub dynamic_lookup_col_coord: usize,
/// num shuffles
pub num_shuffles: usize,
/// shuffle
pub shuffle_col_coord: usize,
/// linear coordinate
pub linear_coord: usize,
/// total const size
@@ -106,6 +119,8 @@ pub struct DummyPassRes {
pub max_lookup_inputs: i128,
/// min lookup inputs
pub min_lookup_inputs: i128,
/// min range check
pub max_range_size: i128,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
@@ -519,7 +534,7 @@ impl Model {
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let res = self.dummy_layout(run_args, &inputs)?;
let res = self.dummy_layout(run_args, &inputs, false)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -533,6 +548,10 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
num_shuffles: res.num_shuffles,
total_shuffle_col_size: res.shuffle_col_coord,
total_const_size: res.total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
@@ -554,12 +573,17 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
/// * `run_args` - [RunArgs]
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
pub fn forward(
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(&RunArgs::default(), &valtensor_inputs)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
Ok(res.into())
}
@@ -572,7 +596,7 @@ impl Model {
fn load_onnx_using_tract(
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
) -> Result<(Graph<TypedFact, Box<dyn TypedOp>>, SymbolValues), Box<dyn Error>> {
) -> Result<TractResult, Box<dyn Error>> {
use tract_onnx::{
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
};
@@ -611,7 +635,7 @@ impl Model {
for (symbol, value) in run_args.variables.iter() {
let symbol = model.symbol_table.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
info!("set {} to {}", symbol, value);
debug!("set {} to {}", symbol, value);
}
// Note: do not optimize the model, as the layout will depend on underlying hardware
@@ -946,7 +970,7 @@ impl Model {
let (model, _) = Model::load_onnx_using_tract(
&mut std::fs::File::open(model_path)
.map_err(|_| format!("failed to load model at {}", model_path.display()))?,
.map_err(|_| format!("failed to load {}", model_path.display()))?,
run_args,
)?;
@@ -982,7 +1006,7 @@ impl Model {
) -> Result<Self, Box<dyn Error>> {
Model::new(
&mut std::fs::File::open(model)
.map_err(|_| format!("failed to load model at {}", model.display()))?,
.map_err(|_| format!("failed to load {}", model.display()))?,
run_args,
)
}
@@ -991,24 +1015,24 @@ impl Model {
/// # Arguments
/// * `meta` - The constraint system.
/// * `vars` - The variables for the circuit.
/// * `run_args` - [RunArgs]
/// * `required_lookups` - The required lookup operations for the circuit.
/// * `settings` - [GraphSettings]
pub fn configure(
meta: &mut ConstraintSystem<Fp>,
vars: &ModelVars<Fp>,
lookup_range: Range,
logrows: usize,
required_lookups: Vec<LookupOp>,
required_range_checks: Vec<Range>,
check_mode: CheckMode,
settings: &GraphSettings,
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
info!("configuring model");
debug!("configuring model");
let lookup_range = settings.run_args.lookup_range;
let logrows = settings.run_args.logrows as usize;
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
let mut base_gate = PolyConfig::configure(
meta,
vars.advices[0..2].try_into()?,
&vars.advices[2],
check_mode,
settings.check_mode,
);
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
let input = &vars.advices[0];
@@ -1019,7 +1043,23 @@ impl Model {
}
for range in required_range_checks {
base_gate.configure_range_check(meta, input, range)?;
base_gate.configure_range_check(meta, input, index, range, logrows)?;
}
if settings.requires_dynamic_lookup() {
base_gate.configure_dynamic_lookup(
meta,
vars.advices[0..3].try_into()?,
vars.advices[3..6].try_into()?,
)?;
}
if settings.requires_shuffle() {
base_gate.configure_shuffles(
meta,
vars.advices[0..2].try_into()?,
vars.advices[3..5].try_into()?,
)?;
}
Ok(base_gate)
@@ -1340,6 +1380,7 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1358,7 +1399,7 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
@@ -1369,27 +1410,26 @@ impl Model {
ValType::Constant(Fp::ONE)
};
let comparator = outputs
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.map(|x| {
let mut v: ValTensor<Fp> =
vec![default_value.clone(); x.dims().iter().product::<usize>()].into();
v.reshape(x.dims())?;
Ok(v)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.enumerate()
.map(|(i, output)| {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
let _ = outputs
.iter()
.zip(comparator)
.map(|(o, c)| {
dummy_config.layout(
&mut region,
&[o.clone(), c],
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
&[output.clone(), comparator],
Box::new(HybridOp::RangeCheck(tolerance)),
)
})
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<_>, _>>();
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
@@ -1401,7 +1441,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
info!(
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
region.row().to_string().blue(),
@@ -1426,6 +1466,11 @@ impl Model {
range_checks: region.used_range_checks(),
max_lookup_inputs: region.max_lookup_inputs(),
min_lookup_inputs: region.min_lookup_inputs(),
max_range_size: region.max_range_size(),
num_dynamic_lookups: region.dynamic_lookup_index(),
dynamic_lookup_col_coord: region.dynamic_lookup_col_coord(),
num_shuffles: region.shuffle_index(),
shuffle_col_coord: region.shuffle_col_coord(),
outputs,
};

View File

@@ -1,12 +1,12 @@
use crate::circuit::modules::kzg::{KZGChip, KZGConfig};
use crate::circuit::modules::polycommit::{PolyCommitChip, PolyCommitConfig};
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use crate::circuit::modules::Module;
use crate::tensor::{Tensor, ValTensor};
use halo2_proofs::circuit::Layouter;
use halo2_proofs::plonk::{Column, ConstraintSystem, Error, Instance, VerifyingKey};
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
use halo2curves::bn256::{Bn256, Fr as Fp, G1Affine};
use halo2_proofs::poly::commitment::CommitmentScheme;
use halo2curves::bn256::{Fr as Fp, G1Affine};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
@@ -14,9 +14,6 @@ use super::{VarVisibility, Visibility};
/// poseidon len to hash in tree
pub const POSEIDON_LEN_GRAPH: usize = 32;
/// ElGamal number of instances
pub const ELGAMAL_INSTANCES: usize = 4;
/// Poseidon number of instancess
pub const POSEIDON_INSTANCES: usize = 1;
@@ -29,8 +26,8 @@ pub type ModulePoseidonConfig = PoseidonConfig<POSEIDON_WIDTH, POSEIDON_RATE>;
///
#[derive(Clone, Debug, Default)]
pub struct ModuleConfigs {
/// KZG
kzg: Vec<KZGConfig>,
/// PolyCommit
polycommit: Vec<PolyCommitConfig>,
/// Poseidon
poseidon: Option<ModulePoseidonConfig>,
/// Instance
@@ -46,8 +43,10 @@ impl ModuleConfigs {
) -> Self {
let mut config = Self::default();
for size in module_size.kzg {
config.kzg.push(KZGChip::configure(cs, (logrows, size)));
for size in module_size.polycommit {
config
.polycommit
.push(PolyCommitChip::configure(cs, (logrows, size)));
}
config
@@ -94,8 +93,8 @@ impl ModuleConfigs {
pub struct ModuleForwardResult {
/// The inputs of the forward pass for poseidon
pub poseidon_hash: Option<Vec<Fp>>,
/// The outputs of the forward pass for KZG
pub kzg_commit: Option<Vec<Vec<G1Affine>>>,
/// The outputs of the forward pass for PolyCommit
pub polycommit: Option<Vec<Vec<G1Affine>>>,
}
impl ModuleForwardResult {
@@ -126,7 +125,7 @@ impl ModuleForwardResult {
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
///
pub struct ModuleSizes {
kzg: Vec<usize>,
polycommit: Vec<usize>,
poseidon: (usize, Vec<usize>),
}
@@ -134,7 +133,7 @@ impl ModuleSizes {
/// Create new module sizes
pub fn new() -> Self {
ModuleSizes {
kzg: vec![],
polycommit: vec![],
poseidon: (
0,
vec![0; crate::circuit::modules::poseidon::NUM_INSTANCE_COLUMNS],
@@ -156,17 +155,17 @@ impl ModuleSizes {
/// Graph modules that can process inputs, params and outputs beyond the basic operations
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct GraphModules {
kzg_idx: usize,
polycommit_idx: usize,
}
impl GraphModules {
///
pub fn new() -> GraphModules {
GraphModules { kzg_idx: 0 }
GraphModules { polycommit_idx: 0 }
}
///
pub fn reset_index(&mut self) {
self.kzg_idx = 0;
self.polycommit_idx = 0;
}
}
@@ -179,9 +178,9 @@ impl GraphModules {
for shape in shapes {
let total_len = shape.iter().product::<usize>();
if total_len > 0 {
if visibility.is_kzgcommit() {
// 1 constraint for each kzg commitment
sizes.kzg.push(total_len);
if visibility.is_polycommit() {
// 1 constraint for each polycommit commitment
sizes.polycommit.push(total_len);
} else if visibility.is_hashed() {
sizes.poseidon.0 += ModulePoseidon::num_rows(total_len);
// 1 constraints for hash
@@ -236,22 +235,22 @@ impl GraphModules {
element_visibility: &Visibility,
instance_offset: &mut usize,
) -> Result<(), Error> {
if element_visibility.is_kzgcommit() && !values.is_empty() {
if element_visibility.is_polycommit() && !values.is_empty() {
// concat values and sk to get the inputs
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
// layout the module
inputs.iter_mut().for_each(|x| {
// create the module
let chip = KZGChip::new(configs.kzg[self.kzg_idx].clone());
// reserve module 2 onwards for kzg modules
let module_offset = 3 + self.kzg_idx;
let chip = PolyCommitChip::new(configs.polycommit[self.polycommit_idx].clone());
// reserve module 2 onwards for polycommit modules
let module_offset = 3 + self.polycommit_idx;
layouter
.assign_region(|| format!("_enter_module_{}", module_offset), |_| Ok(()))
.unwrap();
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
// increment the current index
self.kzg_idx += 1;
self.polycommit_idx += 1;
});
// replace the inputs with the outputs
@@ -288,14 +287,14 @@ impl GraphModules {
}
/// Run forward pass
pub fn forward(
inputs: &[Tensor<Fp>],
pub fn forward<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
inputs: &[Tensor<Scheme::Scalar>],
element_visibility: &Visibility,
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&ParamsKZG<Bn256>>,
srs: Option<&Scheme::ParamsProver>,
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
let mut poseidon_hash = None;
let mut kzg_commit = None;
let mut polycommit = None;
if element_visibility.is_hashed() {
let field_elements = inputs.iter().fold(vec![], |mut acc, x| {
@@ -306,11 +305,11 @@ impl GraphModules {
poseidon_hash = Some(field_elements);
}
if element_visibility.is_kzgcommit() {
if element_visibility.is_polycommit() {
if let Some(vk) = vk {
if let Some(srs) = srs {
let commitments = inputs.iter().fold(vec![], |mut acc, x| {
let res = KZGChip::commit(
let res = PolyCommitChip::commit::<Scheme>(
x.to_vec(),
vk.cs().degree() as u32,
(vk.cs().blinding_factors() + 1) as u32,
@@ -319,20 +318,20 @@ impl GraphModules {
acc.push(res);
acc
});
kzg_commit = Some(commitments);
polycommit = Some(commitments);
} else {
log::warn!("no srs provided for kzgcommit. processed value will be none");
log::warn!("no srs provided for polycommit. processed value will be none");
}
} else {
log::debug!(
"no verifying key provided for kzgcommit. processed value will be none"
"no verifying key provided for polycommit. processed value will be none"
);
}
}
Ok(ModuleForwardResult {
poseidon_hash,
kzg_commit,
polycommit,
})
}
}

View File

@@ -497,6 +497,7 @@ impl Node {
/// * `public_params` - flag if parameters of model are public
/// * `idx` - The node's unique identifier.
#[cfg(not(target_arch = "wasm32"))]
#[allow(clippy::too_many_arguments)]
pub fn new(
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
other_nodes: &mut BTreeMap<usize, super::NodeType>,

View File

@@ -23,7 +23,10 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::ops::{
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
},
change_axes::AxisOp,
cnn::{Conv, Deconv},
einsum::EinSum,
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
// Extract the max value
}
"ScatterNd" => {
if inputs.len() != 3 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"scatter nd".to_string(),
)));
};
// just verify it deserializes correctly
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherNd" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"gather nd".to_string(),
)));
};
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
let batch_dims = op.batch_dims;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherElements" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
@@ -734,7 +809,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
use_range_check_for_int: false,
use_range_check_for_int: true,
})
}

View File

@@ -1,4 +1,5 @@
use std::error::Error;
use std::fmt::Display;
use crate::tensor::TensorType;
use crate::tensor::{ValTensor, VarTensor};
@@ -14,6 +15,7 @@ use pyo3::{
};
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
use super::*;
@@ -40,6 +42,33 @@ pub enum Visibility {
Fixed,
}
impl Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Visibility::KZGCommit => write!(f, "polycommit"),
Visibility::Private => write!(f, "private"),
Visibility::Public => write!(f, "public"),
Visibility::Fixed => write!(f, "fixed"),
Visibility::Hashed {
hash_is_public,
outlets,
} => {
if *hash_is_public {
write!(f, "hashed/public")
} else {
write!(f, "hashed/private/{}", outlets.iter().join(","))
}
}
}
}
}
impl ToFlags for Visibility {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl<'a> From<&'a str> for Visibility {
fn from(s: &'a str) -> Self {
if s.contains("hashed/private") {
@@ -59,7 +88,7 @@ impl<'a> From<&'a str> for Visibility {
match s {
"private" => Visibility::Private,
"public" => Visibility::Public,
"kzgcommit" => Visibility::KZGCommit,
"polycommit" => Visibility::KZGCommit,
"fixed" => Visibility::Fixed,
"hashed" | "hashed/public" => Visibility::Hashed {
hash_is_public: true,
@@ -82,7 +111,7 @@ impl IntoPy<PyObject> for Visibility {
Visibility::Private => "private".to_object(py),
Visibility::Public => "public".to_object(py),
Visibility::Fixed => "fixed".to_object(py),
Visibility::KZGCommit => "kzgcommit".to_object(py),
Visibility::KZGCommit => "polycommit".to_object(py),
Visibility::Hashed {
hash_is_public,
outlets,
@@ -129,7 +158,7 @@ impl<'source> FromPyObject<'source> for Visibility {
match strval.to_lowercase().as_str() {
"private" => Ok(Visibility::Private),
"public" => Ok(Visibility::Public),
"kzgcommit" => Ok(Visibility::KZGCommit),
"polycommit" => Ok(Visibility::KZGCommit),
"hashed" => Ok(Visibility::Hashed {
hash_is_public: true,
outlets: vec![],
@@ -163,7 +192,7 @@ impl Visibility {
matches!(&self, Visibility::Hashed { .. })
}
#[allow(missing_docs)]
pub fn is_kzgcommit(&self) -> bool {
pub fn is_polycommit(&self) -> bool {
matches!(&self, Visibility::KZGCommit)
}
@@ -202,17 +231,6 @@ impl Visibility {
vec![]
}
}
impl std::fmt::Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Visibility::KZGCommit => write!(f, "kzgcommit"),
Visibility::Private => write!(f, "private"),
Visibility::Public => write!(f, "public"),
Visibility::Fixed => write!(f, "fixed"),
Visibility::Hashed { .. } => write!(f, "hashed"),
}
}
}
/// Represents the scale of the model input, model parameters.
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
@@ -305,9 +323,9 @@ impl VarVisibility {
& !output_vis.is_hashed()
& !params_vis.is_hashed()
& !input_vis.is_hashed()
& !output_vis.is_kzgcommit()
& !params_vis.is_kzgcommit()
& !input_vis.is_kzgcommit()
& !output_vis.is_polycommit()
& !params_vis.is_polycommit()
& !input_vis.is_polycommit()
{
return Err(Box::new(GraphError::Visibility));
}
@@ -402,20 +420,34 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
}
/// Allocate all columns that will be assigned to by a model.
pub fn new(
cs: &mut ConstraintSystem<F>,
logrows: usize,
var_len: usize,
num_inner_cols: usize,
num_constants: usize,
module_requires_fixed: bool,
) -> Self {
info!("number of blinding factors: {}", cs.blinding_factors());
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
debug!("number of blinding factors: {}", cs.blinding_factors());
let advices = (0..3)
let logrows = params.run_args.logrows as usize;
let var_len = params.total_assignments;
let num_inner_cols = params.run_args.num_inner_cols;
let num_constants = params.total_const_size;
let module_requires_fixed = params.module_requires_fixed();
let requires_dynamic_lookup = params.requires_dynamic_lookup();
let requires_shuffle = params.requires_shuffle();
let dynamic_lookup_and_shuffle_size = params.dynamic_lookup_and_shuffle_col_size();
let mut advices = (0..3)
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
.collect_vec();
if requires_dynamic_lookup || requires_shuffle {
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
for _ in 0..num_cols {
let dynamic_lookup =
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
if dynamic_lookup.num_blocks() > 1 {
panic!("dynamic lookup or shuffle should only have one block");
};
advices.push(dynamic_lookup);
}
}
debug!(
"model uses {} advice blocks (size={})",
advices.iter().map(|v| v.num_blocks()).sum::<usize>(),

View File

@@ -28,10 +28,17 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use std::str::FromStr;
use circuit::{table::Range, CheckMode, Tolerance};
use clap::Args;
use graph::Visibility;
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};
use halo2curves::bn256::{Bn256, G1Affine};
use serde::{Deserialize, Serialize};
use tosubcommand::ToFlags;
/// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit.
pub mod circuit;
@@ -70,11 +77,99 @@ pub mod tensor;
#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub mod wasm;
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
/// The denominator in the fixed point representation used when quantizing inputs
pub type Scale = i32;
#[cfg(not(target_arch = "wasm32"))]
// Buf writer capacity
lazy_static! {
/// The capacity of the buffer used for writing to disk
pub static ref EZKL_BUF_CAPACITY: usize = std::env::var("EZKL_BUF_CAPACITY")
.unwrap_or("8000".to_string())
.parse()
.unwrap();
/// The serialization format for the keys
pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT")
.unwrap_or("raw-bytes".to_string());
}
#[cfg(target_arch = "wasm32")]
const EZKL_KEY_FORMAT: &str = "raw-bytes";
#[cfg(target_arch = "wasm32")]
const EZKL_BUF_CAPACITY: &usize = &8000;
#[derive(
Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Default, Copy,
)]
/// Commitment scheme
pub enum Commitments {
#[default]
/// KZG
KZG,
/// IPA
IPA,
}
impl FromStr for Commitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"kzg" => Ok(Commitments::KZG),
"ipa" => Ok(Commitments::IPA),
_ => Err("Invalid value for Commitments".to_string()),
}
}
}
impl From<KZGCommitmentScheme<Bn256>> for Commitments {
fn from(_value: KZGCommitmentScheme<Bn256>) -> Self {
Commitments::KZG
}
}
impl From<IPACommitmentScheme<G1Affine>> for Commitments {
fn from(_value: IPACommitmentScheme<G1Affine>) -> Self {
Commitments::IPA
}
}
impl std::fmt::Display for Commitments {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Commitments::KZG => write!(f, "kzg"),
Commitments::IPA => write!(f, "ipa"),
}
}
}
impl ToFlags for Commitments {
/// Convert the struct to a subcommand string
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<String> for Commitments {
fn from(value: String) -> Self {
match value.to_lowercase().as_str() {
"kzg" => Commitments::KZG,
"ipa" => Commitments::IPA,
_ => {
log::error!("Invalid value for Commitments");
log::warn!("defaulting to KZG");
Commitments::KZG
}
}
}
}
/// Parameters specific to a proving run
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd)]
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
pub struct RunArgs {
/// The tolerance for error on model outputs
#[arg(short = 'T', long, default_value = "0")]
@@ -89,7 +184,7 @@ pub struct RunArgs {
#[arg(long, default_value = "1")]
pub scale_rebase_multiplier: u32,
/// The min and max elements in the lookup table input column
#[arg(short = 'B', long, value_parser = parse_tuple::<i128>, default_value = "(-32768,32768)")]
#[arg(short = 'B', long, value_parser = parse_key_val::<i128, i128>, default_value = "-32768->32768")]
pub lookup_range: Range,
/// The log_2 number of rows
#[arg(short = 'K', long, default_value = "17")]
@@ -98,7 +193,7 @@ pub struct RunArgs {
#[arg(short = 'N', long, default_value = "2")]
pub num_inner_cols: usize,
/// Hand-written parser for graph variables, eg. batch_size=1
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size=1", value_delimiter = ',')]
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
pub variables: Vec<(String, usize)>,
/// Flags whether inputs are public, private, hashed
#[arg(long, default_value = "private")]
@@ -118,6 +213,9 @@ pub struct RunArgs {
/// check mode (safe, unsafe, etc)
#[arg(long, default_value = "unsafe")]
pub check_mode: CheckMode,
/// commitment scheme
#[arg(long, default_value = "kzg")]
pub commitment: Commitments,
}
impl Default for RunArgs {
@@ -137,6 +235,7 @@ impl Default for RunArgs {
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: Commitments::KZG,
}
}
}
@@ -156,6 +255,9 @@ impl RunArgs {
if self.num_inner_cols < 1 {
return Err("num_inner_cols must be >= 1".into());
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
}
Ok(())
}
@@ -180,34 +282,15 @@ fn parse_key_val<T, U>(
s: &str,
) -> Result<(T, U), Box<dyn std::error::Error + Send + Sync + 'static>>
where
T: std::str::FromStr,
T: std::str::FromStr + std::fmt::Debug,
T::Err: std::error::Error + Send + Sync + 'static,
U: std::str::FromStr,
U: std::str::FromStr + std::fmt::Debug,
U::Err: std::error::Error + Send + Sync + 'static,
{
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
}
/// Parse a tuple
fn parse_tuple<T>(s: &str) -> Result<(T, T), Box<dyn std::error::Error + Send + Sync + 'static>>
where
T: std::str::FromStr + Clone,
T::Err: std::error::Error + Send + Sync + 'static,
{
let res = s.trim_matches(|p| p == '(' || p == ')').split(',');
let res = res
.map(|x| {
// remove blank space
let x = x.trim();
x.parse::<T>()
})
.collect::<Result<Vec<_>, _>>()?;
if res.len() != 2 {
return Err("invalid tuple".into());
}
Ok((res[0].clone(), res[1].clone()))
.find("->")
.ok_or_else(|| format!("invalid x->y: no `->` found in `{s}`"))?;
let a = s[..pos].parse()?;
let b = s[pos + 2..].parse()?;
Ok((a, b))
}

View File

@@ -1,7 +1,7 @@
use thiserror::Error;
/// Aggregate proof generation for EVM
pub mod aggregation;
/// Aggregate proof generation for EVM using KZG
pub mod aggregation_kzg;
#[derive(Error, Debug)]
/// Errors related to evm verification

View File

@@ -6,16 +6,16 @@ pub mod srs;
use crate::circuit::CheckMode;
use crate::graph::GraphWitness;
use crate::pfsys::evm::aggregation::PoseidonTranscript;
use crate::tensor::TensorType;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
};
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG};
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
@@ -31,7 +31,6 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::{compile, Config};
use snark_verifier::verifier::plonk::PlonkProtocol;
use std::error::Error;
use std::fs::File;
@@ -39,25 +38,19 @@ use std::io::{self, BufReader, BufWriter, Cursor, Write};
use std::ops::Deref;
use std::path::PathBuf;
use thiserror::Error as thisError;
// not wasm
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use tosubcommand::ToFlags;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
#[cfg(not(target_arch = "wasm32"))]
// Buf writer capacity
lazy_static! {
static ref EZKL_BUF_CAPACITY: usize = std::env::var("EZKL_BUF_CAPACITY")
.unwrap_or("8000".to_string())
.parse()
.unwrap();
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
match s {
"processed" => halo2_proofs::SerdeFormat::Processed,
"raw-bytes-unchecked" => halo2_proofs::SerdeFormat::RawBytesUnchecked,
"raw-bytes" => halo2_proofs::SerdeFormat::RawBytes,
_ => panic!("invalid serde format"),
}
}
#[cfg(target_arch = "wasm32")]
const EZKL_BUF_CAPACITY: &usize = &8000;
#[allow(missing_docs)]
#[derive(
ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd,
@@ -68,6 +61,25 @@ pub enum ProofType {
ForAggr,
}
impl std::fmt::Display for ProofType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
ProofType::Single => "single",
ProofType::ForAggr => "for-aggr",
}
)
}
}
impl ToFlags for ProofType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
impl From<ProofType> for TranscriptType {
fn from(val: ProofType) -> Self {
match val {
@@ -170,6 +182,25 @@ pub enum TranscriptType {
EVM,
}
impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
TranscriptType::Poseidon => "poseidon",
TranscriptType::EVM => "evm",
}
)
}
}
impl ToFlags for TranscriptType {
fn to_flags(&self) -> Vec<String> {
vec![format!("{}", self)]
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for TranscriptType {
fn to_object(&self, py: Python) -> PyObject {
@@ -183,8 +214,8 @@ impl ToPyObject for TranscriptType {
#[cfg(feature = "python-bindings")]
///
pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) {
let g1affine_x = field_to_string_montgomery(&g1affine.x);
let g1affine_y = field_to_string_montgomery(&g1affine.y);
let g1affine_x = field_to_string(&g1affine.x);
let g1affine_y = field_to_string(&g1affine.y);
g1affine_dict.set_item("x", g1affine_x).unwrap();
g1affine_dict.set_item("y", g1affine_y).unwrap();
}
@@ -194,23 +225,23 @@ use halo2curves::bn256::G1;
#[cfg(feature = "python-bindings")]
///
pub fn g1_to_pydict(g1_dict: &PyDict, g1: &G1) {
let g1_x = field_to_string_montgomery(&g1.x);
let g1_y = field_to_string_montgomery(&g1.y);
let g1_z = field_to_string_montgomery(&g1.z);
let g1_x = field_to_string(&g1.x);
let g1_y = field_to_string(&g1.y);
let g1_z = field_to_string(&g1.z);
g1_dict.set_item("x", g1_x).unwrap();
g1_dict.set_item("y", g1_y).unwrap();
g1_dict.set_item("z", g1_z).unwrap();
}
/// converts fp into `Vec<u64>` in Montgomery form
pub fn field_to_string_montgomery<F: PrimeField + SerdeObject + Serialize>(fp: &F) -> String {
/// converts fp into a little endian Hex string
pub fn field_to_string<F: PrimeField + SerdeObject + Serialize>(fp: &F) -> String {
let repr = serde_json::to_string(&fp).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
b
}
/// converts `Vec<u64>` in Montgomery form into fp
pub fn string_to_field_montgomery<F: PrimeField + SerdeObject + Serialize + DeserializeOwned>(
/// converts a little endian Hex string into a field element
pub fn string_to_field<F: PrimeField + SerdeObject + Serialize + DeserializeOwned>(
b: &String,
) -> F {
let repr = serde_json::to_string(&b).unwrap();
@@ -260,6 +291,8 @@ where
pub pretty_public_inputs: Option<PrettyElements>,
/// timestamp
pub timestamp: Option<u128>,
/// commitment
pub commitment: Option<Commitments>,
}
#[cfg(feature = "python-bindings")]
@@ -275,7 +308,7 @@ where
let field_elems: Vec<Vec<String>> = self
.instances
.iter()
.map(|x| x.iter().map(|fp| field_to_string_montgomery(fp)).collect())
.map(|x| x.iter().map(|fp| field_to_string(fp)).collect())
.collect::<Vec<_>>();
dict.set_item("instances", field_elems).unwrap();
let hex_proof = hex::encode(&self.proof);
@@ -303,6 +336,7 @@ where
transcript_type: TranscriptType,
split: Option<ProofSplitCommit>,
pretty_public_inputs: Option<PrettyElements>,
commitment: Option<Commitments>,
) -> Self {
Self {
protocol,
@@ -319,6 +353,7 @@ where
.unwrap()
.as_millis(),
),
commitment,
}
}
@@ -344,8 +379,10 @@ where
<C as CurveAffine>::ScalarExt: FromUniformBytes<64>,
{
trace!("reading proof");
let data = std::fs::read_to_string(proof_path)?;
serde_json::from_str(&data).map_err(|e| e.into())
let file = std::fs::File::open(proof_path)?;
let reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let proof: Self = serde_json::from_reader(reader)?;
Ok(proof)
}
}
@@ -363,27 +400,36 @@ impl From<GraphWitness> for Option<ProofSplitCommit> {
let mut elem_offset = 0;
if let Some(input) = witness.processed_inputs {
if let Some(kzg) = input.kzg_commit {
if let Some(polycommit) = input.polycommit {
// flatten and count number of elements
let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::<usize>();
let num_elements = polycommit
.iter()
.map(|polycommit| polycommit.len())
.sum::<usize>();
elem_offset += num_elements;
}
}
if let Some(params) = witness.processed_params {
if let Some(kzg) = params.kzg_commit {
if let Some(polycommit) = params.polycommit {
// flatten and count number of elements
let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::<usize>();
let num_elements = polycommit
.iter()
.map(|polycommit| polycommit.len())
.sum::<usize>();
elem_offset += num_elements;
}
}
if let Some(output) = witness.processed_outputs {
if let Some(kzg) = output.kzg_commit {
if let Some(polycommit) = output.polycommit {
// flatten and count number of elements
let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::<usize>();
let num_elements = polycommit
.iter()
.map(|polycommit| polycommit.len())
.sum::<usize>();
Some(ProofSplitCommit {
start: elem_offset,
@@ -446,22 +492,22 @@ where
}
/// Creates a [VerifyingKey] and [ProvingKey] for a [crate::graph::GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`).
pub fn create_keys<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
circuit: &C,
params: &'_ Scheme::ParamsProver,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
{
// Real proof
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
let empty_circuit = <C as Circuit<Scheme::Scalar>>::without_witnesses(circuit);
// Initialize verifying key
let now = Instant::now();
trace!("preparing VK");
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
let vk = keygen_vk_custom(params, &empty_circuit, !disable_selector_compression)?;
let elapsed = now.elapsed();
info!("VK took {}.{}", elapsed.as_secs(), elapsed.subsec_millis());
@@ -478,8 +524,7 @@ where
pub fn create_proof_circuit<
'params,
Scheme: CommitmentScheme,
F: PrimeField + TensorType,
C: Circuit<F>,
C: Circuit<Scheme::Scalar>,
P: Prover<'params, Scheme>,
V: Verifier<'params, Scheme>,
Strategy: VerificationStrategy<'params, Scheme, V>,
@@ -491,40 +536,28 @@ pub fn create_proof_circuit<
instances: Vec<Vec<Scheme::Scalar>>,
params: &'params Scheme::ParamsProver,
pk: &ProvingKey<Scheme::Curve>,
strategy: Strategy,
check_mode: CheckMode,
commitment: Commitments,
transcript_type: TranscriptType,
split: Option<ProofSplitCommit>,
set_protocol: bool,
protocol: Option<PlonkProtocol<Scheme::Curve>>,
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, Box<dyn Error>>
where
C: Circuit<Scheme::Scalar>,
Scheme::ParamsVerifier: 'params,
Scheme::Scalar: Serialize
+ DeserializeOwned
+ SerdeObject
+ PrimeField
+ FromUniformBytes<64>
+ WithSmallOrderMulGroup<3>
+ Ord,
+ WithSmallOrderMulGroup<3>,
Scheme::Curve: Serialize + DeserializeOwned,
{
let strategy = Strategy::new(params.verifier_params());
let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
#[cfg(feature = "det-prove")]
let mut rng = <StdRng as rand::SeedableRng>::from_seed([0u8; 32]);
#[cfg(not(feature = "det-prove"))]
let mut rng = OsRng;
let number_instance = instances.iter().map(|x| x.len()).collect();
trace!("number_instance {:?}", number_instance);
let mut protocol = None;
if set_protocol {
protocol = Some(compile(
params,
pk.get_vk(),
Config::kzg().with_num_instance(number_instance),
))
}
let pi_inner = instances
.iter()
@@ -560,17 +593,19 @@ where
transcript_type,
split,
None,
Some(commitment),
);
// sanity check that the generated proof is valid
if check_mode == CheckMode::SAFE {
debug!("verifying generated proof");
let verifier_params = params.verifier_params();
verify_proof_circuit::<F, V, Scheme, Strategy, E, TR>(
verify_proof_circuit::<V, Scheme, Strategy, E, TR>(
&checkable_pf,
verifier_params,
pk.get_vk(),
strategy,
verifier_params.n(),
)?;
}
let elapsed = now.elapsed();
@@ -585,7 +620,6 @@ where
/// Swaps the proof commitments to a new set in the proof
pub fn swap_proof_commitments<
F: PrimeField,
Scheme: CommitmentScheme,
E: EncodedChallenge<Scheme::Curve>,
TW: TranscriptWriterBuffer<Vec<u8>, Scheme::Curve, E>,
@@ -605,7 +639,7 @@ where
{
let mut transcript_new: TW = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
// kzg commitments are the first set of points in the proof, this we'll always be the first set of advice
// polycommit commitments are the first set of points in the proof, this we'll always be the first set of advice
for commit in commitments {
transcript_new
.write_point(*commit)
@@ -623,31 +657,46 @@ where
}
/// Swap the proof commitments to a new set in the proof for KZG
pub fn swap_proof_commitments_kzg(
pub fn swap_proof_commitments_polycommit(
snark: &Snark<Fr, G1Affine>,
commitments: &[G1Affine],
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
let proof = match snark.transcript_type {
TranscriptType::EVM => swap_proof_commitments::<
Fr,
KZGCommitmentScheme<Bn256>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(snark, commitments)?,
TranscriptType::Poseidon => swap_proof_commitments::<
Fr,
KZGCommitmentScheme<Bn256>,
_,
PoseidonTranscript<NativeLoader, _>,
>(snark, commitments)?,
let proof = match snark.commitment {
Some(Commitments::KZG) => match snark.transcript_type {
TranscriptType::EVM => swap_proof_commitments::<
KZGCommitmentScheme<Bn256>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(snark, commitments)?,
TranscriptType::Poseidon => swap_proof_commitments::<
KZGCommitmentScheme<Bn256>,
_,
PoseidonTranscript<NativeLoader, _>,
>(snark, commitments)?,
},
Some(Commitments::IPA) => match snark.transcript_type {
TranscriptType::EVM => swap_proof_commitments::<
IPACommitmentScheme<G1Affine>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(snark, commitments)?,
TranscriptType::Poseidon => swap_proof_commitments::<
IPACommitmentScheme<G1Affine>,
_,
PoseidonTranscript<NativeLoader, _>,
>(snark, commitments)?,
},
None => {
return Err("commitment scheme not found".into());
}
};
Ok(proof)
}
/// A wrapper around halo2's verify_proof
pub fn verify_proof_circuit<
'params,
F: PrimeField,
V: Verifier<'params, Scheme>,
Scheme: CommitmentScheme,
Strategy: VerificationStrategy<'params, Scheme, V>,
@@ -658,13 +707,13 @@ pub fn verify_proof_circuit<
params: &'params Scheme::ParamsVerifier,
vk: &VerifyingKey<Scheme::Curve>,
strategy: Strategy,
orig_n: u64,
) -> Result<Strategy::Output, halo2_proofs::plonk::Error>
where
Scheme::Scalar: SerdeObject
+ PrimeField
+ FromUniformBytes<64>
+ WithSmallOrderMulGroup<3>
+ Ord
+ Serialize
+ DeserializeOwned,
Scheme::Curve: Serialize + DeserializeOwned,
@@ -678,11 +727,11 @@ where
trace!("instances {:?}", instances);
let mut transcript = TranscriptReadBuffer::init(Cursor::new(snark.proof.clone()));
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript)
verify_proof::<Scheme, V, _, TR, _>(params, vk, strategy, instances, &mut transcript, orig_n)
}
/// Loads a [VerifyingKey] at `path`.
pub fn load_vk<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
path: PathBuf,
params: <C as Circuit<Scheme::Scalar>>::Params,
) -> Result<VerifyingKey<Scheme::Curve>, Box<dyn Error>>
@@ -695,16 +744,17 @@ where
let f =
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
VerifyingKey::<Scheme::Curve>::read::<_, C>(
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading verification key ✅");
Ok(vk)
}
/// Loads a [ProvingKey] at `path`.
pub fn load_pk<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
path: PathBuf,
params: <C as Circuit<Scheme::Scalar>>::Params,
) -> Result<ProvingKey<Scheme::Curve>, Box<dyn Error>>
@@ -717,45 +767,46 @@ where
let f =
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
ProvingKey::<Scheme::Curve>::read::<_, C>(
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading proving key ✅");
Ok(pk)
}
/// Saves a [ProvingKey] to `path`.
pub fn save_pk<Scheme: CommitmentScheme>(
pub fn save_pk<C: SerdeObject + CurveAffine>(
path: &PathBuf,
vk: &ProvingKey<Scheme::Curve>,
pk: &ProvingKey<C>,
) -> Result<(), io::Error>
where
Scheme::Curve: SerdeObject + CurveAffine,
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
{
info!("saving proving key 💾");
let f = File::create(path)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving proving key ✅");
Ok(())
}
/// Saves a [VerifyingKey] to `path`.
pub fn save_vk<Scheme: CommitmentScheme>(
pub fn save_vk<C: CurveAffine + SerdeObject>(
path: &PathBuf,
vk: &VerifyingKey<Scheme::Curve>,
vk: &VerifyingKey<C>,
) -> Result<(), io::Error>
where
Scheme::Curve: SerdeObject + CurveAffine,
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
{
info!("saving verification key 💾");
let f = File::create(path)?;
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
vk.write(&mut writer, halo2_proofs::SerdeFormat::RawBytes)?;
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving verification key ✅");
Ok(())
}
@@ -772,155 +823,25 @@ pub fn save_params<Scheme: CommitmentScheme>(
Ok(())
}
/// helper function
#[allow(clippy::too_many_arguments)]
pub fn create_proof_circuit_kzg<
'params,
C: Circuit<Fr>,
Strategy: VerificationStrategy<'params, KZGCommitmentScheme<Bn256>, VerifierSHPLONK<'params, Bn256>>,
>(
circuit: C,
params: &'params ParamsKZG<Bn256>,
public_inputs: Option<Vec<Fr>>,
pk: &ProvingKey<G1Affine>,
transcript: TranscriptType,
strategy: Strategy,
check_mode: CheckMode,
split: Option<ProofSplitCommit>,
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
let public_inputs = if let Some(public_inputs) = public_inputs {
if !public_inputs.is_empty() {
vec![public_inputs]
} else {
vec![vec![]]
}
} else {
vec![]
};
match transcript {
TranscriptType::EVM => create_proof_circuit::<
KZGCommitmentScheme<_>,
Fr,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
_,
_,
EvmTranscript<G1Affine, _, _, _>,
EvmTranscript<G1Affine, _, _, _>,
>(
circuit,
public_inputs,
params,
pk,
strategy,
check_mode,
transcript,
split,
false,
)
.map_err(Box::<dyn Error>::from),
TranscriptType::Poseidon => create_proof_circuit::<
KZGCommitmentScheme<_>,
Fr,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
_,
_,
PoseidonTranscript<NativeLoader, _>,
PoseidonTranscript<NativeLoader, _>,
>(
circuit,
public_inputs,
params,
pk,
strategy,
check_mode,
transcript,
split,
true,
)
.map_err(Box::<dyn Error>::from),
}
}
#[allow(unused)]
/// helper function
pub(crate) fn verify_proof_circuit_kzg<
'params,
Strategy: VerificationStrategy<'params, KZGCommitmentScheme<Bn256>, VerifierSHPLONK<'params, Bn256>>,
>(
params: &'params ParamsKZG<Bn256>,
proof: Snark<Fr, G1Affine>,
vk: &VerifyingKey<G1Affine>,
strategy: Strategy,
) -> Result<Strategy::Output, halo2_proofs::plonk::Error> {
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
Fr,
VerifierSHPLONK<'_, Bn256>,
_,
_,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, params, vk, strategy),
TranscriptType::Poseidon => verify_proof_circuit::<
Fr,
VerifierSHPLONK<'_, Bn256>,
_,
_,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, params, vk, strategy),
}
}
////////////////////////
#[cfg(test)]
#[cfg(not(target_arch = "wasm32"))]
mod tests {
use std::io::copy;
use super::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
use tempfile::Builder;
#[tokio::test]
async fn test_can_load_pre_generated_srs() {
let tmp_dir = Builder::new().prefix("example").tempdir().unwrap();
// lets hope this link never rots
let target = "https://trusted-setup-halo2kzg.s3.eu-central-1.amazonaws.com/hermez-raw-1";
let response = reqwest::get(target).await.unwrap();
let fname = response
.url()
.path_segments()
.and_then(|segments| segments.last())
.and_then(|name| if name.is_empty() { None } else { Some(name) })
.unwrap_or("tmp.bin");
info!("file to download: '{}'", fname);
let fname = tmp_dir.path().join(fname);
info!("will be located under: '{:?}'", fname);
let mut dest = File::create(fname.clone()).unwrap();
let content = response.bytes().await.unwrap();
copy(&mut &content[..], &mut dest).unwrap();
let res = srs::load_srs::<KZGCommitmentScheme<Bn256>>(fname);
assert!(res.is_ok())
}
#[tokio::test]
async fn test_can_load_saved_srs() {
let tmp_dir = Builder::new().prefix("example").tempdir().unwrap();
let fname = tmp_dir.path().join("kzg.params");
let fname = tmp_dir.path().join("polycommit.params");
let srs = srs::gen_srs::<KZGCommitmentScheme<Bn256>>(1);
let res = save_params::<KZGCommitmentScheme<Bn256>>(&fname, &srs);
assert!(res.is_ok());
let res = srs::load_srs::<KZGCommitmentScheme<Bn256>>(fname);
let res = srs::load_srs_prover::<KZGCommitmentScheme<Bn256>>(fname);
assert!(res.is_ok())
}
@@ -935,6 +856,7 @@ mod tests {
split: None,
pretty_public_inputs: None,
timestamp: None,
commitment: None,
};
snark

View File

@@ -17,7 +17,7 @@ pub fn gen_srs<Scheme: CommitmentScheme>(k: u32) -> Scheme::ParamsProver {
}
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
pub fn load_srs<Scheme: CommitmentScheme>(
pub fn load_srs_verifier<Scheme: CommitmentScheme>(
path: PathBuf,
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
info!("loading srs from {:?}", path);
@@ -26,3 +26,14 @@ pub fn load_srs<Scheme: CommitmentScheme>(
let mut reader = BufReader::new(f);
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
}
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
pub fn load_srs_prover<Scheme: CommitmentScheme>(
path: PathBuf,
) -> Result<Scheme::ParamsProver, Box<dyn Error>> {
info!("loading srs from {:?}", path);
let f = File::open(path.clone())
.map_err(|_| format!("failed to load srs at {}", path.display()))?;
let mut reader = BufReader::new(f);
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
}

View File

@@ -1,4 +1,4 @@
use crate::circuit::modules::kzg::KZGChip;
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
@@ -12,13 +12,14 @@ use crate::graph::TestDataSource;
use crate::graph::{
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
};
use crate::pfsys::evm::aggregation::AggregationCircuit;
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs, ProofType,
TranscriptType,
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
};
use crate::Commitments;
use crate::RunArgs;
use ethers::types::H160;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
@@ -65,9 +66,9 @@ struct PyG1 {
impl From<G1> for PyG1 {
fn from(g1: G1) -> Self {
PyG1 {
x: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.x),
y: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.y),
z: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.z),
x: crate::pfsys::field_to_string::<Fq>(&g1.x),
y: crate::pfsys::field_to_string::<Fq>(&g1.y),
z: crate::pfsys::field_to_string::<Fq>(&g1.z),
}
}
}
@@ -75,9 +76,9 @@ impl From<G1> for PyG1 {
impl From<PyG1> for G1 {
fn from(val: PyG1) -> Self {
G1 {
x: crate::pfsys::string_to_field_montgomery::<Fq>(&val.x),
y: crate::pfsys::string_to_field_montgomery::<Fq>(&val.y),
z: crate::pfsys::string_to_field_montgomery::<Fq>(&val.z),
x: crate::pfsys::string_to_field::<Fq>(&val.x),
y: crate::pfsys::string_to_field::<Fq>(&val.y),
z: crate::pfsys::string_to_field::<Fq>(&val.z),
}
}
}
@@ -108,8 +109,8 @@ pub struct PyG1Affine {
impl From<G1Affine> for PyG1Affine {
fn from(g1: G1Affine) -> Self {
PyG1Affine {
x: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.x),
y: crate::pfsys::field_to_string_montgomery::<Fq>(&g1.y),
x: crate::pfsys::field_to_string::<Fq>(&g1.x),
y: crate::pfsys::field_to_string::<Fq>(&g1.y),
}
}
}
@@ -117,8 +118,8 @@ impl From<G1Affine> for PyG1Affine {
impl From<PyG1Affine> for G1Affine {
fn from(val: PyG1Affine) -> Self {
G1Affine {
x: crate::pfsys::string_to_field_montgomery::<Fq>(&val.x),
y: crate::pfsys::string_to_field_montgomery::<Fq>(&val.y),
x: crate::pfsys::string_to_field::<Fq>(&val.x),
y: crate::pfsys::string_to_field::<Fq>(&val.y),
}
}
}
@@ -165,6 +166,8 @@ struct PyRunArgs {
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
pub check_mode: CheckMode,
#[pyo3(get, set)]
pub commitment: PyCommitments,
}
/// default instantiation of PyRunArgs
@@ -194,6 +197,7 @@ impl From<PyRunArgs> for RunArgs {
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: py_run_args.commitment.into(),
}
}
}
@@ -215,57 +219,95 @@ impl Into<PyRunArgs> for RunArgs {
div_rebasing: self.div_rebasing,
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
commitment: self.commitment.into(),
}
}
}
/// Converts 4 u64s to a field element
#[pyfunction(signature = (
array,
))]
fn string_to_felt(array: PyFelt) -> PyResult<String> {
Ok(format!(
"{:?}",
crate::pfsys::string_to_field_montgomery::<Fr>(&array)
))
#[pyclass]
#[derive(Debug, Clone)]
/// Pyclass marking the type of commitment
pub enum PyCommitments {
/// KZG commitment
KZG,
/// IPA commitment
IPA,
}
/// Converts 4 u64s representing a field element directly to an integer
impl From<PyCommitments> for Commitments {
fn from(py_commitments: PyCommitments) -> Self {
match py_commitments {
PyCommitments::KZG => Commitments::KZG,
PyCommitments::IPA => Commitments::IPA,
}
}
}
impl Into<PyCommitments> for Commitments {
fn into(self) -> PyCommitments {
match self {
Commitments::KZG => PyCommitments::KZG,
Commitments::IPA => PyCommitments::IPA,
}
}
}
impl FromStr for PyCommitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"kzg" => Ok(PyCommitments::KZG),
"ipa" => Ok(PyCommitments::IPA),
_ => Err("Invalid value for Commitments".to_string()),
}
}
}
/// Converts a felt to big endian
#[pyfunction(signature = (
felt,
))]
fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
Ok(format!("{:?}", felt))
}
/// Converts a field element hex string to an integer
#[pyfunction(signature = (
array,
))]
fn string_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field_montgomery::<Fr>(&array);
fn felt_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
let int_rep = felt_to_i128(felt);
Ok(int_rep)
}
/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point
/// Converts a field eleement hex string to a floating point number
#[pyfunction(signature = (
array,
scale
))]
fn string_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field_montgomery::<Fr>(&array);
fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
let int_rep = felt_to_i128(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
Ok(float_rep)
}
/// Converts a floating point element to 4 u64s representing a fixed point field element
/// Converts a floating point element to a field element hex string
#[pyfunction(signature = (
input,
scale
))]
fn float_to_string(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
let felt = i128_to_felt(int_rep);
Ok(crate::pfsys::field_to_string_montgomery::<Fr>(&felt))
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
/// Converts a buffer to vector of field elements
#[pyfunction(signature = (
buffer
))]
@@ -318,7 +360,10 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
.map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x)))
.collect();
let field_elements: Vec<String> = field_elements.iter().map(|x| format!("{:?}", x)).collect();
let field_elements: Vec<String> = field_elements
.iter()
.map(|x| crate::pfsys::field_to_string::<Fr>(x))
.collect();
Ok(field_elements)
}
@@ -330,7 +375,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::string_to_field_montgomery::<Fr>)
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let output =
@@ -341,7 +386,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let hash = output[0]
.iter()
.map(crate::pfsys::field_to_string_montgomery::<Fr>)
.map(crate::pfsys::field_to_string::<Fr>)
.collect::<Vec<_>>();
Ok(hash)
}
@@ -361,21 +406,62 @@ fn kzg_commit(
) -> PyResult<Vec<PyG1Affine>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::string_to_field_montgomery::<Fr>)
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let settings = GraphSettings::load(&settings_path)
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
let srs_path = crate::execute::get_srs_path(settings.run_args.logrows, srs_path);
let srs_path =
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
let srs = load_srs::<KZGCommitmentScheme<Bn256>>(srs_path)
let srs = load_srs_prover::<KZGCommitmentScheme<Bn256>>(srs_path)
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, settings)
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)
.map_err(|_| PyIOError::new_err("Failed to load vk"))?;
let output = KZGChip::commit(
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
message,
vk.cs().degree() as u32,
(vk.cs().blinding_factors() + 1) as u32,
&srs,
);
Ok(output.iter().map(|x| (*x).into()).collect::<Vec<_>>())
}
/// Generate an ipa commitment.
#[pyfunction(signature = (
message,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
fn ipa_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
settings_path: PathBuf,
srs_path: Option<PathBuf>,
) -> PyResult<Vec<PyG1Affine>> {
let message: Vec<Fr> = message
.iter()
.map(crate::pfsys::string_to_field::<Fr>)
.collect::<Vec<_>>();
let settings = GraphSettings::load(&settings_path)
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
let srs_path =
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
let srs = load_srs_prover::<IPACommitmentScheme<G1Affine>>(srs_path)
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
let vk = load_vk::<IPACommitmentScheme<G1Affine>, GraphCircuit>(vk_path, settings)
.map_err(|_| PyIOError::new_err("Failed to load vk"))?;
let output = PolyCommitChip::commit::<IPACommitmentScheme<G1Affine>>(
message,
vk.cs().degree() as u32,
(vk.cs().blinding_factors() + 1) as u32,
@@ -391,7 +477,7 @@ fn kzg_commit(
witness_path=PathBuf::from(DEFAULT_WITNESS),
))]
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
crate::execute::swap_proof_commitments(proof_path, witness_path)
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
Ok(())
@@ -411,13 +497,13 @@ fn gen_vk_from_pk_single(
let settings = GraphSettings::load(&circuit_settings_path)
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
let pk = load_pk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(path_to_pk, settings)
let pk = load_pk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(path_to_pk, settings)
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
let vk = pk.get_vk();
// now save
save_vk::<KZGCommitmentScheme<Bn256>>(&vk_output_path, vk)
save_vk::<G1Affine>(&vk_output_path, vk)
.map_err(|_| PyIOError::new_err("Failed to save vk"))?;
Ok(true)
@@ -429,13 +515,13 @@ fn gen_vk_from_pk_single(
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(path_to_pk, ())
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
let vk = pk.get_vk();
// now save
save_vk::<KZGCommitmentScheme<Bn256>>(&vk_output_path, vk)
save_vk::<G1Affine>(&vk_output_path, vk)
.map_err(|_| PyIOError::new_err("Failed to save vk"))?;
Ok(true)
@@ -472,20 +558,27 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
#[pyfunction(signature = (
settings_path=PathBuf::from(DEFAULT_SETTINGS),
logrows=None,
srs_path=None
srs_path=None,
commitment=None,
))]
fn get_srs(
settings_path: Option<PathBuf>,
logrows: Option<u32>,
srs_path: Option<PathBuf>,
commitment: Option<PyCommitments>,
) -> PyResult<bool> {
let commitment: Option<Commitments> = match commitment {
Some(c) => Some(c.into()),
None => None,
};
Runtime::new()
.unwrap()
.block_on(crate::execute::get_srs_cmd(
srs_path,
settings_path,
logrows,
CheckMode::SAFE,
commitment,
))
.map_err(|e| {
let err_str = format!("Failed to get srs: {}", e);
@@ -525,7 +618,7 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
div_rebasing = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
fn calibrate_settings(
data: PathBuf,
@@ -536,7 +629,7 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
div_rebasing: Option<bool>,
only_range_check_rebase: bool,
) -> Result<bool, PyErr> {
crate::execute::calibrate(
model,
@@ -546,7 +639,7 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
only_range_check_rebase,
max_logrows,
)
.map_err(|e| {
@@ -623,7 +716,7 @@ fn mock_aggregate(
pk_path=PathBuf::from(DEFAULT_PK),
srs_path=None,
witness_path = None,
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
))]
fn setup(
model: PathBuf,
@@ -631,7 +724,7 @@ fn setup(
pk_path: PathBuf,
srs_path: Option<PathBuf>,
witness_path: Option<PathBuf>,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<bool, PyErr> {
crate::execute::setup(
model,
@@ -639,7 +732,7 @@ fn setup(
vk_path,
pk_path,
witness_path,
compress_selectors,
disable_selector_compression,
)
.map_err(|e| {
let err_str = format!("Failed to run setup: {}", e);
@@ -689,14 +782,23 @@ fn prove(
settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_path=PathBuf::from(DEFAULT_VK),
srs_path=None,
non_reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
))]
fn verify(
proof_path: PathBuf,
settings_path: PathBuf,
vk_path: PathBuf,
srs_path: Option<PathBuf>,
non_reduced_srs: bool,
) -> Result<bool, PyErr> {
crate::execute::verify(proof_path, settings_path, vk_path, srs_path).map_err(|e| {
crate::execute::verify(
proof_path,
settings_path,
vk_path,
srs_path,
non_reduced_srs,
)
.map_err(|e| {
let err_str = format!("Failed to run verify: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -711,7 +813,8 @@ fn verify(
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
srs_path = None,
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
@@ -720,7 +823,8 @@ fn setup_aggregate(
logrows: u32,
split_proofs: bool,
srs_path: Option<PathBuf>,
compress_selectors: bool,
disable_selector_compression: bool,
commitment: PyCommitments,
) -> Result<bool, PyErr> {
crate::execute::setup_aggregate(
sample_snarks,
@@ -729,7 +833,8 @@ fn setup_aggregate(
srs_path,
logrows,
split_proofs,
compress_selectors,
disable_selector_compression,
commitment.into(),
)
.map_err(|e| {
let err_str = format!("Failed to setup aggregate: {}", e);
@@ -767,6 +872,7 @@ fn compile_circuit(
check_mode=CheckMode::UNSAFE,
split_proofs = false,
srs_path=None,
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
fn aggregate(
aggregation_snarks: Vec<PathBuf>,
@@ -777,6 +883,7 @@ fn aggregate(
check_mode: CheckMode,
split_proofs: bool,
srs_path: Option<PathBuf>,
commitment: PyCommitments,
) -> Result<bool, PyErr> {
// the K used for the aggregation circuit
crate::execute::aggregate(
@@ -788,6 +895,7 @@ fn aggregate(
logrows,
check_mode,
split_proofs,
commitment.into(),
)
.map_err(|e| {
let err_str = format!("Failed to run aggregate: {}", e);
@@ -802,15 +910,27 @@ fn aggregate(
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
vk_path=PathBuf::from(DEFAULT_VK),
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
srs_path=None,
))]
fn verify_aggr(
proof_path: PathBuf,
vk_path: PathBuf,
logrows: u32,
commitment: PyCommitments,
reduced_srs: bool,
srs_path: Option<PathBuf>,
) -> Result<bool, PyErr> {
crate::execute::verify_aggr(proof_path, vk_path, srs_path, logrows).map_err(|e| {
crate::execute::verify_aggr(
proof_path,
vk_path,
srs_path,
logrows,
reduced_srs,
commitment.into(),
)
.map_err(|e| {
let err_str = format!("Failed to run verify_aggr: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -1022,24 +1142,15 @@ fn verify_evm(
addr_da: Option<&str>,
addr_vk: Option<&str>,
) -> Result<bool, PyErr> {
let addr_verifier = H160::from_str(addr_verifier).map_err(|e| {
let err_str = format!("address is invalid: {}", e);
PyRuntimeError::new_err(err_str)
})?;
let addr_verifier = H160Flag::from(addr_verifier);
let addr_da = if let Some(addr_da) = addr_da {
let addr_da = H160::from_str(addr_da).map_err(|e| {
let err_str = format!("address is invalid: {}", e);
PyRuntimeError::new_err(err_str)
})?;
let addr_da = H160Flag::from(addr_da);
Some(addr_da)
} else {
None
};
let addr_vk = if let Some(addr_vk) = addr_vk {
let addr_vk = H160::from_str(addr_vk).map_err(|e| {
let err_str = format!("address is invalid: {}", e);
PyRuntimeError::new_err(err_str)
})?;
let addr_vk = H160Flag::from(addr_vk);
Some(addr_vk)
} else {
None
@@ -1105,13 +1216,16 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add_function(wrap_pyfunction!(string_to_felt, m)?)?;
m.add_function(wrap_pyfunction!(string_to_int, m)?)?;
m.add_function(wrap_pyfunction!(string_to_float, m)?)?;
m.add_class::<PyCommitments>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
m.add_function(wrap_pyfunction!(kzg_commit, m)?)?;
m.add_function(wrap_pyfunction!(ipa_commit, m)?)?;
m.add_function(wrap_pyfunction!(swap_proof_commitments, m)?)?;
m.add_function(wrap_pyfunction!(poseidon_hash, m)?)?;
m.add_function(wrap_pyfunction!(float_to_string, m)?)?;
m.add_function(wrap_pyfunction!(float_to_felt, m)?)?;
m.add_function(wrap_pyfunction!(buffer_to_felts, m)?)?;
m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?;
m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?;

View File

@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&res), &dims)
}
/// Set a slice of the Tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
/// assert_eq!(a, b);
/// ```
pub fn set_slice(
&mut self,
indices: &[Range<usize>],
value: &Tensor<T>,
) -> Result<(), TensorError>
where
T: Send + Sync,
{
if indices.is_empty() {
return Ok(());
}
if self.dims.len() < indices.len() {
return Err(TensorError::DimError(format!(
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
indices, self.dims
)));
}
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
let omitted_dims = (indices.len()..self.dims.len())
.map(|i| self.dims[i])
.collect::<Vec<_>>();
for dim in &omitted_dims {
full_indices.push(0..*dim);
}
let full_dims = full_indices
.iter()
.map(|x| x.end - x.start)
.collect::<Vec<_>>();
// now broadcast the value to the full dims
let value = value.expand(&full_dims)?;
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
let _ = cartesian_coord
.iter()
.enumerate()
.map(|(i, e)| {
self.set(e, value[i].clone());
})
.collect::<Vec<_>>();
Ok(())
}
/// Get the array index from rows / columns indices.
///
/// ```
@@ -1526,18 +1588,20 @@ pub fn get_broadcasted_shape(
let num_dims_a = shape_a.len();
let num_dims_b = shape_b.len();
// reewrite the below using match
if num_dims_a == num_dims_b {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
match (num_dims_a, num_dims_b) {
(a, b) if a == b => {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
}
Ok(broadcasted_shape)
}
Ok(broadcasted_shape)
} else if num_dims_a < num_dims_b {
Ok(shape_b.to_vec())
} else {
Ok(shape_a.to_vec())
(a, b) if a < b => Ok(shape_b.to_vec()),
(a, b) if a > b => Ok(shape_a.to_vec()),
_ => Err(Box::new(TensorError::DimError(
"Unknown condition for broadcasting".to_string(),
))),
}
}
////////////////////////

View File

@@ -2,7 +2,7 @@ use super::TensorError;
use crate::tensor::{Tensor, TensorType};
use itertools::Itertools;
use maybe_rayon::{
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
prelude::IntoParallelRefIterator,
};
use std::collections::{HashMap, HashSet};
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
Ok(output)
}
/// Gather ND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to gather
/// * `batch_dims` - Number of batch dimensions
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::gather_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3]),
/// &[2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 0, 1, 1]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
/// &[2, 2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 1, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 1).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
/// &[2, 2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
/// &[2, 2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if last_value > &(input_dims.len() - batch_dims) {
return Err(TensorError::DimMismatch("gather_nd".to_string()));
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
// cartesian coord over batch dims
let mut batch_cartesian_coord = input_dims[0..batch_dims]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let outputs = batch_cartesian_coord
.par_iter()
.map(|batch_coord| {
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&batch_slice)?;
index_slice.reshape(&index.dims()[batch_dims..])?;
let mut input_slice = input.get_slice(&batch_slice)?;
input_slice.reshape(&input.dims()[batch_dims..])?;
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let output = inner_cartesian_coord
.iter()
.map(|coord| {
let slice = coord
.iter()
.map(|x| *x..*x + 1)
.chain(batch_coord.iter().map(|x| *x..*x + 1))
.collect::<Vec<_>>();
let index_slice = index_slice
.get_slice(&slice)
.unwrap()
.iter()
.map(|x| *x..*x + 1)
.collect::<Vec<_>>();
input_slice.get_slice(&index_slice).unwrap()
})
.collect::<Tensor<_>>();
output.combine()
})
.collect::<Result<Vec<_>, _>>()?;
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
outputs.reshape(&output_size)?;
Ok(outputs)
}
/// Scatter ND.
/// This operator is the inverse of GatherND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to scatter
/// * `src` - Tensor of src
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::scatter_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[8],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[4, 3, 1, 7]),
/// &[4, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10, 11, 12]),
/// &[4],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 2]),
/// &[2, 1],
/// ).unwrap();
///
/// let src = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// ]),
/// &[2, 4, 4],
/// ).unwrap();
///
/// let result = scatter_nd(&x, &index, &src).unwrap();
///
/// let expected = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[2, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10]),
/// &[2],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[1, 1, 2],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9]),
/// &[1, 1],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
if last_value > &input_dims.len() {
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
}
let mut output = input.clone();
let cartesian_coord = index_dims[0..index_dims.len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
cartesian_coord
.iter()
.map(|coord| {
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok(())
})
.collect::<Result<Vec<_>, _>>()?;
Ok(output)
}
fn axes_op<T: TensorType + Send + Sync>(
a: &Tensor<T>,
axes: &[usize],
@@ -3773,6 +4083,30 @@ pub mod nonlinearities {
.unwrap()
}
/// Elementwise inverse.
/// # Arguments
/// * `out_scale` - Single value
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
/// let k = 2_f64;
/// let result = zero_recip(1.0);
/// let expected = Tensor::<i128>::new(Some(&[4503599627370496]), &[1]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn zero_recip(out_scale: f64) -> Tensor<i128> {
let a = Tensor::<i128>::new(Some(&[0]), &[1]).unwrap();
a.par_enum_map(|_, a_i| {
let rescaled = a_i as f64;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as i128)
})
.unwrap()
}
/// Elementwise greater than
/// # Arguments
///

View File

@@ -4,6 +4,37 @@ use super::{
};
use halo2_proofs::{arithmetic::Field, plonk::Instance};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
val: F,
len: usize,
) -> ValTensor<F> {
let mut constant = Tensor::from(vec![ValType::Constant(val); len].into_iter());
constant.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(constant)
}
pub(crate) fn create_unit_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
len: usize,
) -> ValTensor<F> {
let mut unit = Tensor::from(vec![ValType::Constant(F::ONE); len].into_iter());
unit.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(unit)
}
pub(crate) fn create_zero_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
len: usize,
) -> ValTensor<F> {
let mut zero = Tensor::from(vec![ValType::Constant(F::ZERO); len].into_iter());
zero.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(zero)
}
#[derive(Debug, Clone)]
/// A [ValType] is a wrapper around Halo2 value(s).
pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> {
@@ -318,6 +349,19 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
matches!(self, ValTensor::Instance { .. })
}
/// reverse order of elements whilst preserving the shape
pub fn reverse(&mut self) -> Result<(), Box<dyn Error>> {
match self {
ValTensor::Value { inner: v, .. } => {
v.reverse();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
}
};
Ok(())
}
///
pub fn set_initial_instance_offset(&mut self, offset: usize) {
if let ValTensor::Instance { initial_offset, .. } = self {
@@ -450,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(integer_evals.into_iter().into())
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
Ok(tensor)
}
/// Calls `pad_to_zero_rem` on the inner tensor.
@@ -672,7 +721,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}
@@ -690,7 +739,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}
@@ -709,7 +758,11 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
if indices.is_empty() {
return Ok(());
} else {
return Err(TensorError::WrongMethod);
}
}
}
Ok(())

View File

@@ -33,10 +33,7 @@ pub enum VarTensor {
impl VarTensor {
///
pub fn is_advice(&self) -> bool {
match self {
VarTensor::Advice { .. } => true,
_ => false,
}
matches!(self, VarTensor::Advice { .. })
}
///

View File

@@ -6,17 +6,33 @@ use crate::fieldutils::i128_to_felt;
use crate::graph::modules::POSEIDON_LEN_GRAPH;
use crate::graph::quantize_float;
use crate::graph::scale_to_multiplier;
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::create_proof_circuit;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::pfsys::verify_proof_circuit;
use crate::pfsys::TranscriptType;
use crate::tensor::TensorType;
use crate::CheckMode;
use crate::Commitments;
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
use halo2_proofs::poly::ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
strategy::SingleStrategy as IPASingleStrategy,
};
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::{
commitment::{KZGCommitmentScheme, ParamsKZG},
strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, PrimeField};
use crate::tensor::TensorType;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
@@ -37,9 +53,6 @@ pub fn init_panic_hook() {
console_error_panic_hook::set_once();
}
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
/// Wrapper around the halo2 encode call data method
#[wasm_bindgen]
#[allow(non_snake_case)]
@@ -69,19 +82,30 @@ pub fn encodeVerifierCalldata(
Ok(encoded)
}
/// Converts 4 u64s to a field element
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn stringToFelt(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(format!("{:?}", felt))
}
/// Converts 4 u64s representing a field element directly to an integer
/// Converts a felt to a little endian string
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn stringToInt(
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let repr = serde_json::to_string(&felt).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
Ok(b)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToInt(
array: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
@@ -92,10 +116,10 @@ pub fn stringToInt(
))
}
/// Converts 4 u64s representing a field element directly to a (rescaled from fixed point scaling) floating point
/// Converts felts to a floating point element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn stringToFloat(
pub fn feltToFloat(
array: wasm_bindgen::Clamped<Vec<u8>>,
scale: crate::Scale,
) -> Result<f64, JsError> {
@@ -106,26 +130,26 @@ pub fn stringToFloat(
Ok(int_rep as f64 / multiplier)
}
/// Converts a floating point element to 4 u64s representing a fixed point field element
/// Converts a floating point number to a hex string representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn floatTostring(
pub fn floatToFelt(
input: f64,
scale: crate::Scale,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = i128_to_felt(int_rep);
let vec = crate::pfsys::field_to_string_montgomery::<halo2curves::bn256::Fr>(&felt);
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|e| JsError::new(&format!("Failed to serialize string_montgomery{}", e)),
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
)?))
}
/// Converts a buffer to vector of 4 u64s representing a fixed point field element
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn bufferToVecOfstring(
pub fn bufferToVecOfFelt(
buffer: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
// Convert the buffer to a slice
@@ -211,7 +235,7 @@ pub fn genWitness(
.map_err(|e| JsError::new(&format!("{}", e)))?;
let witness = circuit
.forward(&mut input, None, None)
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false)
.map_err(|e| JsError::new(&format!("{}", e)))?;
serde_json::to_vec(&witness)
@@ -296,28 +320,74 @@ pub fn verify(
settings: wasm_bindgen::Clamped<Vec<u8>>,
srs: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<bool, JsError> {
let mut reader = std::io::BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
let snark: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
circuit_settings,
circuit_settings.clone(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
let orig_n = 1 << circuit_settings.run_args.logrows;
let result = verify_proof_circuit_kzg(params.verifier_params(), snark, &vk, strategy);
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match circuit_settings.run_args.commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
@@ -337,11 +407,6 @@ pub fn prove(
log::set_max_level(log::LevelFilter::Debug);
#[cfg(not(feature = "det-prove"))]
log::set_max_level(log::LevelFilter::Info);
// read in kzg params
let mut reader = std::io::BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?;
// read in circuit
let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..])
@@ -369,17 +434,63 @@ pub fn prove(
.map_err(|e| JsError::new(&format!("{}", e)))?;
let proof_split_commits: Option<crate::pfsys::ProofSplitCommit> = data.into();
let strategy = KZGSingleStrategy::new(&params);
let proof = create_proof_circuit_kzg(
circuit,
&params,
Some(public_inputs),
&pk,
crate::pfsys::TranscriptType::EVM,
strategy,
crate::circuit::CheckMode::UNSAFE,
proof_split_commits,
)
// read in kzg params
let mut reader = std::io::BufReader::new(&srs[..]);
// creates and verifies the proof
let proof = match circuit.settings().run_args.commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?;
create_proof_circuit::<
KZGCommitmentScheme<Bn256>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
KZGSingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit,
vec![public_inputs],
&params,
&pk,
CheckMode::UNSAFE,
crate::Commitments::KZG,
TranscriptType::EVM,
proof_split_commits,
None,
)
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?;
create_proof_circuit::<
IPACommitmentScheme<G1Affine>,
_,
ProverIPA<_>,
VerifierIPA<_>,
IPASingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit,
vec![public_inputs],
&params,
&pk,
CheckMode::UNSAFE,
crate::Commitments::IPA,
TranscriptType::EVM,
proof_split_commits,
None,
)
}
}
.map_err(|e| JsError::new(&format!("{}", e)))?;
Ok(serde_json::to_string(&proof)
@@ -387,15 +498,6 @@ pub fn prove(
.into_bytes())
}
/// print hex representation of a proof
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn printProofHex(proof: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let hex_str = hex::encode(proof.proof);
Ok(format!("0x{}", hex_str))
}
// VALIDATION FUNCTIONS
/// Witness file validation

View File

@@ -1 +0,0 @@
{"protocol":null,"instances":[[[12436184717236109307,3962172157175319849,7381016538464732718,1011752739694698287]],[[6425625360762666998,7924344314350639699,14762033076929465436,2023505479389396574]]],"proof":[1,2,3,4,5,6,7,8],"transcript_type":"EVM"}

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,7 @@ def get_ezkl_output(witness_file, settings_file):
outputs = witness_output['outputs']
with open(settings_file) as f:
settings = json.load(f)
ezkl_outputs = [[ezkl.string_to_float(
ezkl_outputs = [[ezkl.felt_to_float(
outputs[i][j], settings['model_output_scales'][i]) for j in range(len(outputs[i]))] for i in range(len(outputs))]
return ezkl_outputs
@@ -91,9 +91,7 @@ def compare_outputs(zk_output, onnx_output):
print("------- zk_output: ", list1_i)
print("------- onnx_output: ", list2_i)
return np.mean(np.abs(res))
return res
if __name__ == '__main__':
@@ -113,6 +111,9 @@ if __name__ == '__main__':
onnx_output = get_onnx_output(model_file, input_file)
# compare the outputs
percentage_difference = compare_outputs(ezkl_output, onnx_output)
mean_percentage_difference = np.mean(np.abs(percentage_difference))
max_percentage_difference = np.max(np.abs(percentage_difference))
# print the percentage difference
print("mean percent diff: ", percentage_difference)
assert percentage_difference < target, "Percentage difference is too high"
print("mean percent diff: ", mean_percentage_difference)
print("max percent diff: ", max_percentage_difference)
assert mean_percentage_difference < target, "Percentage difference is too high"

View File

@@ -61,6 +61,7 @@ mod py_tests {
let status = Command::new("pip")
.args([
"install",
"torch-geometric==2.5.0",
"torch==2.0.1",
"torchvision==0.15.2",
"pandas==2.0.3",

View File

@@ -56,9 +56,9 @@ def test_poseidon_hash():
Test for poseidon_hash
"""
message = [1.0, 2.0, 3.0, 4.0]
message = [ezkl.float_to_string(x, 7) for x in message]
message = [ezkl.float_to_felt(x, 7) for x in message]
res = ezkl.poseidon_hash(message)
assert ezkl.string_to_felt(
assert ezkl.felt_to_big_endian(
res[0]) == "0x0da7e5e5c8877242fa699f586baf770d731defd54f952d4adeb85047a0e32f45"
@@ -70,14 +70,14 @@ def test_field_serialization():
input = 890
scale = 7
felt = ezkl.float_to_string(input, scale)
roundtrip_input = ezkl.string_to_float(felt, scale)
felt = ezkl.float_to_felt(input, scale)
roundtrip_input = ezkl.felt_to_float(felt, scale)
assert input == roundtrip_input
input = -700
scale = 7
felt = ezkl.float_to_string(input, scale)
roundtrip_input = ezkl.string_to_float(felt, scale)
felt = ezkl.float_to_felt(input, scale)
roundtrip_input = ezkl.felt_to_float(felt, scale)
assert input == roundtrip_input
@@ -88,12 +88,12 @@ def test_buffer_to_felts():
buffer = bytearray("a sample string!", 'utf-8')
felts = ezkl.buffer_to_felts(buffer)
ref_felt_1 = "0x0000000000000000000000000000000021676e6972747320656c706d61732061"
assert felts == [ref_felt_1]
assert ezkl.felt_to_big_endian(felts[0]) == ref_felt_1
buffer = bytearray("a sample string!"+"high", 'utf-8')
felts = ezkl.buffer_to_felts(buffer)
ref_felt_2 = "0x0000000000000000000000000000000000000000000000000000000068676968"
assert felts == [ref_felt_1, ref_felt_2]
assert [ezkl.felt_to_big_endian(felts[0]), ezkl.felt_to_big_endian(felts[1])] == [ref_felt_1, ref_felt_2]
def test_gen_srs():
@@ -242,7 +242,7 @@ def test_get_srs():
another_srs_path = os.path.join(folder_path, "kzg_test_k8.params")
res = ezkl.get_srs(logrows=8, srs_path=another_srs_path)
res = ezkl.get_srs(logrows=8, srs_path=another_srs_path, commitment=ezkl.PyCommitments.KZG)
assert os.path.isfile(another_srs_path)

View File

@@ -8,9 +8,9 @@ mod wasm32 {
use ezkl::graph::GraphWitness;
use ezkl::pfsys;
use ezkl::wasm::{
bufferToVecOfstring, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk,
genWitness, inputValidation, pkValidation, poseidonHash, printProofHex, proofValidation,
prove, settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt,
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
@@ -76,22 +76,29 @@ mod wasm32 {
for i in 0..32 {
let field_element = Fr::from(i);
let serialized = serde_json::to_vec(&field_element).unwrap();
let clamped = wasm_bindgen::Clamped(serialized);
let scale = 2;
let floating_point = stringToFloat(clamped.clone(), scale)
let floating_point = feltToFloat(clamped.clone(), scale)
.map_err(|_| "failed")
.unwrap();
assert_eq!(floating_point, (i as f64) / 4.0);
let integer: i128 = serde_json::from_slice(
&stringToInt(clamped.clone()).map_err(|_| "failed").unwrap(),
)
.unwrap();
let integer: i128 =
serde_json::from_slice(&feltToInt(clamped.clone()).map_err(|_| "failed").unwrap())
.unwrap();
assert_eq!(integer, i as i128);
let hex_string = format!("{:?}", field_element);
let returned_string = stringToFelt(clamped).map_err(|_| "failed").unwrap();
let hex_string = format!("{:?}", field_element.clone());
let returned_string: String = feltToBigEndian(clamped.clone())
.map_err(|_| "failed")
.unwrap();
assert_eq!(hex_string, returned_string);
let repr = serde_json::to_string(&field_element).unwrap();
let little_endian_string: String = serde_json::from_str(&repr).unwrap();
let returned_string: String =
feltToLittleEndian(clamped).map_err(|_| "failed").unwrap();
assert_eq!(little_endian_string, returned_string);
}
}
@@ -101,7 +108,7 @@ mod wasm32 {
let mut buffer = string_high.clone().into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
@@ -118,7 +125,7 @@ mod wasm32 {
let buffer = string_sample.clone().into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();
@@ -133,7 +140,7 @@ mod wasm32 {
let buffer = string_concat.into_bytes();
let clamped = wasm_bindgen::Clamped(buffer.clone());
let field_elements_ser = bufferToVecOfstring(clamped).map_err(|_| "failed").unwrap();
let field_elements_ser = bufferToVecOfFelt(clamped).map_err(|_| "failed").unwrap();
let field_elements: Vec<Fr> = serde_json::from_slice(&field_elements_ser[..]).unwrap();

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -24,9 +24,14 @@
"param_visibility": "Private",
"div_rebasing": false,
"rebase_frac_zero_constants": false,
"check_mode": "UNSAFE"
"check_mode": "UNSAFE",
"commitment": "KZG"
},
"num_rows": 16,
"total_dynamic_col_size": 0,
"num_dynamic_lookups": 0,
"num_shuffles": 0,
"total_shuffle_col_size": 0,
"total_assignments": 32,
"total_const_size": 8,
"model_instance_shapes": [
@@ -42,7 +47,7 @@
0
],
"module_sizes": {
"kzg": [],
"polycommit": [],
"poseidon": [
0,
[

View File

@@ -1,28 +1,42 @@
import localEVMVerify, { Hardfork } from '@ezkljs/verify'
import localEVMVerify from '../../in-browser-evm-verifier/src/index'
import { serialize, deserialize } from '@ezkljs/engine/nodejs'
import { compileContracts } from './utils'
import * as fs from 'fs'
exports.USER_NAME = require("minimist")(process.argv.slice(2))["example"];
exports.EXAMPLE = require("minimist")(process.argv.slice(2))["example"];
exports.PATH = require("minimist")(process.argv.slice(2))["dir"];
exports.VK = require("minimist")(process.argv.slice(2))["vk"];
describe('localEVMVerify', () => {
let bytecode: string
let bytecode_verifier: string
let bytecode_vk: string | undefined = undefined
let proof: any
const example = exports.USER_NAME || "1l_mlp"
const example = exports.EXAMPLE || "1l_mlp"
const path = exports.PATH || "../ezkl/examples/onnx"
const vk = exports.VK || false
beforeEach(() => {
let solcOutput = compileContracts(path, example)
const solcOutput = compileContracts(path, example, 'kzg')
bytecode =
bytecode_verifier =
solcOutput.contracts['artifacts/Verifier.sol']['Halo2Verifier'].evm.bytecode
.object
console.log('size', bytecode.length)
if (vk) {
const solcOutput_vk = compileContracts(path, example, 'vk')
bytecode_vk =
solcOutput_vk.contracts['artifacts/Verifier.sol']['Halo2VerifyingKey'].evm.bytecode
.object
console.log('size of verifier bytecode', bytecode_verifier.length)
}
console.log('verifier bytecode', bytecode_verifier)
})
it('should return true when verification succeeds', async () => {
@@ -30,7 +44,9 @@ describe('localEVMVerify', () => {
proof = deserialize(proofFileBuffer)
const result = await localEVMVerify(proofFileBuffer, bytecode)
const result = await localEVMVerify(proofFileBuffer, bytecode_verifier, bytecode_vk)
console.log('result', result)
expect(result).toBe(true)
})
@@ -39,13 +55,16 @@ describe('localEVMVerify', () => {
let result: boolean = true
console.log(proof.proof)
try {
let index = Math.floor(Math.random() * (proof.proof.length - 2)) + 2
let number = (proof.proof[index] + 1) % 16
let index = Math.round((Math.random() * (proof.proof.length))) % proof.proof.length
console.log('index', index)
console.log('index', proof.proof[index])
let number = (proof.proof[index] + 1) % 256
console.log('index', index)
console.log('new number', number)
proof.proof[index] = number
console.log('index post', proof.proof[index])
const proofModified = serialize(proof)
result = await localEVMVerify(proofModified, bytecode)
result = await localEVMVerify(proofModified, bytecode_verifier, bytecode_vk)
} catch (error) {
// Check if the error thrown is the "out of gas" error.
expect(error).toEqual(

View File

@@ -38,7 +38,10 @@ describe('Generate witness, prove and verify', () => {
let pk = await readEzklArtifactsFile(path, example, 'key.pk');
let circuit_ser = await readEzklArtifactsFile(path, example, 'network.compiled');
circuit_settings_ser = await readEzklArtifactsFile(path, example, 'settings.json');
params_ser = await readEzklSrsFile(path, example);
// get the log rows from the circuit settings
const circuit_settings = deserialize(circuit_settings_ser) as any;
const logrows = circuit_settings.run_args.logrows as string;
params_ser = await readEzklSrsFile(logrows);
const startTimeProve = Date.now();
result = wasmFunctions.prove(witness, pk, circuit_ser, params_ser);
const endTimeProve = Date.now();
@@ -54,6 +57,7 @@ describe('Generate witness, prove and verify', () => {
let result
const vk = await readEzklArtifactsFile(path, example, 'key.vk');
const startTimeVerify = Date.now();
params_ser = await readEzklSrsFile("1");
result = wasmFunctions.verify(proof_ser, vk, circuit_settings_ser, params_ser);
const result_ref = wasmFunctions.verify(proof_ser_ref, vk, circuit_settings_ser, params_ser);
const endTimeVerify = Date.now();

View File

@@ -16,15 +16,7 @@ export async function readEzklArtifactsFile(path: string, example: string, filen
return new Uint8ClampedArray(buffer.buffer);
}
export async function readEzklSrsFile(path: string, example: string): Promise<Uint8ClampedArray> {
// const settingsPath = path.join(__dirname, '..', '..', 'ezkl', 'examples', 'onnx', example, 'settings.json');
const settingsPath = `${path}/${example}/settings.json`
const settingsBuffer = await fs.readFile(settingsPath, { encoding: 'utf-8' });
const settings = JSONBig.parse(settingsBuffer);
const logrows = settings.run_args.logrows;
// const filePath = path.join(__dirname, '..', '..', 'ezkl', 'examples', 'onnx', `kzg${logrows}.srs`);
// srs path is at $HOME/.ezkl/srs
export async function readEzklSrsFile(logrows: string): Promise<Uint8ClampedArray> {
const filePath = `${userHomeDir}/.ezkl/srs/kzg${logrows}.srs`
const buffer = await fs.readFile(filePath);
return new Uint8ClampedArray(buffer.buffer);
@@ -51,21 +43,21 @@ export function serialize(data: object | string): Uint8ClampedArray { // data is
return new Uint8ClampedArray(uint8Array.buffer);
}
export function getSolcInput(path: string, example: string) {
export function getSolcInput(path: string, example: string, name: string) {
return {
language: 'Solidity',
sources: {
'artifacts/Verifier.sol': {
content: fsSync.readFileSync(`${path}/${example}/kzg.sol`, 'utf-8'),
content: fsSync.readFileSync(`${path}/${example}/${name}.sol`, 'utf-8'),
},
// If more contracts were to be compiled, they should have their own entries here
},
settings: {
optimizer: {
enabled: true,
runs: 200,
runs: 1,
},
evmVersion: 'london',
evmVersion: 'shanghai',
outputSelection: {
'*': {
'*': ['abi', 'evm.bytecode'],
@@ -75,8 +67,8 @@ export function getSolcInput(path: string, example: string) {
}
}
export function compileContracts(path: string, example: string) {
const input = getSolcInput(path, example)
export function compileContracts(path: string, example: string, name: string) {
const input = getSolcInput(path, example, name)
const output = JSON.parse(solc.compile(JSON.stringify(input)))
let compilationFailed = false

Binary file not shown.

View File

@@ -1 +1 @@
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1}
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0}

1
verifier_abi.json Normal file
View File

@@ -0,0 +1 @@
[{"type":"function","name":"verifyProof","inputs":[{"internalType":"bytes","name":"proof","type":"bytes"},{"internalType":"uint256[]","name":"instances","type":"uint256[]"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"nonpayable"}]

1
vk.abi Normal file
View File

@@ -0,0 +1 @@
[{"type":"constructor","inputs":[]}]