Compare commits

...

19 Commits

Author SHA1 Message Date
dante
5cb303b149 Merge branch 'main' into example-reorg 2024-02-05 14:43:01 +00:00
dante
2a1ee1102c refactor: range check recip (#703) 2024-02-05 14:42:26 +00:00
Sofia Wawrzyniak
9fb78c36e0 readding examples 2024-02-05 09:41:01 -05:00
Sofia Wawrzyniak
074db5d229 preliminary bucketing of examples 2024-02-05 09:09:41 -05:00
dante
95d4fd4a70 feat: power of 2 div using type system (#702) 2024-02-04 02:43:38 +00:00
dante
e0d3f4f145 fix: uncomparable values in acc table (#701) 2024-02-02 15:13:29 +00:00
dante
bceac2fab5 ci: make gpu tests single threaded (#700) 2024-01-31 18:19:29 +00:00
dante
04d7b5feaa chore: fold div_rebasing parameter into calibration (#699) 2024-01-31 10:03:12 +00:00
dante
45fd12a04f refactor!: make rebasing multiplicative by default (#698)
BREAKING CHANGE: adds a `required_range_checks` field to `cs`
2024-01-30 18:37:57 +00:00
dante
bc7c33190f feat: allow for separate vk render on-chain (#697) 2024-01-25 19:48:13 +00:00
dante
df72e01414 feat: make selector compression optional (#696) 2024-01-24 00:09:00 +00:00
Tobin South
172e26c00d fix: link for CLI auto-install (#695) 2024-01-22 13:00:27 +00:00
Jason Morton
11ac120f23 fix: large test numbering(#689)
Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2024-01-21 21:01:46 +00:00
Jseam
0fdd92e9f3 fix: move install_ezkl_cli.sh into main repo (#694)
Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2024-01-21 20:59:39 +00:00
Alexander Camuto
31f58056a5 chore: bump py03 2024-01-21 20:58:32 +00:00
dante
ddbcc1d2d8 fix: calibration should only consider local scales (#691) 2024-01-18 23:28:49 +00:00
Vehorny
feccc5feed chore(examples): proofreading the notebooks (#687)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
2024-01-18 14:48:02 +00:00
dante
db24577c5d fix: calibrate from total min/max on lookups rather than individual x (#690) 2024-01-17 23:59:15 +00:00
Jseam
bb482e3cac fix: set max_logrows for calibrate_settings (#688) 2024-01-16 17:40:45 +00:00
121 changed files with 8635 additions and 15503 deletions

View File

@@ -6,7 +6,7 @@ on:
description: "Test scenario tags"
jobs:
large-tests:
runs-on: self-hosted
runs-on: kaiju
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -23,6 +23,6 @@ jobs:
- name: Self Attention KZG prove and verify large tests
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_0_expects -- --include-ignored
- name: mobilenet Mock
run: cargo test --release --verbose tests::large_mock_::large_tests_2_expects -- --include-ignored
run: cargo test --release --verbose tests::large_mock_::large_tests_3_expects -- --include-ignored
- name: mobilenet KZG prove and verify large tests
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_2_expects -- --include-ignored
run: cargo test --release --verbose tests::large_kzg_prove_and_verify_::large_tests_3_expects -- --include-ignored

View File

@@ -313,6 +313,8 @@ jobs:
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_kzg_all_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + kzg inputs)
@@ -425,21 +427,21 @@ jobs:
crate: cargo-nextest
locked: true
- name: KZG prove and verify tests (kzg outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
- name: KZG prove and verify tests (public inputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
- name: KZG prove and verify tests (fixed params)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 2
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
fuzz-tests:
runs-on: ubuntu-latest-32-cores
@@ -465,7 +467,7 @@ jobs:
# run: cargo nextest run --release --verbose tests::kzg_fuzz_ --test-threads 6
prove-and-verify-mock-aggr-tests:
runs-on: ubuntu-latest-32-cores
runs-on: self-hosted
needs: [build, library-tests]
steps:
- uses: actions/checkout@v4
@@ -601,6 +603,8 @@ jobs:
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
- name: fixed params
@@ -611,27 +615,7 @@ jobs:
run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_
python-integration-tests:
runs-on:
large-self-hosted
# Service containers to run with `container-job`
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
# needs: [build, library-tests, docs]
runs-on: large-self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
@@ -663,13 +647,13 @@ jobs:
# # now dump the contents of the file into a file called kaggle.json
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --no-capture
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --no-capture
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

141
Cargo.lock generated
View File

@@ -1058,16 +1058,6 @@ dependencies = [
"itertools 0.10.5",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.3"
@@ -1088,7 +1078,7 @@ dependencies = [
"autocfg",
"cfg-if",
"crossbeam-utils",
"memoffset 0.9.0",
"memoffset",
"scopeguard",
]
@@ -1378,7 +1368,7 @@ checksum = "68b0cf012f1230e43cd00ebb729c6bb58707ecfa8ad08b52ef3a4ccd2697fc30"
[[package]]
name = "ecc"
version = "0.1.0"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#c1d7551c82953829caee30fe218759b0d2657d26"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
dependencies = [
"integer",
"num-bigint",
@@ -1862,7 +1852,7 @@ dependencies = [
"halo2_gadgets",
"halo2_proofs",
"halo2_solidity_verifier",
"halo2curves 0.1.0",
"halo2curves 0.6.0",
"hex",
"indicatif",
"instant",
@@ -2253,7 +2243,7 @@ dependencies = [
[[package]]
name = "halo2_gadgets"
version = "0.2.0"
source = "git+https://github.com/zkonduit/halo2?branch=ac/lookup-modularity#57b9123835aa7d8482f4182ede3e8f4b0aea5c0a"
source = "git+https://github.com/zkonduit/halo2?branch=main#6a2b9ada9804807ddba03bbadaf6e63822cec275"
dependencies = [
"arrayvec 0.7.4",
"bitvec 1.0.1",
@@ -2269,14 +2259,14 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.2.0"
source = "git+https://github.com/zkonduit/halo2?branch=ac/lookup-modularity#57b9123835aa7d8482f4182ede3e8f4b0aea5c0a"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2?branch=main#6a2b9ada9804807ddba03bbadaf6e63822cec275"
dependencies = [
"blake2b_simd",
"env_logger",
"ff",
"group",
"halo2curves 0.1.0",
"halo2curves 0.6.0",
"icicle",
"log",
"maybe-rayon",
@@ -2292,7 +2282,7 @@ dependencies = [
[[package]]
name = "halo2_solidity_verifier"
version = "0.1.0"
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/lookup-modularity#cf9a3128bb583680dd4c418defd8d37bd8e5c3f1"
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=main#eb04be1f7d005e5b9dd3ff41efa30aeb5e0c34a3"
dependencies = [
"askama",
"blake2b_simd",
@@ -2319,8 +2309,6 @@ dependencies = [
"paste",
"rand 0.8.5",
"rand_core 0.6.4",
"serde",
"serde_arrays",
"static_assertions",
"subtle",
]
@@ -2343,10 +2331,35 @@ dependencies = [
"subtle",
]
[[package]]
name = "halo2curves"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3675880dc0cc7cd468943266297198a28f88210ba60ca5e0e04d121edf86b46"
dependencies = [
"blake2b_simd",
"ff",
"group",
"hex",
"lazy_static",
"num-bigint",
"num-traits",
"pairing",
"pasta_curves",
"paste",
"rand 0.8.5",
"rand_core 0.6.4",
"rayon",
"serde",
"serde_arrays",
"static_assertions",
"subtle",
]
[[package]]
name = "halo2wrong"
version = "0.1.0"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#c1d7551c82953829caee30fe218759b0d2657d26"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
dependencies = [
"halo2_proofs",
"num-bigint",
@@ -2423,6 +2436,9 @@ name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
dependencies = [
"serde",
]
[[package]]
name = "hex-literal"
@@ -2668,9 +2684,9 @@ dependencies = [
[[package]]
name = "indoc"
version = "1.0.9"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8"
[[package]]
name = "inout"
@@ -2696,7 +2712,7 @@ dependencies = [
[[package]]
name = "integer"
version = "0.1.0"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#c1d7551c82953829caee30fe218759b0d2657d26"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
dependencies = [
"maingate",
"num-bigint",
@@ -2941,7 +2957,7 @@ checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
[[package]]
name = "maingate"
version = "0.1.0"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#c1d7551c82953829caee30fe218759b0d2657d26"
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/chunked-mv-lookup#b43ebe30e84825d0d004fa27803d99c4187d419f"
dependencies = [
"halo2wrong",
"num-bigint",
@@ -3001,15 +3017,6 @@ dependencies = [
"libc",
]
[[package]]
name = "memoffset"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
dependencies = [
"autocfg",
]
[[package]]
name = "memoffset"
version = "0.9.0"
@@ -3335,6 +3342,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "pairing"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81fec4625e73cf41ef4bb6846cafa6d44736525f442ba45e407c4a000a13996f"
dependencies = [
"group",
]
[[package]]
name = "papergrid"
version = "0.9.1"
@@ -3840,14 +3856,14 @@ dependencies = [
[[package]]
name = "pyo3"
version = "0.18.3"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109"
checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0"
dependencies = [
"cfg-if",
"indoc",
"libc",
"memoffset 0.8.0",
"memoffset",
"parking_lot",
"pyo3-build-config",
"pyo3-ffi",
@@ -3857,9 +3873,9 @@ dependencies = [
[[package]]
name = "pyo3-asyncio"
version = "0.18.0"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3564762e37035cfc486228e10b0528460fa026d681b5763873c693aa0d5c260"
checksum = "6ea6b68e93db3622f3bb3bf363246cf948ed5375afe7abff98ccbdd50b184995"
dependencies = [
"futures",
"once_cell",
@@ -3871,9 +3887,9 @@ dependencies = [
[[package]]
name = "pyo3-asyncio-macros"
version = "0.18.0"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be72d4cd43a27530306bd0d20d3932182fbdd072c6b98d3638bc37efb9d559dd"
checksum = "56c467178e1da6252c95c29ecf898b133f742e9181dca5def15dc24e19d45a39"
dependencies = [
"proc-macro2",
"quote",
@@ -3882,9 +3898,9 @@ dependencies = [
[[package]]
name = "pyo3-build-config"
version = "0.18.3"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cb946f5ac61bb61a5014924910d936ebd2b23b705f7a4a3c40b05c720b079a3"
checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be"
dependencies = [
"once_cell",
"target-lexicon",
@@ -3892,9 +3908,9 @@ dependencies = [
[[package]]
name = "pyo3-ffi"
version = "0.18.3"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd4d7c5337821916ea2a1d21d1092e8443cf34879e53a0ac653fbb98f44ff65c"
checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1"
dependencies = [
"libc",
"pyo3-build-config",
@@ -3902,9 +3918,9 @@ dependencies = [
[[package]]
name = "pyo3-log"
version = "0.8.2"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c94ff6535a6bae58d7d0b85e60d4c53f7f84d0d0aa35d6a28c3f3e70bfe51444"
checksum = "4c10808ee7250403bedb24bc30c32493e93875fef7ba3e4292226fe924f398bd"
dependencies = [
"arc-swap",
"log",
@@ -3913,25 +3929,26 @@ dependencies = [
[[package]]
name = "pyo3-macros"
version = "0.18.3"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9d39c55dab3fc5a4b25bbd1ac10a2da452c4aca13bb450f22818a002e29648d"
checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 1.0.109",
"syn 2.0.22",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.18.3"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97daff08a4c48320587b5224cc98d609e3c27b6d437315bd40b605c98eeb5918"
checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 1.0.109",
"syn 2.0.22",
]
[[package]]
@@ -4040,9 +4057,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.7.0"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b"
checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051"
dependencies = [
"either",
"rayon-core",
@@ -4050,14 +4067,12 @@ dependencies = [
[[package]]
name = "rayon-core"
version = "1.11.0"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-utils",
"num_cpus",
]
[[package]]
@@ -4745,11 +4760,11 @@ dependencies = [
[[package]]
name = "snark-verifier"
version = "0.1.1"
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#22ee76bee1a24f3732e994b72b10ec09939348de"
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/chunked-mv-lookup#574b65ea6b4d43eebac5565146519a95b435815c"
dependencies = [
"ecc",
"halo2_proofs",
"halo2curves 0.1.0",
"halo2curves 0.6.0",
"hex",
"itertools 0.10.5",
"lazy_static",
@@ -5552,9 +5567,9 @@ checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "unindent"
version = "0.1.11"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]]
name = "unzip-n"

View File

@@ -15,9 +15,9 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "ac/lookup-modularity" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "ac/lookup-modularity" }
halo2curves = { version = "0.1.0" }
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2curves = { version = "0.6.0", features = ["derive_serde"] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.3.3", features = ["derive"]}
@@ -28,7 +28,7 @@ thiserror = { version = "1.0.38", default_features = false }
hex = { version = "0.4.3", default_features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features=["derive_serde"]}
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "ac/lookup-modularity" }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
@@ -51,9 +51,9 @@ plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3-log = { version = "0.8.1", default_features = false, optional = true }
pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }

View File

@@ -64,8 +64,8 @@ More notebook tutorials can be found within `examples/notebooks`.
#### CLI
Install the CLI
```bash
curl https://hub.ezkl.xyz/install_ezkl_cli.sh | bash
``` shell
curl https://raw.githubusercontent.com/zkonduit/ezkl/main/install_ezkl_cli.sh | bash
```
https://user-images.githubusercontent.com/45801863/236771676-5bbbbfd1-ba6f-418a-902e-20738ce0e9f0.mp4

View File

@@ -121,13 +121,16 @@ fn runcnvrl(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
&circuit, &params, true,
)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {

View File

@@ -90,13 +90,13 @@ fn rundot(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -94,13 +94,13 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -1,4 +1,5 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::table::Range;
use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
@@ -16,7 +17,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use std::marker::PhantomData;
const BITS: (i128, i128) = (-32768, 32768);
const BITS: Range = (-32768, 32768);
static mut LEN: usize = 4;
const K: usize = 16;
@@ -111,13 +112,13 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -3,6 +3,7 @@ use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::table::Range;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
@@ -16,7 +17,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use std::marker::PhantomData;
const BITS: (i128, i128) = (-8180, 8180);
const BITS: Range = (-8180, 8180);
static mut LEN: usize = 4;
static mut K: usize = 16;
@@ -114,13 +115,13 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(k as u64));
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(k as u64));
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {

View File

@@ -86,13 +86,13 @@ fn runsum(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -101,13 +101,16 @@ fn runsumpool(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
&circuit, &params, true,
)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {

View File

@@ -84,13 +84,13 @@ fn runadd(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -83,13 +83,13 @@ fn runpow(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -76,13 +76,13 @@ fn runposeidon(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {

View File

@@ -1,5 +1,6 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::TranscriptType;
@@ -14,7 +15,7 @@ use halo2_proofs::{
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;
const BITS: (i128, i128) = (-32768, 32768);
const BITS: Range = (-32768, 32768);
static mut LEN: usize = 4;
const K: usize = 16;
@@ -90,13 +91,13 @@ fn runrelu(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params)
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params).unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {

View File

@@ -271,7 +271,7 @@
"The graph input for on chain data sources is formatted completely differently compared to file based data sources.\n",
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elemenets :-D). \n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like:\n",
" \n",
"```json\n",

View File

@@ -309,7 +309,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(ezkl.vecu64_to_felt(res['processed_outputs']['poseidon_hash'][0]))"
"print(ezkl.string_to_felt(res['processed_outputs']['poseidon_hash'][0]))"
]
},
{
@@ -338,7 +338,7 @@
"\n",
"def test_on_chain_data(res):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = [int(ezkl.vecu64_to_felt(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
" data = [int(ezkl.string_to_felt(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
"\n",
" # Step 1: Prepare the data\n",
" # Step 2: Prepare and compile the contract.\n",

View File

@@ -42,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"id": "gvQ5HL1bTDWF"
},
@@ -441,7 +441,9 @@
"# Serialize calibration data into file:\n",
"json.dump(data, open(cal_data_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\") # Optimize for resources"
"# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy\n",
"# You may want to increase the max logrows if accuracy is a concern\n",
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\", max_logrows = 12, scales = [2])"
]
},
{
@@ -506,9 +508,8 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
" \n",
"\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
@@ -563,7 +564,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
@@ -695,7 +695,7 @@
"formatted_output = \"[\"\n",
"for i, value in enumerate(proof[\"instances\"]):\n",
" for j, field_element in enumerate(value):\n",
" onchain_input_array.append(ezkl.vecu64_to_felt(field_element))\n",
" onchain_input_array.append(ezkl.string_to_felt(field_element))\n",
" formatted_output += str(onchain_input_array[-1])\n",
" if j != len(value) - 1:\n",
" formatted_output += \", \"\n",

File diff suppressed because one or more lines are too long

View File

@@ -302,7 +302,7 @@
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" \n",
"\n",
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
"\n",
@@ -330,14 +330,14 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"for-aggr\",\n",
" )\n",
"\n",
" print(res)\n",
" res_1_proof = res[\"proof\"]\n",
" assert os.path.isfile(proof_path)\n",
"\n",
" # Verify the proof\n",
" # # Verify the proof\n",
" if i > 0:\n",
" print(\"swapping commitments\")\n",
" # swap the proof commitments if we are not the first model\n",
@@ -356,12 +356,19 @@
"\n",
" res = ezkl.swap_proof_commitments(proof_path, witness_path)\n",
" print(res)\n",
" \n",
" # load proof and then print \n",
" proof = json.load(open(proof_path, 'r'))\n",
" res_2_proof = proof[\"hex_proof\"]\n",
" # show diff in hex strings\n",
" print(res_1_proof)\n",
" print(res_2_proof)\n",
" assert res_1_proof == res_2_proof\n",
"\n",
" res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
" assert res == True\n",
@@ -439,7 +446,7 @@
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
" proofs.append(proof_path)\n",
"\n",
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
"ezkl.mock_aggregate(proofs, logrows=22, split_proofs = True)"
]
}
],

View File

@@ -8,7 +8,7 @@
"source": [
"## EZKL Jupyter Notebook Demo \n",
"\n",
"Here we demonstrate how to use the EZKL package to run a publicly known / committted to network on some private data, producing a public output.\n"
"Here we demonstrate how to use the EZKL package to run a publicly known / committed to network on some private data, producing a public output.\n"
]
},
{
@@ -210,7 +210,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "b1c561a8",
"metadata": {},
"outputs": [],

View File

@@ -126,7 +126,7 @@
"# Loop through each element in the y tensor\n",
"for e in user_preimages:\n",
" # Apply the custom function and append the result to the list\n",
" users.append(ezkl.poseidon_hash([ezkl.float_to_vecu64(e, 0)])[0])\n",
" users.append(ezkl.poseidon_hash([ezkl.float_to_string(e, 0)])[0])\n",
"\n",
"users_t = torch.tensor(user_preimages)\n",
"users_t = users_t.reshape(1, 6)\n",
@@ -303,7 +303,7 @@
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [ezkl.float_to_vecu64(1.0, 0)]\n",
"witness[\"outputs\"][0] = [ezkl.float_to_string(1.0, 0)]\n",
"json.dump(witness, open(witness_path, \"w\"))"
]
},
@@ -417,7 +417,7 @@
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [ezkl.float_to_vecu64(1.0, 0)]\n",
"witness[\"outputs\"][0] = [ezkl.float_to_string(1.0, 0)]\n",
"json.dump(witness, open(witness_path, \"w\"))\n"
]
},

View File

@@ -633,7 +633,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
]
},
{

View File

@@ -154,7 +154,7 @@
"source": [
"## Create a neural net to verify the execution of the tic tac toe model\n",
"\n",
"1. Given the data generated above classify whether the tic tac toe games are valid. This approach uses a binary classification as the tic tac toe state space is fairly small. For larger state spaces we will want to use anomaly detection based approachs"
"1. Given the data generated above classify whether the tic tac toe games are valid. This approach uses a binary classification as the tic tac toe state space is fairly small. For larger state spaces, we will want to use anomaly detection based approaches."
]
},
{
@@ -520,7 +520,7 @@
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
]
},
{
@@ -636,7 +636,8 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,

View File

@@ -237,7 +237,7 @@
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\")\n",
" input_filename, onnx_filename, settings_filename, \"resources\", scales = [4])\n",
"res = ezkl.get_srs(settings_filename)\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)\n",
"\n",
@@ -255,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {
"id": "fULvvnK7_CMb"
},
@@ -451,7 +451,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,

View File

@@ -25,17 +25,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"voice_data_dir: .\n"
]
}
],
"outputs": [],
"source": [
"\n",
"import os\n",
@@ -43,7 +35,7 @@
"\n",
"voice_data_dir = os.environ.get('VOICE_DATA_DIR')\n",
"\n",
"# if is none set to \"\" \n",
"# if is none set to \"\"\n",
"if voice_data_dir is None:\n",
" voice_data_dir = \"\"\n",
"\n",
@@ -637,7 +629,7 @@
"source": [
"\n",
"\n",
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\")\n",
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
"assert res == True\n",
"print(\"verified\")\n"
]
@@ -908,7 +900,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,

View File

@@ -49,7 +49,7 @@
"import torch\n",
"import math\n",
"\n",
"# these are constatns for the rotation\n",
"# these are constants for the rotation\n",
"phi = torch.tensor(5 * math.pi / 180)\n",
"s = torch.sin(phi)\n",
"c = torch.cos(phi)\n",
@@ -503,11 +503,11 @@
"pyplot.arrow(0, 0, 1, 0, width=0.02, alpha=0.5)\n",
"pyplot.arrow(0, 0, 0, 1, width=0.02, alpha=0.5)\n",
"\n",
"arrow_x = ezkl.vecu64_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.vecu64_to_float(witness['outputs'][0][1], out_scale)\n",
"arrow_x = ezkl.string_to_float(witness['outputs'][0][0], out_scale)\n",
"arrow_y = ezkl.string_to_float(witness['outputs'][0][1], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)\n",
"arrow_x = ezkl.vecu64_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.vecu64_to_float(witness['outputs'][0][3], out_scale)\n",
"arrow_x = ezkl.string_to_float(witness['outputs'][0][2], out_scale)\n",
"arrow_y = ezkl.string_to_float(witness['outputs'][0][3], out_scale)\n",
"pyplot.arrow(0, 0, arrow_x, arrow_y, width=0.02)"
]
}

View File

@@ -7,7 +7,7 @@
"source": [
"# kzg-ezkl\n",
"\n",
"Here's an example leveraging EZKL whereby the inputs to the model, and the model params themselves, are commited to using kzg-commitments inside a circuit.\n",
"Here's an example leveraging EZKL whereby the inputs to the model, and the model params themselves, are committed to using kzg-commitments inside a circuit.\n",
"\n",
"In this setup:\n",
"- the commitments are publicly known to the prover and verifier\n",
@@ -166,7 +166,7 @@
"Shoutouts: \n",
"\n",
"- [summa-solvency](https://github.com/summa-dev/summa-solvency) for their help with the poseidon hashing chip. \n",
"- [timeofey](https://github.com/timoftime) for providing inspiration in our developement of the el-gamal encryption circuit in Halo2. "
"- [timeofey](https://github.com/timoftime) for providing inspiration in our development of the el-gamal encryption circuit in Halo2. "
]
},
{

View File

@@ -78,7 +78,7 @@
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
@@ -122,8 +122,8 @@
"# Loop through each element in the y tensor\n",
"for e in y_input:\n",
" # Apply the custom function and append the result to the list\n",
" print(ezkl.float_to_vecu64(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_vecu64(e, 7)])[0])\n",
" print(ezkl.float_to_string(e,7))\n",
" result.append(ezkl.poseidon_hash([ezkl.float_to_string(e, 7)])[0])\n",
"\n",
"y = y.unsqueeze(0)\n",
"y = y.reshape(1, 9)\n",
@@ -343,7 +343,7 @@
"# we force the output to be 0 this corresponds to the set membership test being true -- and we set this to a fixed vis output\n",
"# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject\n",
"witness = json.load(open(witness_path, \"r\"))\n",
"witness[\"outputs\"][0] = [[0, 0, 0, 0]]\n",
"witness[\"outputs\"][0] = [\"0000000000000000000000000000000000000000000000000000000000000000\"]\n",
"json.dump(witness, open(witness_path, \"w\"))\n",
"\n",
"witness = json.load(open(witness_path, \"r\"))\n",
@@ -353,7 +353,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" witness_path = witness_path,\n",
" )\n",
"\n",
@@ -520,4 +519,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -300,13 +300,14 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# iterate over each submodel gen-settings, compile circuit and setup zkSNARK\n",
"\n",
"def setup(i):\n",
" print(\"Setting up split model \"+str(i))\n",
" # file names\n",
" model_path = os.path.join('network_split_'+str(i)+'.onnx')\n",
" settings_path = os.path.join('settings_split_'+str(i)+'.json')\n",
@@ -342,12 +343,12 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" compress_selectors=True,\n",
" )\n",
"\n",
" assert res == True\n",
" assert os.path.isfile(vk_path)\n",
" assert os.path.isfile(pk_path)\n",
" \n",
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
"\n",
@@ -383,7 +384,6 @@
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"for-aggr\",\n",
" )\n",
"\n",
@@ -413,7 +413,6 @@
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
" assert res == True\n",
@@ -442,7 +441,7 @@
" proof_path = os.path.join('proof_split_'+str(i)+'.json')\n",
" proofs.append(proof_path)\n",
"\n",
"ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)"
"ezkl.mock_aggregate(proofs, logrows=22, split_proofs = True)"
]
}
],

View File

@@ -780,7 +780,7 @@
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
@@ -845,7 +845,7 @@
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [5,6])\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
"assert res == True"
]
},
@@ -887,11 +887,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 30,
"metadata": {
"id": "12YIcFr85X9-"
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"spawning module 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"quotient_poly_degree 4\n",
"n 262144\n",
"extended_k 20\n"
]
}
],
"source": [
"res = ezkl.setup(\n",
" compiled_model_path,\n",
@@ -971,9 +988,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}

File diff suppressed because it is too large Load Diff

View File

Before

Width:  |  Height:  |  Size: 109 KiB

After

Width:  |  Height:  |  Size: 109 KiB

View File

@@ -0,0 +1,39 @@
from torch import nn
import torch
import json
class Circuit(nn.Module):
def __init__(self, inplace=False):
super(Circuit, self).__init__()
def forward(self, x):
return x/ 10000
circuit = Circuit()
x = torch.empty(1, 8).random_(0, 2)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]]}

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

170
install_ezkl_cli.sh Normal file
View File

@@ -0,0 +1,170 @@
#!/usr/bin/env bash
set -e
BASE_DIR=${XDG_CONFIG_HOME:-$HOME}
EZKL_DIR=${EZKL_DIR-"$BASE_DIR/.ezkl"}
# Create the .ezkl bin directory if it doesn't exit
mkdir -p $EZKL_DIR
# Store the correct profile file (i.e. .profile for bash or .zshenv for ZSH).
case $SHELL in
*/zsh)
PROFILE=${ZDOTDIR-"$HOME"}/.zshenv
PREF_SHELL=zsh
;;
*/bash)
PROFILE=$HOME/.bashrc
PREF_SHELL=bash
;;
*/fish)
PROFILE=$HOME/.config/fish/config.fish
PREF_SHELL=fish
;;
*/ash)
PROFILE=$HOME/.profile
PREF_SHELL=ash
;;
*)
echo "NOTICE: Shell could not be detected, you will need to manually add ${EZKL_DIR} to your PATH."
esac
# Check for non standard installation of ezkl
if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; then
echo "ezkl is installed in a non-standard directory, $(which ezkl). To use the automated installer, remove the existing ezkl from path to prevent conflicts"
exit 1
fi
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
# Add the ezkl directory to the path and ensure the old PATH variables remain.
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
fi
# Install latest ezkl version
# Get the right release URL
if [ -z "$1" ]
then
RELEASE_URL="https://api.github.com/repos/zkonduit/ezkl/releases/latest"
echo "No version tags provided, installing the latest ezkl version"
else
RELEASE_URL="https://api.github.com/repos/zkonduit/ezkl/releases/tags/$1"
echo "Installing ezkl version $1"
fi
PLATFORM=""
case "$(uname -s)" in
Darwin*)
PLATFORM="macos"
;;
Linux*Microsoft*)
PLATFORM="linux"
;;
Linux*)
PLATFORM="linux"
;;
CYGWIN*|MINGW*|MINGW32*|MSYS*)
PLATFORM="windows-msvc"
;;
*)
echo "Platform is not supported. If you would need support for the platform please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
;;
esac
# Check arch
ARCHITECTURE="$(uname -m)"
if [ "${ARCHITECTURE}" = "x86_64" ]; then
# Redirect stderr to /dev/null to avoid printing errors if non Rosetta.
if [ "$(sysctl -n sysctl.proc_translated 2>/dev/null)" = "1" ]; then
ARCHITECTURE="arm64" # Rosetta.
else
ARCHITECTURE="amd64" # Intel.
fi
elif [ "${ARCHITECTURE}" = "arm64" ] ||[ "${ARCHITECTURE}" = "aarch64" ]; then
ARCHITECTURE="aarch64" # Arm.
elif [ "${ARCHITECTURE}" = "amd64" ]; then
ARCHITECTURE="amd64" # Amd
else
echo "Architecture is not supported. If you would need support for the architecture please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
fi
# Remove existing ezkl
echo "Removing old ezkl binary if it exists"
[ -e file ] && rm file
# download the release and unpack the right tarball
if [ "$PLATFORM" == "windows-msvc" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-windows-msvc.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-windows-msvc.tar.gz"
elif [ "$PLATFORM" == "macos" ]; then
if [ "$ARCHITECTURE" == "aarch64" ] || [ "$ARCHITECTURE" == "arm64" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos-aarch64.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-macos-aarch64.tar.gz"
else
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-macos.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-macos.tar.gz"
fi
elif [ "$PLATFORM" == "linux" ]; then
if [ "${ARCHITECTURE}" = "amd64" ]; then
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-gnu.tar.gz")
echo "Downloading package"
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
echo "Unpacking package"
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz" -C "$EZKL_DIR"
echo "Cleaning up"
rm "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
else
echo "ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
fi
else
echo "Platform and Architecture is not supported. If you would need support for the platform and architecture please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
exit 1
fi
echo && echo "Successfully downloaded ezkl at ${EZKL_DIR}"
echo "We detected that your preferred shell is ${PREF_SHELL} and added ezkl to PATH. Run 'source ${PROFILE}' or start a new terminal session to use ezkl."

View File

@@ -219,7 +219,7 @@ mod tests {
};
let prover = halo2_proofs::dev::MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
}
@@ -240,6 +240,6 @@ mod tests {
message: message.into(),
};
let prover = halo2_proofs::dev::MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
}

View File

@@ -499,7 +499,7 @@ mod tests {
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
#[test]
@@ -518,7 +518,7 @@ mod tests {
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
#[test]
@@ -551,7 +551,7 @@ mod tests {
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
}
@@ -573,6 +573,6 @@ mod tests {
_spec: PhantomData,
};
let prover = halo2_proofs::dev::MockProver::run(k, &circuit, output).unwrap();
assert_eq!(prover.verify_par(), Ok(()))
assert_eq!(prover.verify(), Ok(()))
}
}

View File

@@ -19,7 +19,10 @@ use serde::{Deserialize, Serialize};
use crate::{
circuit::ops::base::BaseOp,
circuit::{table::Table, utils},
circuit::{
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
@@ -176,6 +179,10 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
///
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
@@ -194,7 +201,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
lookup_index: dummy_var,
selectors: BTreeMap::new(),
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
check_mode: CheckMode::SAFE,
_marker: PhantomData,
}
@@ -325,11 +334,13 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Self {
selectors,
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
inputs: inputs.to_vec(),
lookup_input: VarTensor::Empty,
lookup_output: VarTensor::Empty,
lookup_index: VarTensor::Empty,
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
output: output.clone(),
check_mode,
_marker: PhantomData,
@@ -344,7 +355,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
input: &VarTensor,
output: &VarTensor,
index: &VarTensor,
lookup_range: (i128, i128),
lookup_range: Range,
logrows: usize,
nl: &LookupOp,
) -> Result<(), Box<dyn Error>>
@@ -482,6 +493,74 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_range_check(
&mut self,
cs: &mut ConstraintSystem<F>,
input: &VarTensor,
range: Range,
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
let mut selectors = BTreeMap::new();
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
}
// we borrow mutably twice so we need to do this dance
let range_check = if !self.range_checks.contains_key(&range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range);
self.range_checks.insert(range, range_check.clone());
range_check
} else {
return Ok(());
};
for x in 0..input.num_blocks() {
for y in 0..input.num_inner_cols() {
let single_col_sel = cs.complex_selector();
cs.lookup("", |cs| {
let mut res = vec![];
let sel = cs.query_selector(single_col_sel);
let input_query = match &input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
};
let default_x = range_check.get_first_element();
let not_sel = Expression::Constant(F::ONE) - sel.clone();
res.extend([(
sel.clone() * input_query.clone()
+ not_sel.clone() * Expression::Constant(default_x),
range_check.input,
)]);
res
});
selectors.insert((range, x, y), single_col_sel);
}
}
self.range_check_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
}
Ok(())
}
/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
for (i, table) in self.tables.values_mut().enumerate() {
@@ -500,6 +579,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
Ok(())
}
/// layout_range_checks must be called before layout.
pub fn layout_range_checks(
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
for range_check in self.range_checks.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
range_check.layout(layouter)?;
}
}
Ok(())
}
/// Assigns variables to the regions created when calling `configure`.
/// # Arguments
/// * `values` - The explicit values to the operations.

View File

@@ -1,7 +1,8 @@
use super::*;
use crate::{
circuit::{self, layouts, utils, Tolerance},
circuit::{layouts, utils, Tolerance},
fieldutils::{felt_to_i128, i128_to_felt},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
@@ -13,6 +14,15 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
use_range_check_for_int: bool,
},
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
},
@@ -75,6 +85,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::ScatterElements { .. } => vec![0, 2],
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
_ => vec![],
}
}
@@ -113,6 +124,40 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
let res = crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64);
// if denom is a round number and use_range_check_for_int is true, use range check check
if denom.0.fract() == 0.0 && *use_range_check_for_int {
let divisor = Tensor::from(vec![denom.0 as i128 / 2].into_iter());
(res, vec![-divisor.clone(), divisor])
} else {
(res, vec![x])
}
}
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => {
let res = crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
);
// if scale is a round number and use_range_check_for_int is true, use range check check
if input_scale.0.fract() == 0.0 && *use_range_check_for_int {
let err_tol = Tensor::from(
vec![(output_scale.0 * input_scale.0) as i128 / 2].into_iter(),
);
(res, vec![-err_tol.clone(), err_tol])
} else {
(res, vec![x])
}
}
HybridOp::ReduceArgMax { dim } => {
let res = tensor::ops::argmax_axes(&x, *dim)?;
let indices = Tensor::from(0..x.dims()[*dim] as i128);
@@ -272,6 +317,21 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn as_string(&self) -> String {
match self {
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => format!(
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
input_scale, output_scale, use_range_check_for_int
),
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
),
HybridOp::SumPool {
padding,
stride,
@@ -335,6 +395,57 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
*kernel_shape,
*normalized,
)?,
HybridOp::Recip {
input_scale,
output_scale,
use_range_check_for_int,
} => {
if input_scale.0.fract() == 0.0
&& output_scale.0.fract() == 0.0
&& *use_range_check_for_int
{
layouts::recip(
config,
region,
values[..].try_into()?,
i128_to_felt(input_scale.0 as i128),
i128_to_felt(output_scale.0 as i128),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Recip {
input_scale: *input_scale,
output_scale: *output_scale,
},
)?
}
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::div(
config,
region,
values[..].try_into()?,
i128_to_felt(denom.0 as i128),
)?
} else {
layouts::nonlinearity(
config,
region,
values.try_into()?,
&LookupOp::Div {
denom: denom.clone(),
},
)?
}
}
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(values[0].get_inner_tensor()?, idx, *dim)?.into()
@@ -422,86 +533,12 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { .. } => 2 * in_scales[0],
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};
Ok(scale)
}
fn required_lookups(&self) -> Vec<LookupOp> {
match self {
HybridOp::ReduceMax { .. }
| HybridOp::ReduceMin { .. }
| HybridOp::MaxPool2d { .. } => Op::<F>::required_lookups(&LookupOp::ReLU),
HybridOp::Softmax { scale, .. } => {
vec![
LookupOp::Exp { scale: *scale },
LookupOp::Recip {
scale: scale.0.powf(2.0).into(),
},
]
}
HybridOp::RangeCheck(tol) => {
let mut lookups = vec![];
if tol.val > 0.0 {
let scale_squared = tol.scale.0.powf(2.0);
lookups.extend([
LookupOp::Recip {
scale: scale_squared.into(),
},
LookupOp::GreaterThan {
a: circuit::utils::F32((tol.val * scale_squared) / 100.0),
},
]);
}
lookups
}
HybridOp::Greater { .. } | HybridOp::Less { .. } => {
vec![LookupOp::GreaterThan {
a: circuit::utils::F32(0.),
}]
}
HybridOp::GreaterEqual { .. } | HybridOp::LessEqual { .. } => {
vec![LookupOp::GreaterThanEqual {
a: circuit::utils::F32(0.),
}]
}
HybridOp::TopK { .. } => {
vec![
LookupOp::GreaterThan {
a: circuit::utils::F32(0.),
},
LookupOp::KroneckerDelta,
]
}
HybridOp::Gather {
constant_idx: None, ..
}
| HybridOp::OneHot { .. }
| HybridOp::GatherElements {
constant_idx: None, ..
}
| HybridOp::ScatterElements {
constant_idx: None, ..
}
| HybridOp::Equals { .. } => {
vec![LookupOp::KroneckerDelta]
}
HybridOp::ReduceArgMax { .. } | HybridOp::ReduceArgMin { .. } => {
vec![LookupOp::ReLU, LookupOp::KroneckerDelta]
}
HybridOp::SumPool {
kernel_shape,
normalized: true,
..
} => {
vec![LookupOp::Div {
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
}]
}
_ => vec![],
}
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {
Box::new(self.clone()) // Forward to the derive(Clone) impl
}

View File

@@ -18,8 +18,11 @@ use super::{
region::RegionCtx,
};
use crate::{
circuit::{ops::base::BaseOp, utils},
fieldutils::i128_to_felt,
circuit::{
ops::base::BaseOp,
utils::{self},
},
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{
get_broadcasted_shape,
ops::{accumulated, add, mult, sub},
@@ -51,6 +54,144 @@ pub fn overflowed_len(starting_idx: usize, mut total_len: usize, column_len: usi
total_len
}
/// Div accumulated layout
pub fn div<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
div: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let input = value[0].clone();
let input_dims = input.dims();
let range_check_bracket = felt_to_i128(div) / 2;
let mut divisor = Tensor::from(vec![ValType::Constant(div)].into_iter());
divisor.set_visibility(&crate::graph::Visibility::Fixed);
let divisor = region.assign(&config.inputs[1], &divisor.into())?;
region.increment(divisor.len());
let is_assigned = !input.any_unknowns()? && !divisor.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
claimed_output.reshape(input_dims)?;
let product = pairwise(
config,
region,
&[claimed_output.clone(), divisor.clone()],
BaseOp::Mult,
)?;
log::debug!("product: {:?}", product.get_int_evals()?);
let diff_with_input = pairwise(
config,
region,
&[product.clone(), input.clone()],
BaseOp::Sub,
)?;
range_check(
config,
region,
&[diff_with_input],
&(-range_check_bracket, range_check_bracket),
)?;
Ok(claimed_output)
}
/// recip accumulated layout
pub fn recip<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
input_scale: F,
output_scale: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let input = value[0].clone();
let input_dims = input.dims();
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
let mut scaled_unit =
Tensor::from(vec![ValType::Constant(output_scale * input_scale)].into_iter());
scaled_unit.set_visibility(&crate::graph::Visibility::Fixed);
let scaled_unit = region.assign(&config.inputs[1], &scaled_unit.into())?;
region.increment(scaled_unit.len());
let is_assigned = !input.any_unknowns()? && !scaled_unit.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
tensor::ops::nonlinearities::recip(
&input_evals,
felt_to_i128(input_scale) as f64,
felt_to_i128(output_scale) as f64,
)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
claimed_output.reshape(input_dims)?;
// this is now of scale 2 * scale
let product = pairwise(
config,
region,
&[claimed_output.clone(), input.clone()],
BaseOp::Mult,
)?;
log::debug!("product: {:?}", product.get_int_evals()?);
// this is now of scale 2 * scale hence why we rescaled the unit scale
let diff_with_input = pairwise(
config,
region,
&[product.clone(), scaled_unit.clone()],
BaseOp::Sub,
)?;
log::debug!("scaled_unit: {:?}", scaled_unit.get_int_evals()?);
// debug print the diff
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
log::debug!("range_check_bracket: {:?}", range_check_bracket);
// at most the error should be in the original unit scale's range
range_check(
config,
region,
&[diff_with_input],
&(-range_check_bracket, range_check_bracket),
)?;
Ok(claimed_output)
}
/// Dot product accumulated layout
pub fn dot<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -1837,15 +1978,6 @@ pub fn deconv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std:
)));
}
if has_bias {
let bias = &inputs[2];
if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) {
return Err(Box::new(TensorError::DimMismatch(
"deconv bias".to_string(),
)));
}
}
let (kernel_height, kernel_width) = (kernel.dims()[2], kernel.dims()[3]);
let null_val = ValType::Constant(F::ZERO);
@@ -2313,6 +2445,52 @@ pub fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// layout for range check.
pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
range: &crate::circuit::table::Range,
) -> Result<ValTensor<F>, Box<dyn Error>> {
region.add_used_range_check(*range);
// time the entire operation
let timer = instant::Instant::now();
let x = values[0].clone();
let w = region.assign(&config.lookup_input, &x)?;
let assigned_len = x.len();
let is_dummy = region.is_dummy();
if !is_dummy {
(0..assigned_len)
.map(|i| {
let (x, y, z) = config
.lookup_input
.cartesian_coord(region.linear_coord() + i);
let selector = config.range_check_selectors.get(&(range.clone(), x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(assigned_len);
let elapsed = timer.elapsed();
trace!(
"range check {:?} layout took {:?}, row: {:?}",
range,
elapsed,
region.row()
);
Ok(w)
}
/// layout for nonlinearity check.
pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -2320,6 +2498,8 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 1],
nl: &LookupOp,
) -> Result<ValTensor<F>, Box<dyn Error>> {
region.add_used_lookup(nl.clone());
// time the entire operation
let timer = instant::Instant::now();
@@ -2789,7 +2969,8 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd>(
&[denom],
// we set to input scale + output_scale so the output scale is output)scale
&LookupOp::Recip {
scale: scale.0.powf(2.0).into(),
input_scale: scale,
output_scale: scale,
},
)?;
@@ -2817,19 +2998,22 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, values, BaseOp::Sub)?;
let scale_squared = scale.0.powf(2.0);
// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
let recip = nonlinearity(
config,
region,
&[values[0].clone()],
&LookupOp::Recip {
scale: scale_squared.into(),
input_scale: scale,
output_scale: scale,
},
)?;
// Multiply the difference by the recip
let product = pairwise(config, region, &[diff, recip], BaseOp::Mult)?;
let scale_squared = scale.0 * scale.0;
// Use the greater than look up table to check if the percent error is within the tolerance for upper bound
let tol = tol / 100.0;
let upper_bound = nonlinearity(

View File

@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
use std::error::Error;
use crate::{
circuit::{layouts, utils},
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_i128, i128_to_felt},
graph::{multiplier_to_scale, scale_to_multiplier},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
};
@@ -17,47 +17,117 @@ use halo2curves::ff::PrimeField;
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Abs,
Div { denom: utils::F32 },
Cast { scale: utils::F32 },
Div {
denom: utils::F32,
},
Cast {
scale: utils::F32,
},
ReLU,
Max { scale: utils::F32, a: utils::F32 },
Min { scale: utils::F32, a: utils::F32 },
Ceil { scale: utils::F32 },
Floor { scale: utils::F32 },
Round { scale: utils::F32 },
RoundHalfToEven { scale: utils::F32 },
Sqrt { scale: utils::F32 },
Rsqrt { scale: utils::F32 },
Recip { scale: utils::F32 },
LeakyReLU { slope: utils::F32 },
Sigmoid { scale: utils::F32 },
Ln { scale: utils::F32 },
Exp { scale: utils::F32 },
Cos { scale: utils::F32 },
ACos { scale: utils::F32 },
Cosh { scale: utils::F32 },
ACosh { scale: utils::F32 },
Sin { scale: utils::F32 },
ASin { scale: utils::F32 },
Sinh { scale: utils::F32 },
ASinh { scale: utils::F32 },
Tan { scale: utils::F32 },
ATan { scale: utils::F32 },
Tanh { scale: utils::F32 },
ATanh { scale: utils::F32 },
Erf { scale: utils::F32 },
GreaterThan { a: utils::F32 },
LessThan { a: utils::F32 },
GreaterThanEqual { a: utils::F32 },
LessThanEqual { a: utils::F32 },
Max {
scale: utils::F32,
a: utils::F32,
},
Min {
scale: utils::F32,
a: utils::F32,
},
Ceil {
scale: utils::F32,
},
Floor {
scale: utils::F32,
},
Round {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
Rsqrt {
scale: utils::F32,
},
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
},
LeakyReLU {
slope: utils::F32,
},
Sigmoid {
scale: utils::F32,
},
Ln {
scale: utils::F32,
},
Exp {
scale: utils::F32,
},
Cos {
scale: utils::F32,
},
ACos {
scale: utils::F32,
},
Cosh {
scale: utils::F32,
},
ACosh {
scale: utils::F32,
},
Sin {
scale: utils::F32,
},
ASin {
scale: utils::F32,
},
Sinh {
scale: utils::F32,
},
ASinh {
scale: utils::F32,
},
Tan {
scale: utils::F32,
},
ATan {
scale: utils::F32,
},
Tanh {
scale: utils::F32,
},
ATanh {
scale: utils::F32,
},
Erf {
scale: utils::F32,
},
GreaterThan {
a: utils::F32,
},
LessThan {
a: utils::F32,
},
GreaterThanEqual {
a: utils::F32,
},
LessThanEqual {
a: utils::F32,
},
Sign,
KroneckerDelta,
Pow { scale: utils::F32, a: utils::F32 },
Pow {
scale: utils::F32,
a: utils::F32,
},
}
impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> (i128, i128) {
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as i128;
(-range, range)
@@ -120,7 +190,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
&x,
f32::from(*scale).into(),
)),
LookupOp::Recip { scale } => Ok(tensor::ops::nonlinearities::recip(&x, scale.into())),
LookupOp::Recip {
input_scale,
output_scale,
} => Ok(tensor::ops::nonlinearities::recip(
&x,
input_scale.into(),
output_scale.into(),
)),
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),
LookupOp::LeakyReLU { slope: a } => {
@@ -173,7 +250,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
LookupOp::LessThan { .. } => "LESS_THAN".into(),
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
LookupOp::Recip { scale, .. } => format!("RECIP(scale={})", scale),
LookupOp::Recip {
input_scale,
output_scale,
} => format!(
"RECIP(input_scale={}, output_scale={})",
input_scale, output_scale
),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
@@ -220,12 +303,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
let in_scale = inputs_scale[0];
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { scale } => {
let mut out_scale = inputs_scale[0];
out_scale +=
multiplier_to_scale(scale.0 as f64 / scale_to_multiplier(out_scale).powf(2.0));
out_scale
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::Sign
| LookupOp::GreaterThan { .. }
| LookupOp::LessThan { .. }
@@ -237,10 +315,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
Ok(scale)
}
fn required_lookups(&self) -> Vec<LookupOp> {
vec![self.clone()]
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {
Box::new(self.clone()) // Forward to the derive(Clone) impl
}

View File

@@ -55,11 +55,6 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
vec![]
}
/// Returns the lookups required by the operation.
fn required_lookups(&self) -> Vec<LookupOp> {
vec![]
}
/// Returns true if the operation is an input.
fn is_input(&self) -> bool {
false

View File

@@ -33,7 +33,9 @@ pub enum PolyOp {
Sub,
Neg,
Mult,
Identity,
Identity {
out_scale: Option<crate::Scale>,
},
Reshape(Vec<usize>),
MoveAxis {
source: usize,
@@ -85,7 +87,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Resize { .. } => "RESIZE".into(),
PolyOp::Iff => "IFF".into(),
PolyOp::Einsum { equation, .. } => format!("EINSUM {}", equation),
PolyOp::Identity => "IDENTITY".into(),
PolyOp::Identity { out_scale } => {
format!("IDENTITY (out_scale={:?})", out_scale)
}
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
PolyOp::Flatten(_) => "FLATTEN".into(),
PolyOp::Pad(_) => "PAD".into(),
@@ -135,7 +139,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
PolyOp::Identity => Ok(inputs[0].clone()),
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
PolyOp::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
@@ -237,7 +241,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
PolyOp::Neg => layouts::neg(config, region, values[..].try_into()?)?,
PolyOp::Iff => layouts::iff(config, region, values[..].try_into()?)?,
PolyOp::Einsum { equation } => layouts::einsum(config, region, &values, equation)?,
PolyOp::Einsum { equation } => layouts::einsum(config, region, values, equation)?,
PolyOp::Sum { axes } => {
layouts::sum_axes(config, region, values[..].try_into()?, axes)?
}
@@ -264,7 +268,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -290,12 +294,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
PolyOp::MultiBroadcastTo { .. } => in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Neg => in_scales[0],
PolyOp::MoveAxis { .. } => in_scales[0],
PolyOp::Downsample { .. } => in_scales[0],
PolyOp::Resize { .. } => in_scales[0],
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {
let mut scale = in_scales[0];
@@ -327,9 +326,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
output_scale
}
PolyOp::Add => {
let mut scale_a = 0;
let scale_b = in_scales[0];
scale_a += in_scales[1];
let scale_a = in_scales[0];
let scale_b = in_scales[1];
assert_eq!(scale_a, scale_b);
scale_a
}
@@ -339,26 +337,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
scale += in_scales[1];
scale
}
PolyOp::Identity => in_scales[0],
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pad(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Pack(_, _) => in_scales[0],
PolyOp::GlobalSumPool => in_scales[0],
PolyOp::Concat { axis: _ } => in_scales[0],
PolyOp::Slice { .. } => in_scales[0],
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
_ => in_scales[0],
};
Ok(scale)
}
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(
self,
PolyOp::Add { .. } | PolyOp::Sub | PolyOp::Concat { .. }
) {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else {
vec![]
}

View File

@@ -1,4 +1,7 @@
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor};
use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
};
use halo2_proofs::{
circuit::Region,
plonk::{Error, Selector},
@@ -7,9 +10,14 @@ use halo2curves::ff::PrimeField;
use std::{
cell::RefCell,
collections::HashSet,
sync::atomic::{AtomicUsize, Ordering},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use super::lookup::LookupOp;
/// Region error
#[derive(Debug, thiserror::Error)]
pub enum RegionError {
@@ -56,6 +64,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
@@ -75,6 +85,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
row,
linear_coord,
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
}
}
/// Create a new region context from a wrapped region
@@ -90,6 +102,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
}
}
@@ -104,6 +118,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
}
}
@@ -111,8 +127,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
pub fn new_dummy_with_constants(
row: usize,
linear_coord: usize,
constants: usize,
total_constants: usize,
num_inner_cols: usize,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -120,7 +138,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: constants,
total_constants,
used_lookups,
used_range_checks,
}
}
@@ -170,6 +190,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
*output = output
.par_enum_map(|idx, _| {
@@ -177,12 +199,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
let starting_constants = constants.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_constants(
starting_offset,
starting_linear_coord,
starting_constants,
self.num_inner_cols,
HashSet::new(),
HashSet::new(),
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -195,6 +221,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
res
})
.map_err(|e| {
@@ -204,6 +235,21 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
self.row = row.into_inner();
self.used_lookups = Arc::try_unwrap(lookups)
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e))
})?;
self.used_range_checks = Arc::try_unwrap(range_checks)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?;
Ok(())
}
@@ -212,15 +258,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.region.is_none()
}
/// duplicate_dummy
pub fn duplicate_dummy(&self) -> Self {
Self {
region: None,
linear_coord: self.linear_coord,
num_inner_cols: self.num_inner_cols,
row: self.row,
total_constants: self.total_constants,
}
/// add used lookup
pub fn add_used_lookup(&mut self, lookup: LookupOp) {
self.used_lookups.insert(lookup);
}
/// add used range check
pub fn add_used_range_check(&mut self, range: Range) {
self.used_range_checks.insert(range);
}
/// Get the offset
@@ -238,6 +283,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants
}
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.used_lookups.clone()
}
/// get used range checks
pub fn used_range_checks(&self) -> HashSet<Range> {
self.used_range_checks.clone()
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;

View File

@@ -19,6 +19,9 @@ use crate::circuit::lookup::LookupOp;
use super::Op;
/// The range of the lookup table.
pub type Range = (i128, i128);
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
/// The safety factor offset for the number of rows in the lookup table.
@@ -91,7 +94,7 @@ pub struct Table<F: PrimeField> {
/// Flags if table has been previously assigned to.
pub is_assigned: bool,
/// Number of bits used in lookup table.
pub range: (i128, i128),
pub range: Range,
_marker: PhantomData<F>,
}
@@ -129,7 +132,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
}
///
pub fn num_cols_required(range: (i128, i128), col_size: usize) -> usize {
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
// double it to be safe
let range_len = range.1 - range.0;
// number of cols needed to store the range
@@ -141,7 +144,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
range: (i128, i128),
range: Range,
logrows: usize,
nonlinearity: &LookupOp,
preexisting_inputs: Option<Vec<TableColumn>>,
@@ -257,3 +260,86 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
Ok(())
}
}
/// Halo2 range check column
#[derive(Clone, Debug)]
pub struct RangeCheck<F: PrimeField> {
/// Input to table.
pub input: TableColumn,
/// selector cn
pub selector_constructor: SelectorConstructor<F>,
/// Flags if table has been previously assigned to.
pub is_assigned: bool,
/// Number of bits used in lookup table.
pub range: Range,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self) -> F {
i128_to_felt(self.range.0)
}
///
pub fn cal_col_size(logrows: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(logrows as u32) - reserved_blinding_rows
}
///
pub fn cal_bit_range(bits: usize, reserved_blinding_rows: usize) -> usize {
2usize.pow(bits as u32) - reserved_blinding_rows
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
let inputs = cs.lookup_table_column();
RangeCheck {
input: inputs,
is_assigned: false,
selector_constructor: SelectorConstructor::new(2),
range,
_marker: PhantomData,
}
}
/// Assigns values to the constraints generated when calling `configure`.
pub fn layout(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
if self.is_assigned {
return Err(Box::new(CircuitError::TableAlreadyAssigned));
}
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
self.is_assigned = true;
layouter.assign_table(
|| "range check table",
|mut table| {
let _ = inputs
.iter()
.enumerate()
.map(|(row_offset, input)| {
table.assign_cell(
|| format!("rc_i_col row {}", row_offset),
self.input,
row_offset,
|| Value::known(*input),
)?;
Ok(())
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Ok(())
},
)?;
Ok(())
}
}

View File

@@ -90,7 +90,7 @@ mod matmul {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -165,7 +165,7 @@ mod matmul_col_overflow_double_col {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -239,7 +239,7 @@ mod matmul_col_overflow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -327,7 +327,7 @@ mod matmul_col_ultra_overflow_double_col {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -441,7 +441,7 @@ mod matmul_col_ultra_overflow {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -543,7 +543,7 @@ mod dot {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -620,7 +620,7 @@ mod dot_col_overflow_triple_col {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -693,7 +693,7 @@ mod dot_col_overflow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -762,7 +762,7 @@ mod sum {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -832,7 +832,7 @@ mod sum_col_overflow_double_col {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -901,7 +901,7 @@ mod sum_col_overflow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -994,7 +994,7 @@ mod composition {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1095,7 +1095,7 @@ mod conv {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
#[test]
@@ -1133,7 +1133,7 @@ mod conv {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1240,7 +1240,7 @@ mod conv_col_ultra_overflow {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -1390,7 +1390,7 @@ mod conv_relu_col_ultra_overflow {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -1484,7 +1484,7 @@ mod add_w_shape_casting {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1551,7 +1551,7 @@ mod add {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1618,7 +1618,7 @@ mod add_with_overflow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1727,7 +1727,7 @@ mod add_with_overflow_and_poseidon {
let prover =
MockProver::run(K as u32, &circuit, vec![vec![commitment_a, commitment_b]]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
#[test]
@@ -1822,7 +1822,7 @@ mod sub {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1889,7 +1889,7 @@ mod mult {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -1954,7 +1954,7 @@ mod pow {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -2023,7 +2023,7 @@ mod pack {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -2116,7 +2116,7 @@ mod matmul_relu {
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -2154,7 +2154,7 @@ mod rangecheckpercent {
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let scale = utils::F32(SCALE.pow(2) as f32);
let scale = utils::F32(SCALE as f32);
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
@@ -2162,11 +2162,12 @@ mod rangecheckpercent {
Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE);
// set up a new GreaterThan and Recip tables
let nl = &LookupOp::GreaterThan {
a: circuit::utils::F32((RANGE * scale.0) / 100.0),
a: circuit::utils::F32((RANGE * SCALE.pow(2) as f32) / 100.0),
};
config
.configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, nl)
.unwrap();
config
.configure_lookup(
cs,
@@ -2175,7 +2176,10 @@ mod rangecheckpercent {
&a,
(-32768, 32768),
K,
&LookupOp::Recip { scale },
&LookupOp::Recip {
input_scale: scale,
output_scale: scale,
},
)
.unwrap();
config
@@ -2222,7 +2226,7 @@ mod rangecheckpercent {
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
{
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(200_u64))]), &[1]).unwrap();
@@ -2233,7 +2237,7 @@ mod rangecheckpercent {
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
// Unsuccessful case
@@ -2328,7 +2332,7 @@ mod relu {
};
let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}
@@ -2421,7 +2425,7 @@ mod lookup_ultra_overflow {
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ReLUCircuit<F>,
>(&circuit, &params)
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
@@ -2511,7 +2515,8 @@ mod softmax {
(-32768, 32768),
K,
&LookupOp::Recip {
scale: SCALE.powf(2.0).into(),
input_scale: SCALE.into(),
output_scale: SCALE.into(),
},
)
.unwrap();
@@ -2557,6 +2562,6 @@ mod softmax {
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied_par();
prover.assert_satisfied();
}
}

View File

@@ -59,6 +59,8 @@ pub const DEFAULT_SOL_CODE_DA: &str = "evm_deploy_da.sol";
pub const DEFAULT_CONTRACT_ADDRESS: &str = "contract.address";
/// Default contract address for data attestation
pub const DEFAULT_CONTRACT_ADDRESS_DA: &str = "contract_da.address";
/// Default contract address for vk
pub const DEFAULT_CONTRACT_ADDRESS_VK: &str = "contract_vk.address";
/// Default check mode
pub const DEFAULT_CHECKMODE: &str = "safe";
/// Default calibration target
@@ -73,6 +75,16 @@ pub const DEFAULT_FUZZ_RUNS: &str = "10";
pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
/// Default lookup safety margin
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
/// Default Compress selectors
pub const DEFAULT_COMPRESS_SELECTORS: &str = "false";
/// Default render vk seperately
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
/// Default VK sol path
pub const DEFAULT_VK_SOL: &str = "vk.sol";
/// Default VK abi path
pub const DEFAULT_VK_ABI: &str = "vk.abi";
/// Default scale rebase multipliers for calibration
pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -313,9 +325,20 @@ pub enum Commands {
/// Optional scales to specifically try for calibration. Example, --scales 0,4
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
scales: Option<Vec<crate::Scale>>,
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale. Example, --scale-rebase-multipliers 0,4
#[arg(
long,
value_delimiter = ',',
allow_hyphen_values = true,
default_value = DEFAULT_SCALE_REBASE_MULTIPLIERS
)]
scale_rebase_multiplier: Vec<u32>,
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
max_logrows: Option<u32>,
// whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase
#[arg(long)]
div_rebasing: Option<bool>,
},
/// Generates a dummy SRS
@@ -389,6 +412,9 @@ pub enum Commands {
/// whether the accumulated are segments of a larger proof
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
},
/// Aggregates proofs :)
Aggregate {
@@ -451,6 +477,9 @@ pub enum Commands {
/// The graph witness (optional - used to override fixed values in the circuit)
#[arg(short = 'W', long)]
witness: Option<PathBuf>,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -473,6 +502,9 @@ pub enum Commands {
/// number of fuzz iterations
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
num_runs: usize,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
@@ -573,6 +605,31 @@ pub enum Commands {
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VERIFIER_ABI)]
abi_path: PathBuf,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
render_vk_seperately: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier for a single proof
#[command(name = "create-evm-vk")]
CreateEVMVK {
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
settings_path: PathBuf,
/// The path to load the desired verification key file
#[arg(long, default_value = DEFAULT_VK)]
vk_path: PathBuf,
/// The path to output the Solidity code
#[arg(long, default_value = DEFAULT_VK_SOL)]
sol_code_path: PathBuf,
/// The path to output the Solidity verifier ABI
#[arg(long, default_value = DEFAULT_VK_ABI)]
abi_path: PathBuf,
},
#[cfg(not(target_arch = "wasm32"))]
/// Creates an EVM verifier that attests to on-chain inputs for a single proof
@@ -618,6 +675,11 @@ pub enum Commands {
// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
/// Whether the verifier key should be rendered as a separate contract.
/// We recommend disabling selector compression if this is enabled.
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
render_vk_seperately: bool,
},
/// Verifies a proof, returning accept or reject
Verify {
@@ -669,6 +731,25 @@ pub enum Commands {
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl
DeployEvmVK {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_VK_SOL)]
sol_code_path: PathBuf,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long)]
rpc_url: Option<String>,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK)]
/// The path to output the contract address
addr_path: PathBuf,
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
optimizer_runs: usize,
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
#[arg(short = 'P', long)]
private_key: Option<String>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that allows for data attestation
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
@@ -710,6 +791,9 @@ pub enum Commands {
/// does the verifier use data attestation ?
#[arg(long)]
addr_da: Option<H160>,
// is the vk rendered seperately, if so specify an address
#[arg(long)]
addr_vk: Option<H160>,
},
/// Print the proof in hexadecimal

View File

@@ -101,17 +101,18 @@ pub async fn setup_eth_backend(
}
///
pub async fn deploy_verifier_via_solidity(
pub async fn deploy_contract_via_solidity(
sol_code_path: PathBuf,
rpc_url: Option<&str>,
runs: usize,
private_key: Option<&str>,
contract_name: &str,
) -> Result<ethers::types::Address, Box<dyn Error>> {
// anvil instance must be alive at least until the factory completes the deploy
let (anvil, client) = setup_eth_backend(rpc_url, private_key).await?;
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, "Halo2Verifier", runs)?;
get_contract_artifacts(sol_code_path, contract_name, runs)?;
let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone())?;
let contract = factory.deploy(())?.send().await?;
@@ -335,11 +336,16 @@ pub async fn update_account_calls(
pub async fn verify_proof_via_solidity(
proof: Snark<Fr, G1Affine>,
addr: ethers::types::Address,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
) -> Result<bool, Box<dyn Error>> {
let flattened_instances = proof.instances.into_iter().flatten();
let encoded = encode_calldata(None, &proof.proof, &flattened_instances.collect::<Vec<_>>());
let encoded = encode_calldata(
addr_vk.as_ref().map(|x| x.0),
&proof.proof,
&flattened_instances.collect::<Vec<_>>(),
);
info!("encoded: {:#?}", hex::encode(&encoded));
let (anvil, client) = setup_eth_backend(rpc_url, None).await?;
@@ -439,6 +445,7 @@ pub async fn verify_proof_with_data_attestation(
proof: Snark<Fr, G1Affine>,
addr_verifier: ethers::types::Address,
addr_da: ethers::types::Address,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
) -> Result<bool, Box<dyn Error>> {
use ethers::abi::{Function, Param, ParamType, StateMutability, Token};
@@ -452,8 +459,11 @@ pub async fn verify_proof_with_data_attestation(
public_inputs.push(u);
}
let encoded_verifier =
encode_calldata(None, &proof.proof, &flattened_instances.collect::<Vec<_>>());
let encoded_verifier = encode_calldata(
addr_vk.as_ref().map(|x| x.0),
&proof.proof,
&flattened_instances.collect::<Vec<_>>(),
);
info!("encoded: {:#?}", hex::encode(&encoded_verifier));

View File

@@ -3,7 +3,7 @@ use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::commands::Commands;
#[cfg(not(target_arch = "wasm32"))]
use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity};
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
#[cfg(not(target_arch = "wasm32"))]
#[allow(unused_imports)]
use crate::eth::{fix_da_sol, get_contract_artifacts, verify_proof_via_solidity};
@@ -140,8 +140,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
compiled_circuit,
transcript,
num_runs,
} => fuzz(compiled_circuit, witness, transcript, num_runs),
compress_selectors,
} => fuzz(
compiled_circuit,
witness,
transcript,
num_runs,
compress_selectors,
),
Commands::GenSrs { srs_path, logrows } => gen_srs_cmd(srs_path, logrows as u32),
#[cfg(not(target_arch = "wasm32"))]
Commands::GetSrs {
@@ -170,7 +176,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
max_logrows,
div_rebasing,
} => calibrate(
model,
data,
@@ -178,6 +186,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
target,
lookup_safety_margin,
scales,
scale_rebase_multiplier,
div_rebasing,
max_logrows,
)
.map(|e| serde_json::to_string(&e).unwrap()),
@@ -198,7 +208,22 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
settings_path,
sol_code_path,
abi_path,
} => create_evm_verifier(vk_path, srs_path, settings_path, sol_code_path, abi_path),
render_vk_seperately,
} => create_evm_verifier(
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
render_vk_seperately,
),
Commands::CreateEVMVK {
vk_path,
srs_path,
settings_path,
sol_code_path,
abi_path,
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
#[cfg(not(target_arch = "wasm32"))]
Commands::CreateEVMDataAttestation {
settings_path,
@@ -214,6 +239,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
abi_path,
aggregation_settings,
logrows,
render_vk_seperately,
} => create_evm_aggregate_verifier(
vk_path,
srs_path,
@@ -221,6 +247,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
abi_path,
aggregation_settings,
logrows,
render_vk_seperately,
),
Commands::CompileCircuit {
model,
@@ -233,7 +260,15 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path,
pk_path,
witness,
} => setup(compiled_circuit, srs_path, vk_path, pk_path, witness),
compress_selectors,
} => setup(
compiled_circuit,
srs_path,
vk_path,
pk_path,
witness,
compress_selectors,
),
#[cfg(not(target_arch = "wasm32"))]
Commands::SetupTestEVMData {
data,
@@ -296,6 +331,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
split_proofs,
compress_selectors,
} => setup_aggregate(
sample_snarks,
vk_path,
@@ -303,6 +339,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
split_proofs,
compress_selectors,
),
Commands::Aggregate {
proof_path,
@@ -352,6 +389,25 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
addr_path,
optimizer_runs,
private_key,
"Halo2Verifier",
)
.await
}
#[cfg(not(target_arch = "wasm32"))]
Commands::DeployEvmVK {
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
} => {
deploy_evm(
sol_code_path,
rpc_url,
addr_path,
optimizer_runs,
private_key,
"Halo2VerifyingKey",
)
.await
}
@@ -382,7 +438,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
addr_verifier,
rpc_url,
addr_da,
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await,
addr_vk,
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path),
}
}
@@ -432,7 +489,7 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
let path = get_srs_path(logrows, srs_path);
let hash = sha256::digest(&std::fs::read(path.clone())?);
let hash = sha256::digest(std::fs::read(path.clone())?);
info!("SRS hash: {}", hash);
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
@@ -440,7 +497,7 @@ fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
};
if hash != predefined_hash.to_string() {
if hash != *predefined_hash {
// delete file
warn!("removing SRS file at {}", path.display());
std::fs::remove_file(path)?;
@@ -573,6 +630,10 @@ pub(crate) async fn gen_witness(
if let Some(output_path) = output {
serde_json::to_writer(&File::create(output_path)?, &witness)?;
}
// print the witness in debug
debug!("witness: \n {}", witness.as_json()?.to_colored_json_auto()?);
Ok(witness)
}
@@ -662,8 +723,7 @@ impl AccuracyResults {
let error = (original.clone() - calibrated.clone())?;
let abs_error = error.map(|x| x.abs());
let squared_error = error.map(|x| x.powi(2));
let percentage_error =
error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i].clone()))?;
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
let abs_percentage_error = percentage_error.map(|x| x.abs());
errors.extend(error.into_iter());
@@ -679,29 +739,25 @@ impl AccuracyResults {
abs_percentage_errors.iter().sum::<f32>() / abs_percentage_errors.len() as f32;
let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
let median_error = errors[errors.len() / 2];
let max_error = errors
let max_error = *errors
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let min_error = errors
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let min_error = *errors
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
let median_abs_error = abs_errors[abs_errors.len() / 2];
let max_abs_error = abs_errors
let max_abs_error = *abs_errors
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let min_abs_error = abs_errors
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let min_abs_error = *abs_errors
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let mean_squared_error = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;
@@ -731,6 +787,8 @@ pub(crate) fn calibrate(
target: CalibrationTarget,
lookup_safety_margin: i128,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
div_rebasing: Option<bool>,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use std::collections::HashMap;
@@ -774,9 +832,13 @@ pub(crate) fn calibrate(
}
};
let mut found_params: Vec<GraphSettings> = vec![];
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
vec![div_rebasing]
} else {
vec![true, false]
};
let scale_rebase_multiplier = [1, 2, 10];
let mut found_params: Vec<GraphSettings> = vec![];
// 2 x 2 grid
let range_grid = range
@@ -813,18 +875,22 @@ pub(crate) fn calibrate(
.map(|(a, b)| (*a, *b))
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
let range_grid = range_grid
.iter()
.cartesian_product(div_rebasing.iter())
.map(|(a, b)| (*a, *b))
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
let mut forward_pass_res = HashMap::new();
let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"input scale: {}, param scale: {}, scale rebase multiplier: {}",
input_scale, param_scale, scale_rebase_multiplier
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));
// vec of settings copied chunks.len() times
let run_args_iterable = vec![settings.run_args.clone(); chunks.len()];
#[cfg(unix)]
let _r = match Gag::stdout() {
@@ -836,41 +902,42 @@ pub(crate) fn calibrate(
Ok(r) => Some(r),
Err(_) => None,
};
let key = (input_scale, param_scale, scale_rebase_multiplier);
forward_pass_res.insert(key, vec![]);
let tasks = chunks
let local_run_args = RunArgs {
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
..settings.run_args.clone()
};
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
// drop the gag
#[cfg(unix)]
std::mem::drop(_r);
#[cfg(unix)]
std::mem::drop(_q);
debug!("circuit creation from run args failed: {:?}", e);
continue;
}
};
chunks
.iter()
.zip(run_args_iterable)
.map(|(chunk, run_args)| {
// we need to create a new run args for each chunk
// time it
.map(|chunk| {
let chunk = chunk.clone();
let local_run_args = RunArgs {
input_scale,
param_scale,
scale_rebase_multiplier,
..run_args.clone()
};
let original_settings = settings.clone();
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(_) => {
return Err(format!("failed to create circuit from run args"))
as Result<GraphSettings, String>
}
};
let data = circuit
.load_graph_from_file_exclusively(&chunk)
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
let forward_res = circuit
.calibrate(&data, max_logrows, lookup_safety_margin)
.map_err(|e| format!("failed to calibrate: {}", e))?;
.forward(&mut data.clone(), None, None)
.map_err(|e| format!("failed to forward: {}", e))?;
// push result to the hashmap
forward_pass_res
@@ -878,38 +945,32 @@ pub(crate) fn calibrate(
.ok_or("key not found")?
.push(forward_res);
let settings = circuit.settings().clone();
let found_run_args = RunArgs {
input_scale: settings.run_args.input_scale,
param_scale: settings.run_args.param_scale,
lookup_range: settings.run_args.lookup_range,
logrows: settings.run_args.logrows,
scale_rebase_multiplier: settings.run_args.scale_rebase_multiplier,
..run_args.clone()
};
let found_settings = GraphSettings {
run_args: found_run_args,
required_lookups: settings.required_lookups,
model_output_scales: settings.model_output_scales,
model_input_scales: settings.model_input_scales,
num_rows: settings.num_rows,
total_assignments: settings.total_assignments,
total_const_size: settings.total_const_size,
..original_settings.clone()
};
Ok(found_settings) as Result<GraphSettings, String>
Ok(()) as Result<(), String>
})
.collect::<Vec<Result<GraphSettings, String>>>();
.collect::<Result<Vec<()>, String>>()?;
let mut res: Vec<GraphSettings> = vec![];
for task in tasks {
if let Ok(task) = task {
res.push(task);
}
}
let min_lookup_range = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.min_lookup_inputs)
.min()
.unwrap_or(0);
let max_lookup_range = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.max_lookup_inputs)
.max()
.unwrap_or(0);
let res = circuit.calibrate_from_min_max(
min_lookup_range,
max_lookup_range,
max_logrows,
lookup_safety_margin,
);
// drop the gag
#[cfg(unix)]
@@ -917,31 +978,39 @@ pub(crate) fn calibrate(
#[cfg(unix)]
std::mem::drop(_q);
let max_lookup_range = res
.iter()
.map(|x| x.run_args.lookup_range.1)
.max()
.unwrap_or(0);
let min_lookup_range = res
.iter()
.map(|x| x.run_args.lookup_range.0)
.min()
.unwrap_or(0);
if res.is_ok() {
let new_settings = circuit.settings().clone();
let found_run_args = RunArgs {
input_scale: new_settings.run_args.input_scale,
param_scale: new_settings.run_args.param_scale,
div_rebasing: new_settings.run_args.div_rebasing,
lookup_range: new_settings.run_args.lookup_range,
logrows: new_settings.run_args.logrows,
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
..settings.run_args.clone()
};
let found_settings = GraphSettings {
run_args: found_run_args,
required_lookups: new_settings.required_lookups,
required_range_checks: new_settings.required_range_checks,
model_output_scales: new_settings.model_output_scales,
model_input_scales: new_settings.model_input_scales,
num_rows: new_settings.num_rows,
total_assignments: new_settings.total_assignments,
total_const_size: new_settings.total_const_size,
..settings.clone()
};
found_params.push(found_settings.clone());
if let Some(mut best) = res.into_iter().max_by_key(|p| {
(
p.run_args.logrows,
p.run_args.input_scale,
p.run_args.param_scale,
)
}) {
best.run_args.lookup_range = (min_lookup_range, max_lookup_range);
// pick the one with the largest logrows
found_params.push(best.clone());
debug!(
"found settings: \n {}",
best.as_json()?.to_colored_json_auto()?
found_settings.as_json()?.to_colored_json_auto()?
);
} else {
debug!("calibration failed {}", res.err().unwrap());
}
pb.inc(1);
@@ -1034,7 +1103,7 @@ pub(crate) fn calibrate(
let tear_sheet_table = Table::new(vec![accuracy_res]);
println!(
warn!(
"\n\n <------------- Numerical Fidelity Report (input_scale: {}, param_scale: {}, scale_input_multiplier: {}) ------------->\n\n{}\n\n",
best_params.run_args.input_scale,
best_params.run_args.param_scale,
@@ -1098,7 +1167,7 @@ pub(crate) fn mock(
)
.map_err(Box::<dyn Error>::from)?;
prover
.verify_par()
.verify()
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
Ok(String::new())
}
@@ -1143,6 +1212,7 @@ pub(crate) fn create_evm_verifier(
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
@@ -1160,7 +1230,11 @@ pub(crate) fn create_evm_verifier(
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
num_instance,
);
let verifier_solidity = generator.render()?;
let verifier_solidity = if render_vk_seperately {
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
} else {
generator.render()?
};
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
@@ -1172,6 +1246,43 @@ pub(crate) fn create_evm_verifier(
Ok(String::new())
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_vk(
vk_path: PathBuf,
srs_path: Option<PathBuf>,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?;
let num_instance = circuit_settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, circuit_settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
&params,
&vk,
halo2_solidity_verifier::BatchOpenScheme::Bdfg21,
num_instance,
);
let vk_solidity = generator.render_separately()?.1;
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0)?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
Ok(String::new())
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_data_attestation(
settings_path: PathBuf,
@@ -1267,13 +1378,15 @@ pub(crate) async fn deploy_evm(
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
contract_name: &str,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let contract_address = deploy_verifier_via_solidity(
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
runs,
private_key.as_deref(),
contract_name,
)
.await?;
@@ -1290,6 +1403,7 @@ pub(crate) async fn verify_evm(
addr_verifier: H160,
rpc_url: Option<String>,
addr_da: Option<H160>,
addr_vk: Option<H160>,
) -> Result<String, Box<dyn Error>> {
use crate::eth::verify_proof_with_data_attestation;
check_solc_requirement();
@@ -1301,11 +1415,12 @@ pub(crate) async fn verify_evm(
proof.clone(),
addr_verifier,
addr_da,
addr_vk,
rpc_url.as_deref(),
)
.await?
} else {
verify_proof_via_solidity(proof.clone(), addr_verifier, rpc_url.as_deref()).await?
verify_proof_via_solidity(proof.clone(), addr_verifier, addr_vk, rpc_url.as_deref()).await?
};
info!("Solidity verification result: {}", result);
@@ -1325,6 +1440,7 @@ pub(crate) fn create_evm_aggregate_verifier(
abi_path: PathBuf,
circuit_settings: Vec<PathBuf>,
logrows: u32,
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let srs_path = get_srs_path(logrows, srs_path);
@@ -1363,7 +1479,11 @@ pub(crate) fn create_evm_aggregate_verifier(
generator = generator.set_acc_encoding(Some(acc_encoding));
let verifier_solidity = generator.render()?;
let verifier_solidity = if render_vk_seperately {
generator.render_separately()?.0 // ignore the rendered vk for now and generate it in create_evm_vk
} else {
generator.render()?
};
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
@@ -1392,6 +1512,7 @@ pub(crate) fn setup(
vk_path: PathBuf,
pk_path: PathBuf,
witness: Option<PathBuf>,
compress_selectors: bool,
) -> Result<String, Box<dyn Error>> {
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit)?;
@@ -1402,8 +1523,12 @@ pub(crate) fn setup(
let params = load_params_cmd(srs_path, circuit.settings().run_args.logrows)?;
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
)
.map_err(Box::<dyn Error>::from)?;
save_vk::<KZGCommitmentScheme<Bn256>>(&vk_path, pk.get_vk())?;
save_pk::<KZGCommitmentScheme<Bn256>>(&pk_path, &pk)?;
@@ -1542,6 +1667,7 @@ pub(crate) fn fuzz(
data_path: PathBuf,
transcript: TranscriptType,
num_runs: usize,
compress_selectors: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let passed = AtomicBool::new(true);
@@ -1557,8 +1683,12 @@ pub(crate) fn fuzz(
let data = GraphWitness::from_path(data_path)?;
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &params)
.map_err(Box::<dyn Error>::from)?;
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
)
.map_err(Box::<dyn Error>::from)?;
circuit.load_graph_witness(&data)?;
@@ -1574,9 +1704,12 @@ pub(crate) fn fuzz(
let fuzz_pk = || {
let new_params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
let bad_pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &new_params)
.map_err(|_| ())?;
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&new_params,
compress_selectors,
)
.map_err(|_| ())?;
let bad_proof = create_proof_circuit_kzg(
circuit.clone(),
@@ -1647,9 +1780,12 @@ pub(crate) fn fuzz(
let fuzz_vk = || {
let new_params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
let bad_pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(&circuit, &new_params)
.map_err(|_| ())?;
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&new_params,
compress_selectors,
)
.map_err(|_| ())?;
let bad_vk = bad_pk.get_vk();
@@ -1809,7 +1945,7 @@ pub(crate) fn mock_aggregate(
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
.map_err(Box::<dyn Error>::from)?;
prover
.verify_par()
.verify()
.map_err(|e| Box::<dyn Error>::from(ExecutionError::VerifyError(e)))?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("Done.");
@@ -1823,6 +1959,7 @@ pub(crate) fn setup_aggregate(
srs_path: Option<PathBuf>,
logrows: u32,
split_proofs: bool,
compress_selectors: bool,
) -> Result<String, Box<dyn Error>> {
// the K used for the aggregation circuit
let params = load_params_cmd(srs_path, logrows)?;
@@ -1833,8 +1970,11 @@ pub(crate) fn setup_aggregate(
}
let agg_circuit = AggregationCircuit::new(&params.get_g()[0].into(), snarks, split_proofs)?;
let agg_pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(&agg_circuit, &params)?;
let agg_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(
&agg_circuit,
&params,
compress_selectors,
)?;
let agg_vk = agg_pk.get_vk();

View File

@@ -617,13 +617,13 @@ impl ToPyObject for DataSource {
}
#[cfg(feature = "python-bindings")]
use crate::pfsys::field_to_vecu64_montgomery;
use crate::pfsys::field_to_string_montgomery;
#[cfg(feature = "python-bindings")]
impl ToPyObject for FileSourceInner {
fn to_object(&self, py: Python) -> PyObject {
match self {
FileSourceInner::Field(data) => field_to_vecu64_montgomery(data).to_object(py),
FileSourceInner::Field(data) => field_to_string_montgomery(data).to_object(py),
FileSourceInner::Bool(data) => data.to_object(py),
FileSourceInner::Float(data) => data.to_object(py),
}

View File

@@ -23,7 +23,7 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::table::{Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::table::{Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
use crate::pfsys::PrettyElements;
@@ -53,7 +53,7 @@ pub use utilities::*;
pub use vars::*;
#[cfg(feature = "python-bindings")]
use crate::pfsys::field_to_vecu64_montgomery;
use crate::pfsys::field_to_string_montgomery;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: i128 = 2;
@@ -332,16 +332,16 @@ impl ToPyObject for GraphWitness {
let dict_params = PyDict::new(py);
let dict_outputs = PyDict::new(py);
let inputs: Vec<Vec<[u64; 4]>> = self
let inputs: Vec<Vec<String>> = self
.inputs
.iter()
.map(|x| x.iter().map(field_to_vecu64_montgomery).collect())
.map(|x| x.iter().map(field_to_string_montgomery).collect())
.collect();
let outputs: Vec<Vec<[u64; 4]>> = self
let outputs: Vec<Vec<String>> = self
.outputs
.iter()
.map(|x| x.iter().map(field_to_vecu64_montgomery).collect())
.map(|x| x.iter().map(field_to_string_montgomery).collect())
.collect();
dict.set_item("inputs", inputs).unwrap();
@@ -389,9 +389,9 @@ impl ToPyObject for GraphWitness {
#[cfg(feature = "python-bindings")]
fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Result<(), PyErr> {
let poseidon_hash: Vec<[u64; 4]> = poseidon_hash
let poseidon_hash: Vec<String> = poseidon_hash
.iter()
.map(field_to_vecu64_montgomery)
.map(field_to_string_montgomery)
.collect();
pydict.set_item("poseidon_hash", poseidon_hash)?;
@@ -431,6 +431,8 @@ pub struct GraphSettings {
pub module_sizes: ModuleSizes,
/// required_lookups
pub required_lookups: Vec<LookupOp>,
/// required range_checks
pub required_range_checks: Vec<Range>,
/// check mode
pub check_mode: CheckMode,
/// ezkl version used
@@ -639,7 +641,7 @@ impl GraphCircuit {
}
// dummy module settings, must load from GraphData after
let mut settings = model.gen_params(run_args, CheckMode::UNSAFE)?;
let mut settings = model.gen_params(run_args, run_args.check_mode)?;
let mut num_params = 0;
if !model.const_shapes().is_empty() {
@@ -763,18 +765,18 @@ impl GraphCircuit {
if self.settings().run_args.input_visibility.is_public() {
public_inputs.rescaled_inputs = elements.rescaled_inputs.clone();
public_inputs.inputs = elements.inputs.clone();
} else if let Some(_) = &data.processed_inputs {
} else if data.processed_inputs.is_some() {
public_inputs.processed_inputs = elements.processed_inputs.clone();
}
if let Some(_) = &data.processed_params {
if data.processed_params.is_some() {
public_inputs.processed_params = elements.processed_params.clone();
}
if self.settings().run_args.output_visibility.is_public() {
public_inputs.rescaled_outputs = elements.rescaled_outputs.clone();
public_inputs.outputs = elements.outputs.clone();
} else if let Some(_) = &data.processed_outputs {
} else if data.processed_outputs.is_some() {
public_inputs.processed_outputs = elements.processed_outputs.clone();
}
@@ -956,19 +958,24 @@ impl GraphCircuit {
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
}
fn calc_safe_lookup_range(res: &GraphWitness, lookup_safety_margin: i128) -> (i128, i128) {
fn calc_safe_lookup_range(
min_lookup_inputs: i128,
max_lookup_inputs: i128,
lookup_safety_margin: i128,
) -> Range {
let mut margin = (
lookup_safety_margin * res.min_lookup_inputs,
lookup_safety_margin * res.max_lookup_inputs,
lookup_safety_margin * min_lookup_inputs,
lookup_safety_margin * max_lookup_inputs,
);
if lookup_safety_margin == 1 {
margin.0 -= 1;
margin.1 += 1;
margin.0 += 4;
margin.1 += 4;
}
margin
}
fn calc_num_cols(safe_range: (i128, i128), max_logrows: u32) -> usize {
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(
max_logrows as usize,
Self::reserved_blinding_rows() as usize,
@@ -978,7 +985,8 @@ impl GraphCircuit {
fn calc_min_logrows(
&mut self,
res: &GraphWitness,
min_lookup_inputs: i128,
max_lookup_inputs: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
@@ -986,19 +994,23 @@ impl GraphCircuit {
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS);
let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);
let mut min_logrows = MIN_LOGROWS;
let reserved_blinding_rows = Self::reserved_blinding_rows();
// check if has overflowed max lookup input
if res.max_lookup_inputs > MAX_LOOKUP_ABS / lookup_safety_margin
|| res.min_lookup_inputs < -MAX_LOOKUP_ABS / lookup_safety_margin
if max_lookup_inputs > MAX_LOOKUP_ABS / lookup_safety_margin
|| min_lookup_inputs < -MAX_LOOKUP_ABS / lookup_safety_margin
{
let err_string = format!("max lookup input ({}) is too large", res.max_lookup_inputs);
error!("{}", err_string);
let err_string = format!("max lookup input ({}) is too large", max_lookup_inputs);
return Err(err_string.into());
}
let safe_range = Self::calc_safe_lookup_range(res, lookup_safety_margin);
let mut min_logrows = MIN_LOGROWS;
let safe_range = Self::calc_safe_lookup_range(
min_lookup_inputs,
max_lookup_inputs,
lookup_safety_margin,
);
// degrade the max logrows until the extended k is small enough
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
@@ -1020,8 +1032,7 @@ impl GraphCircuit {
return Err(err_string.into());
}
// degrade the max logrows until the extended k is small enough
while max_logrows > min_logrows
while min_logrows < max_logrows
&& !self.extended_k_is_small_enough(
max_logrows,
Self::calc_num_cols(safe_range, max_logrows),
@@ -1030,6 +1041,17 @@ impl GraphCircuit {
max_logrows -= 1;
}
if !self
.extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows))
{
let err_string = format!(
"extended k is too large to accommodate the quotient polynomial with logrows {}",
max_logrows
);
error!("{}", err_string);
return Err(err_string.into());
}
let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.)
.log2()
.ceil() as usize;
@@ -1111,22 +1133,31 @@ impl GraphCircuit {
// n = 2^k
let n = 1u64 << k;
let mut extended_k = k;
while (1 << extended_k) < (n * quotient_poly_degree) {
extended_k += 1;
if extended_k > bn256::Fr::S {
return false;
}
}
extended_k <= bn256::Fr::S
true
}
/// Calibrate the circuit to the supplied data.
pub fn calibrate(
pub fn calibrate_from_min_max(
&mut self,
input: &[Tensor<Fp>],
min_lookup_inputs: i128,
max_lookup_inputs: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let res = self.forward(&mut input.to_vec(), None, None)?;
self.calc_min_logrows(&res, max_logrows, lookup_safety_margin)?;
Ok(res)
) -> Result<(), Box<dyn std::error::Error>> {
self.calc_min_logrows(
min_lookup_inputs,
max_lookup_inputs,
max_logrows,
lookup_safety_margin,
)?;
Ok(())
}
/// Runs the forward pass of the model / graph of computations and any associated hashing.
@@ -1428,6 +1459,7 @@ impl Circuit<Fp> for GraphCircuit {
params.run_args.lookup_range,
params.run_args.logrows as usize,
params.required_lookups,
params.required_range_checks,
params.check_mode,
)
.unwrap();

View File

@@ -6,6 +6,7 @@ use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::RegionCtx;
use crate::circuit::table::Range;
use crate::circuit::Input;
use crate::circuit::InputType;
use crate::circuit::Unknown;
@@ -79,6 +80,21 @@ pub struct ModelConfig {
/// Representation of execution graph
pub type NodeGraph = BTreeMap<usize, NodeType>;
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct DummyPassRes {
/// number of rows use
pub num_rows: usize,
/// linear coordinate
pub linear_coord: usize,
/// total const size
pub total_const_size: usize,
/// lookup ops
pub lookup_ops: HashSet<LookupOp>,
/// range checks
pub range_checks: HashSet<Range>,
}
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Model {
@@ -233,13 +249,7 @@ impl NodeType {
NodeType::SubGraph { out_dims, .. } => out_dims.clone(),
}
}
/// Returns the lookups required by a graph
pub fn required_lookups(&self) -> Vec<LookupOp> {
match self {
NodeType::Node(n) => n.opkind.required_lookups(),
NodeType::SubGraph { model, .. } => model.required_lookups(),
}
}
/// Returns the scales of the node's output.
pub fn out_scales(&self) -> Vec<crate::Scale> {
match self {
@@ -424,14 +434,6 @@ impl ParsedNodes {
}
impl Model {
fn required_lookups(&self) -> Vec<LookupOp> {
self.graph
.nodes
.values()
.flat_map(|n| n.required_lookups())
.collect_vec()
}
/// Creates a `Model` from a specified path to an Onnx file.
/// # Arguments
/// * `reader` - A reader for an Onnx file.
@@ -484,36 +486,21 @@ impl Model {
);
// this is the total number of variables we will need to allocate
// for the circuit
let (num_rows, linear_coord, total_const_size) =
self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
// extract the requisite lookup ops from the model
let mut lookup_ops: Vec<LookupOp> = self.required_lookups();
let res = self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
if run_args.tolerance.val > 0.0 {
for scale in self.graph.get_output_scales()? {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(scale).into();
let opkind: Box<dyn Op<Fp>> = Box::new(HybridOp::RangeCheck(tolerance));
lookup_ops.extend(opkind.required_lookups());
}
}
let set: HashSet<_> = lookup_ops.drain(..).collect(); // dedup
lookup_ops.extend(set.into_iter().sorted());
Ok(GraphSettings {
run_args: run_args.clone(),
model_instance_shapes: instance_shapes,
module_sizes: crate::graph::modules::ModuleSizes::default(),
num_rows,
total_assignments: linear_coord,
required_lookups: lookup_ops,
num_rows: res.num_rows,
total_assignments: res.linear_coord,
required_lookups: res.lookup_ops.into_iter().collect(),
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
total_const_size,
total_const_size: res.total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
num_blinding_factors: None,
@@ -568,6 +555,8 @@ impl Model {
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
);
debug!("input nodes: {:?}", n.inputs());
if n.is_lookup() {
let (mut min, mut max) = (0, 0);
for i in &inputs {
@@ -611,7 +600,7 @@ impl Model {
debug!("intermediate min lookup inputs: {}", min);
}
debug!(
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} ------------ scale: {}",
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} \n ------------ scale: {}",
idx,
res.output.map(crate::fieldutils::felt_to_i32).show(),
res.output
@@ -1042,6 +1031,8 @@ impl Model {
&run_args.param_visibility,
i,
symbol_values,
run_args.div_rebasing,
run_args.rebase_frac_zero_constants,
)?;
if let Some(ref scales) = override_input_scales {
if let Some(inp) = n.opkind.get_input() {
@@ -1058,9 +1049,20 @@ impl Model {
if scales.contains_key(&i) {
let scale_diff = n.out_scale - scales[&i];
n.opkind = if scale_diff > 0 {
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
RebaseScale::rebase(
n.opkind,
scales[&i],
n.out_scale,
1,
run_args.div_rebasing,
)
} else {
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
RebaseScale::rebase_up(
n.opkind,
scales[&i],
n.out_scale,
run_args.div_rebasing,
)
};
n.out_scale = scales[&i];
}
@@ -1155,9 +1157,10 @@ impl Model {
pub fn configure(
meta: &mut ConstraintSystem<Fp>,
vars: &ModelVars<Fp>,
lookup_range: (i128, i128),
lookup_range: Range,
logrows: usize,
required_lookups: Vec<LookupOp>,
required_range_checks: Vec<Range>,
check_mode: CheckMode,
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
info!("configuring model");
@@ -1176,6 +1179,10 @@ impl Model {
base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?;
}
for range in required_range_checks {
base_gate.configure_range_check(meta, input, range)?;
}
Ok(base_gate)
}
@@ -1216,6 +1223,7 @@ impl Model {
let instance_idx = vars.get_instance_idx();
config.base.layout_tables(layouter)?;
config.base.layout_range_checks(layouter)?;
let mut num_rows = 0;
let mut linear_coord = 0;
@@ -1482,7 +1490,7 @@ impl Model {
&self,
run_args: &RunArgs,
input_shapes: &[Vec<usize>],
) -> Result<(usize, usize, usize), Box<dyn Error>> {
) -> Result<DummyPassRes, Box<dyn Error>> {
info!("calculating num of constraints using dummy model layout...");
let start_time = instant::Instant::now();
@@ -1567,11 +1575,15 @@ impl Model {
region.total_constants().to_string().red()
);
Ok((
region.row(),
region.linear_coord(),
region.total_constants(),
))
let res = DummyPassRes {
num_rows: region.row(),
linear_coord: region.linear_coord(),
total_const_size: region.total_constants(),
lookup_ops: region.used_lookups(),
range_checks: region.used_range_checks(),
};
Ok(res)
}
/// Retrieves all constants from the model.

View File

@@ -12,16 +12,12 @@ use crate::circuit::Constant;
use crate::circuit::Input;
use crate::circuit::Op;
use crate::circuit::Unknown;
use crate::fieldutils::felt_to_i128;
use crate::fieldutils::i128_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::new_op_from_onnx;
use crate::tensor::Tensor;
use crate::tensor::TensorError;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::trace;
use serde::Deserialize;
use serde::Serialize;
@@ -94,10 +90,6 @@ impl Op<Fp> for Rescaled {
Op::<Fp>::out_scale(&*self.inner, in_scales)
}
fn required_lookups(&self) -> Vec<LookupOp> {
self.inner.required_lookups()
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -126,12 +118,14 @@ impl Op<Fp> for Rescaled {
pub struct RebaseScale {
/// The operation that has to be rescaled.
pub inner: Box<SupportedOp>,
/// the multiplier applied to the node output
pub multiplier: f64,
/// rebase op
pub rebase_op: HybridOp,
/// scale being rebased to
pub target_scale: i32,
/// The original scale of the operation's inputs.
pub original_scale: i32,
/// multiplier
pub multiplier: f64,
}
impl RebaseScale {
@@ -141,6 +135,7 @@ impl RebaseScale {
global_scale: crate::Scale,
op_out_scale: crate::Scale,
scale_rebase_multiplier: u32,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
&& !inner.is_constant()
@@ -149,10 +144,15 @@ impl RebaseScale {
let multiplier =
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
if let Some(op) = inner.get_rebased() {
let multiplier = op.multiplier * multiplier;
SupportedOp::RebaseScale(RebaseScale {
inner: op.inner.clone(),
target_scale: op.target_scale,
multiplier: op.multiplier * multiplier,
multiplier: multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op.original_scale,
})
} else {
@@ -160,6 +160,10 @@ impl RebaseScale {
inner: Box::new(inner),
target_scale: global_scale * scale_rebase_multiplier as i32,
multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
original_scale: op_out_scale,
})
}
@@ -173,15 +177,21 @@ impl RebaseScale {
inner: SupportedOp,
target_scale: crate::Scale,
op_out_scale: crate::Scale,
div_rebasing: bool,
) -> SupportedOp {
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
if let Some(op) = inner.get_rebased() {
let multiplier = op.multiplier * multiplier;
SupportedOp::RebaseScale(RebaseScale {
inner: op.inner.clone(),
target_scale: op.target_scale,
multiplier: op.multiplier * multiplier,
multiplier,
original_scale: op.original_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
},
})
} else {
SupportedOp::RebaseScale(RebaseScale {
@@ -189,6 +199,10 @@ impl RebaseScale {
target_scale,
multiplier,
original_scale: op_out_scale,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32(multiplier as f32),
use_range_check_for_int: !div_rebasing,
},
})
}
} else {
@@ -203,19 +217,19 @@ impl Op<Fp> for RebaseScale {
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
let mut res = Op::<Fp>::f(&*self.inner, x)?;
let ri = res.output.map(felt_to_i128);
let rescaled = crate::tensor::ops::nonlinearities::const_div(&ri, self.multiplier);
res.output = rescaled.map(i128_to_felt);
res.intermediate_lookups.push(ri);
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
res.output = rebase_res.output;
res.intermediate_lookups
.extend(rebase_res.intermediate_lookups);
Ok(res)
}
fn as_string(&self) -> String {
format!(
"REBASED (div={:?}) ({})",
"REBASED (div={:?}, rebasing_op={}) ({})",
self.multiplier,
<HybridOp as Op<Fp>>::as_string(&self.rebase_op),
self.inner.as_string()
)
}
@@ -224,14 +238,6 @@ impl Op<Fp> for RebaseScale {
Ok(self.target_scale)
}
fn required_lookups(&self) -> Vec<LookupOp> {
let mut lookups = self.inner.required_lookups();
lookups.push(LookupOp::Div {
denom: crate::circuit::utils::F32(self.multiplier as f32),
});
lookups
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
@@ -241,16 +247,8 @@ impl Op<Fp> for RebaseScale {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or("no layout")?;
Ok(Some(crate::circuit::layouts::nonlinearity(
config,
region,
&[original_res],
&LookupOp::Div {
denom: crate::circuit::utils::F32(self.multiplier as f32),
},
)?))
.ok_or("no inner layout")?;
self.rebase_op.layout(config, region, &[original_res])
}
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
@@ -433,10 +431,6 @@ impl Op<Fp> for SupportedOp {
self
}
fn required_lookups(&self) -> Vec<LookupOp> {
self.as_op().required_lookups()
}
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
self.as_op().out_scale(in_scales)
}
@@ -470,14 +464,7 @@ impl Tabled for Node {
fn headers() -> Vec<std::borrow::Cow<'static, str>> {
let mut headers = Vec::with_capacity(Self::LENGTH);
for i in [
"idx",
"opkind",
"out_scale",
"inputs",
"out_dims",
"required_lookups",
] {
for i in ["idx", "opkind", "out_scale", "inputs", "out_dims"] {
headers.push(std::borrow::Cow::Borrowed(i));
}
headers
@@ -490,14 +477,6 @@ impl Tabled for Node {
fields.push(std::borrow::Cow::Owned(self.out_scale.to_string()));
fields.push(std::borrow::Cow::Owned(display_vector(&self.inputs)));
fields.push(std::borrow::Cow::Owned(display_vector(&self.out_dims)));
fields.push(std::borrow::Cow::Owned(format!(
"{:?}",
self.opkind
.required_lookups()
.iter()
.map(<LookupOp as Op<Fp>>::as_string)
.collect_vec()
)));
fields
}
}
@@ -527,9 +506,9 @@ impl Node {
param_visibility: &Visibility,
idx: usize,
symbol_values: &SymbolValues,
div_rebasing: bool,
rebase_frac_zero_constants: bool,
) -> Result<Self, Box<dyn Error>> {
use log::warn;
trace!("Create {:?}", node);
trace!("Create op {:?}", node.op);
@@ -567,6 +546,7 @@ impl Node {
node.clone(),
&mut inputs,
symbol_values,
rebase_frac_zero_constants,
)?; // parses the op name
// we can only take the inputs as mutable once -- so we need to collect them first
@@ -622,8 +602,6 @@ impl Node {
input_node.bump_scale(out_scale);
in_scales[input] = out_scale;
}
} else {
warn!("input {} not found for rescaling, skipping ...", input);
}
}
@@ -631,7 +609,13 @@ impl Node {
let mut out_scale = opkind.out_scale(in_scales.clone())?;
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
let global_scale = scales.get_max();
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
opkind = RebaseScale::rebase(
opkind,
global_scale,
out_scale,
scales.rebase_multiplier,
div_rebasing,
);
out_scale = opkind.out_scale(in_scales)?;

Some files were not shown because too many files have changed in this diff Show More