mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
8 Commits
release-v2
...
ac/rm-pg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fce07b4e57 | ||
|
|
84a5693b0e | ||
|
|
0de0682bfa | ||
|
|
bf9cf14ab7 | ||
|
|
6818962ac2 | ||
|
|
70469e3bf9 | ||
|
|
52ff187e55 | ||
|
|
4e57a5a486 |
4
.github/workflows/pypi.yml
vendored
4
.github/workflows/pypi.yml
vendored
@@ -258,7 +258,7 @@ jobs:
|
||||
|
||||
- name: Install built wheel
|
||||
if: matrix.target == 'x86_64-unknown-linux-musl'
|
||||
uses: addnab/docker-run-action@v3
|
||||
uses: addnab/docker-run-action@3e77f186b7a929ef010f183a9e24c0f9955ea609
|
||||
with:
|
||||
image: alpine:latest
|
||||
options: -v ${{ github.workspace }}:/io -w /io
|
||||
@@ -380,7 +380,7 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
uses: dfm/rtds-action@618148c547f4b56cdf4fa4dcf3a94c91ce025f2d
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
|
||||
10
.github/workflows/rust.yml
vendored
10
.github/workflows/rust.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
toolchain: nightly-2025-02-17
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
@@ -233,7 +233,7 @@ jobs:
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: "v0.12.1"
|
||||
- uses: nanasess/setup-chromedriver@e93e57b843c0c92788f22483f1a31af8ee48db25 #v2.3.0
|
||||
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
@@ -256,10 +256,10 @@ jobs:
|
||||
submodules: recursive
|
||||
|
||||
- name: Install Foundry
|
||||
uses: foundry-rs/foundry-toolchain@v1
|
||||
uses: foundry-rs/foundry-toolchain@3b74dacdda3c0b763089addb99ed86bc3800e68b
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
run: |
|
||||
cd tests/foundry
|
||||
forge install https://github.com/foundry-rs/forge-std --no-git --no-commit
|
||||
forge test -vvvv --fuzz-runs 64
|
||||
@@ -798,8 +798,6 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
|
||||
- name: Postgres tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
|
||||
# - name: authenticate-kaggle-cli
|
||||
|
||||
147
Cargo.lock
generated
147
Cargo.lock
generated
@@ -467,7 +467,7 @@ version = "0.1.0"
|
||||
source = "git+https://github.com/alloy-rs/alloy?rev=5fbf57bac99edef9d8475190109a7ea9fb7e5e83#5fbf57bac99edef9d8475190109a7ea9fb7e5e83"
|
||||
dependencies = [
|
||||
"alloy-json-rpc",
|
||||
"base64 0.22.1",
|
||||
"base64",
|
||||
"futures-util",
|
||||
"futures-utils-wasm",
|
||||
"serde",
|
||||
@@ -888,12 +888,6 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
@@ -915,17 +909,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bigdecimal"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bincode"
|
||||
version = "1.3.3"
|
||||
@@ -1967,7 +1950,6 @@ dependencies = [
|
||||
"num",
|
||||
"objc",
|
||||
"openssl",
|
||||
"pg_bigdecimal",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
@@ -1988,7 +1970,6 @@ dependencies = [
|
||||
"test-case",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tosubcommand",
|
||||
"tract-onnx",
|
||||
"uniffi",
|
||||
@@ -2001,12 +1982,6 @@ dependencies = [
|
||||
"wasm-bindgen-test",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.0.1"
|
||||
@@ -2058,12 +2033,6 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "finl_unicode"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6"
|
||||
|
||||
[[package]]
|
||||
name = "fixed-hash"
|
||||
version = "0.8.0"
|
||||
@@ -3862,19 +3831,6 @@ dependencies = [
|
||||
"indexmap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pg_bigdecimal"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9855a94c74528af62c0ea236577af5e601263c1c404a6ac939b07c97c8e0216"
|
||||
dependencies = [
|
||||
"bigdecimal",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"num",
|
||||
"postgres",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf"
|
||||
version = "0.11.2"
|
||||
@@ -4032,49 +3988,6 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres"
|
||||
version = "0.19.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7915b33ed60abc46040cbcaa25ffa1c7ec240668e0477c4f3070786f5916d451"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49b6c5ef183cd3ab4ba005f1ca64c21e8bd97ce4699cfea9e8d9a2c4958ca520"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"hmac",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"rand 0.8.5",
|
||||
"sha2",
|
||||
"stringprep",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-types"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d2234cdee9408b523530a9b6d2d6b373d1db34f6a8e51dc03ded1828d7fb67c"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
version = "0.2.0"
|
||||
@@ -4584,7 +4497,7 @@ version = "0.12.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"base64",
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@@ -4857,7 +4770,7 @@ version = "2.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"base64",
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
@@ -5324,17 +5237,6 @@ dependencies = [
|
||||
"precomputed-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stringprep"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6"
|
||||
dependencies = [
|
||||
"finl_unicode",
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.0"
|
||||
@@ -5698,32 +5600,6 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres"
|
||||
version = "0.7.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d340244b32d920260ae7448cb72b6e238bddc3d4f7603394e7dd46ed8e48f5b8"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"phf",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"postgres-types",
|
||||
"rand 0.8.5",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.0"
|
||||
@@ -6388,12 +6264,6 @@ version = "0.11.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "wasite"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.92"
|
||||
@@ -6573,17 +6443,6 @@ dependencies = [
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a44ab49fad634e88f55bf8f9bb3abd2f27d7204172a112c7c9987e01c1c94ea9"
|
||||
dependencies = [
|
||||
"redox_syscall",
|
||||
"wasite",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
||||
@@ -69,8 +69,6 @@ reqwest = { version = "0.12.4", default-features = false, features = [
|
||||
"stream",
|
||||
], optional = true }
|
||||
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
|
||||
tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
@@ -242,8 +240,6 @@ ezkl = [
|
||||
"dep:indicatif",
|
||||
"dep:gag",
|
||||
"dep:reqwest",
|
||||
"dep:tokio-postgres",
|
||||
"dep:pg_bigdecimal",
|
||||
"dep:lazy_static",
|
||||
"dep:tokio",
|
||||
"dep:openssl",
|
||||
|
||||
29
README.md
29
README.md
@@ -43,7 +43,7 @@ The generated proofs can then be verified with much less computational resources
|
||||
|
||||
----------------------
|
||||
|
||||
### getting started ⚙️
|
||||
### Getting Started ⚙️
|
||||
|
||||
The easiest way to get started is to try out a notebook.
|
||||
|
||||
@@ -76,12 +76,12 @@ For more details visit the [docs](https://docs.ezkl.xyz). The CLI is faster than
|
||||
|
||||
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
|
||||
|
||||
#### In-browser EVM verifier
|
||||
#### In-browser EVM Verifier
|
||||
|
||||
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
|
||||
|
||||
|
||||
### building the project 🔨
|
||||
### Building the Project 🔨
|
||||
|
||||
#### Rust CLI
|
||||
|
||||
@@ -96,7 +96,7 @@ cargo install --locked --path .
|
||||
|
||||
|
||||
|
||||
#### building python bindings
|
||||
#### Building Python Bindings
|
||||
Python bindings exists and can be built using `maturin`. You will need `rust` and `cargo` to be installed.
|
||||
|
||||
```bash
|
||||
@@ -126,7 +126,7 @@ unset ENABLE_ICICLE_GPU
|
||||
|
||||
**NOTE:** Even with the above environment variable set, icicle is disabled for circuits where k <= 8. To change the value of `k` where icicle is enabled, you can set the environment variable `ICICLE_SMALL_K`.
|
||||
|
||||
### contributing 🌎
|
||||
### Contributing 🌎
|
||||
|
||||
If you're interested in contributing and are unsure where to start, reach out to one of the maintainers:
|
||||
|
||||
@@ -144,20 +144,21 @@ More broadly:
|
||||
|
||||
Any contribution intentionally submitted for inclusion in the work by you shall be licensed to Zkonduit Inc. under the terms and conditions specified in the [CLA](https://github.com/zkonduit/ezkl/blob/main/cla.md), which you agree to by intentionally submitting a contribution. In particular, you have the right to submit the contribution and we can distribute it, among other terms and conditions.
|
||||
|
||||
### no security guarantees
|
||||
|
||||
Ezkl is unaudited, beta software undergoing rapid development. There may be bugs. No guarantees of security are made and it should not be relied on in production.
|
||||
### Audits & Security
|
||||
|
||||
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
|
||||
[v21.0.0](https://github.com/zkonduit/ezkl/releases/tag/v21.0.0) has been audited by Trail of Bits, the report can be found [here](https://github.com/trailofbits/publications/blob/master/reviews/2025-03-zkonduit-ezkl-securityreview.pdf).
|
||||
|
||||
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
|
||||
|
||||
|
||||
### Advanced security topics
|
||||
|
||||
Check out `docs/advanced_security` for more advanced information on potential threat vectors.
|
||||
Check out `docs/advanced_security` for more advanced information on potential threat vectors that are specific to zero-knowledge inference, quantization, and to machine learning models generally.
|
||||
|
||||
|
||||
### No Warranty
|
||||
|
||||
### no warranty
|
||||
|
||||
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
Copyright (c) 2025 Zkonduit Inc.
|
||||
|
||||
|
||||
@@ -1088,7 +1088,7 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" rpc_url='http://127.0.0.1:3030'\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
|
||||
@@ -472,8 +472,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
@@ -526,9 +526,9 @@
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
@@ -557,8 +557,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" proof_path,\n",
|
||||
" addr_da,\n",
|
||||
")"
|
||||
]
|
||||
@@ -566,7 +566,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ezkl",
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -580,7 +580,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
"version": "3.12.9"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -543,8 +543,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
@@ -597,9 +597,9 @@
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
@@ -628,8 +628,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" proof_path,\n",
|
||||
" addr_da,\n",
|
||||
")"
|
||||
]
|
||||
@@ -651,7 +651,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.12.9"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -474,8 +474,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
@@ -529,9 +529,9 @@
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
@@ -560,8 +560,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" proof_path,\n",
|
||||
" addr_da,\n",
|
||||
")"
|
||||
]
|
||||
|
||||
@@ -453,8 +453,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -474,8 +474,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" \"http://127.0.0.1:3030\",\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
@@ -510,4 +510,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -462,8 +462,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -483,8 +483,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" \"http://127.0.0.1:3030\",\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
@@ -512,4 +512,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,462 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2wIAHwqH2_mo"
|
||||
},
|
||||
"source": [
|
||||
"**Import Dependencies**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "osjj-0Ta3E8O"
|
||||
},
|
||||
"source": [
|
||||
"**Create Computational Graph**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
"\n",
|
||||
" # x is a time series \n",
|
||||
" def forward(self, x):\n",
|
||||
" return [torch.mean(x)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = Model()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# # print(torch.__version__)\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"print(device)\n",
|
||||
"\n",
|
||||
"circuit.to(device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=11, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"# export(circuit, input_shape=[1, 20])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E3qCeX-X5xqd"
|
||||
},
|
||||
"source": [
|
||||
"**Set Data Source and Get Data**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"pg_input_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
|
||||
"print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# this corresponds to 4 batches\n",
|
||||
"calibration_filename = os.path.join('calibration.json')\n",
|
||||
"\n",
|
||||
"pg_cal_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eLJ7oirQ_HQR"
|
||||
},
|
||||
"source": [
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rNw0C9QL6W88"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.decomp_legs = 4\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"await ezkl.get_srs(settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
"with open(\"settings.json\") as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
" json_formatted_str = json.dumps(data, indent=2)\n",
|
||||
"\n",
|
||||
" print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"assert os.path.exists(\"settings.json\")\n",
|
||||
"assert os.path.exists(\"input.json\")\n",
|
||||
"assert os.path.exists(\"lol.onnx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_filename)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -504,8 +504,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -527,8 +527,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
" \"http://127.0.0.1:3030\",\n",
|
||||
" proof_path\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
@@ -558,4 +558,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,8 +261,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" \"verifier/reusable\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
@@ -288,7 +288,7 @@
|
||||
"for name in names:\n",
|
||||
" addr_path_vk = \"addr_vk.txt\"\n",
|
||||
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
|
||||
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
|
||||
" res = await ezkl.deploy_evm(addr_path_vk, 'http://127.0.0.1:3030', sol_key_code_path, \"vka\")\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" with open(addr_path_vk, 'r') as file:\n",
|
||||
@@ -298,8 +298,8 @@
|
||||
" sol_code_path = os.path.join(name, 'vk.sol')\n",
|
||||
" res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\",\n",
|
||||
" proof_path,\n",
|
||||
" addr_vk = addr_vk\n",
|
||||
" )\n",
|
||||
" assert res == True"
|
||||
|
||||
@@ -562,8 +562,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
@@ -616,9 +616,9 @@
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
@@ -653,8 +653,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" proof_path,\n",
|
||||
" addr_da,\n",
|
||||
")"
|
||||
]
|
||||
|
||||
@@ -666,7 +666,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -689,8 +689,8 @@
|
||||
"# await\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -701,7 +701,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -722,8 +722,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" \"http://127.0.0.1:3030\",\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
@@ -743,7 +743,8 @@
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
@@ -756,7 +757,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -849,8 +849,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -870,8 +870,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
" \"http://127.0.0.1:3030\",\n",
|
||||
" proof_path\n",
|
||||
")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
@@ -905,4 +905,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,8 +358,8 @@
|
||||
"\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" 'http://127.0.0.1:3030',\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
@@ -405,9 +405,9 @@
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
@@ -470,8 +470,8 @@
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
" proof_path,\n",
|
||||
" addr_da,\n",
|
||||
")"
|
||||
]
|
||||
@@ -531,7 +531,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.12.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,34 +1,34 @@
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::Commitments;
|
||||
use crate::RunArgs;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::{
|
||||
PoseidonChip,
|
||||
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
|
||||
};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
|
||||
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
|
||||
};
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::{
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
|
||||
ProofType, TranscriptType,
|
||||
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
|
||||
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
|
||||
};
|
||||
use crate::Commitments;
|
||||
use crate::RunArgs;
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
|
||||
use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use pyo3_stub_gen::{
|
||||
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
|
||||
derive::gen_stub_pyfunction, TypeInfo,
|
||||
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
|
||||
derive::gen_stub_pyfunction,
|
||||
};
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::collections::HashSet;
|
||||
@@ -206,6 +206,9 @@ struct PyRunArgs {
|
||||
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
|
||||
#[pyo3(get, set)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
/// float: epsilon used for arguments that use division
|
||||
#[pyo3(get, set)]
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -238,12 +241,14 @@ impl From<PyRunArgs> for RunArgs {
|
||||
decomp_base: py_run_args.decomp_base,
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
|
||||
epsilon: Some(py_run_args.epsilon),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
let eps = self.get_epsilon();
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
input_scale: self.input_scale,
|
||||
@@ -262,6 +267,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
decomp_base: self.decomp_base,
|
||||
decomp_legs: self.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
|
||||
epsilon: eps,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -962,6 +968,8 @@ fn gen_settings(
|
||||
output=PathBuf::from(DEFAULT_SETTINGS),
|
||||
variables=Vec::from([("batch_size".to_string(), 1)]),
|
||||
seed=DEFAULT_SEED.parse().unwrap(),
|
||||
min=None,
|
||||
max=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_random_data(
|
||||
@@ -969,8 +977,10 @@ fn gen_random_data(
|
||||
output: PathBuf,
|
||||
variables: Vec<(String, usize)>,
|
||||
seed: u64,
|
||||
min: Option<f32>,
|
||||
max: Option<f32>,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
|
||||
crate::execute::gen_random_data(model, output, variables, seed, min, max).map_err(|e| {
|
||||
let err_str = format!("Failed to generate settings: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
@@ -1819,7 +1829,7 @@ fn create_evm_data_attestation(
|
||||
test_data,
|
||||
input_source,
|
||||
output_source,
|
||||
rpc_url=None
|
||||
rpc_url,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup_test_evm_data(
|
||||
@@ -1829,7 +1839,7 @@ fn setup_test_evm_data(
|
||||
test_data: PathBuf,
|
||||
input_source: PyTestDataSource,
|
||||
output_source: PyTestDataSource,
|
||||
rpc_url: Option<String>,
|
||||
rpc_url: String,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::setup_test_evm_data(
|
||||
@@ -1853,8 +1863,8 @@ fn setup_test_evm_data(
|
||||
/// deploys the solidity verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
rpc_url,
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
rpc_url=None,
|
||||
contract_type=ContractType::default(),
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None,
|
||||
@@ -1863,8 +1873,8 @@ fn setup_test_evm_data(
|
||||
fn deploy_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
rpc_url: String,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
contract_type: ContractType,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
@@ -1892,9 +1902,9 @@ fn deploy_evm(
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
input_data,
|
||||
rpc_url,
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
|
||||
rpc_url=None,
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None
|
||||
))]
|
||||
@@ -1903,9 +1913,9 @@ fn deploy_da_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
input_data: String,
|
||||
rpc_url: String,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
@@ -1952,8 +1962,8 @@ fn deploy_da_evm(
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
rpc_url=None,
|
||||
addr_da = None,
|
||||
addr_vk = None,
|
||||
))]
|
||||
@@ -1961,8 +1971,8 @@ fn deploy_da_evm(
|
||||
fn verify_evm<'a>(
|
||||
py: Python<'a>,
|
||||
addr_verifier: &'a str,
|
||||
rpc_url: String,
|
||||
proof_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<&'a str>,
|
||||
addr_vk: Option<&'a str>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
fieldutils::{IntegerRep, integer_rep_to_felt},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -15,10 +15,12 @@ use serde::{Deserialize, Serialize};
|
||||
pub enum HybridOp {
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
eps: f64,
|
||||
},
|
||||
Rsqrt {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
eps: f64,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
@@ -42,6 +44,7 @@ pub enum HybridOp {
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
eps: f64,
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
@@ -77,6 +80,7 @@ pub enum HybridOp {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
axes: Vec<usize>,
|
||||
eps: f64,
|
||||
},
|
||||
Output {
|
||||
decomp: bool,
|
||||
@@ -128,12 +132,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
} => format!(
|
||||
"RSQRT (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
"RSQRT (input_scale={}, output_scale={}, eps={})",
|
||||
input_scale, output_scale, eps
|
||||
),
|
||||
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
HybridOp::Ln { scale, eps } => format!("LN(scale={}, eps={})", scale, eps),
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
|
||||
}
|
||||
@@ -146,16 +151,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
} => format!(
|
||||
"RECIP (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
"RECIP (input_scale={}, output_scale={}, eps={})",
|
||||
input_scale, output_scale, eps
|
||||
),
|
||||
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized, data_format
|
||||
normalized,
|
||||
data_format,
|
||||
} => format!(
|
||||
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
|
||||
padding, stride, kernel_shape, normalized, data_format
|
||||
@@ -177,10 +184,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
eps,
|
||||
} => {
|
||||
format!(
|
||||
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
|
||||
input_scale, output_scale, axes
|
||||
"SOFTMAX (input_scale={}, output_scale={}, axes={:?}, eps={})",
|
||||
input_scale, output_scale, axes, eps
|
||||
)
|
||||
}
|
||||
HybridOp::Output { decomp } => {
|
||||
@@ -211,17 +219,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
} => layouts::rsqrt(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
*eps,
|
||||
)?,
|
||||
HybridOp::Sqrt { scale } => {
|
||||
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
|
||||
}
|
||||
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
|
||||
HybridOp::Ln { scale, eps } => {
|
||||
layouts::ln(config, region, values[..].try_into()?, *scale, *eps)?
|
||||
}
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
@@ -255,12 +267,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
} => layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as IntegerRep),
|
||||
integer_rep_to_felt(output_scale.0 as IntegerRep),
|
||||
*eps,
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
@@ -317,6 +331,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
eps,
|
||||
} => layouts::softmax_axes(
|
||||
config,
|
||||
region,
|
||||
@@ -324,6 +339,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
axes,
|
||||
*eps,
|
||||
)?,
|
||||
HybridOp::Output { decomp } => {
|
||||
layouts::output(config, region, values[..].try_into()?, *decomp)?
|
||||
@@ -364,6 +380,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
|
||||
HybridOp::Ln {
|
||||
scale: output_scale,
|
||||
eps: _,
|
||||
} => 4 * multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
|
||||
@@ -303,6 +303,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
value: &[ValTensor<F>; 1],
|
||||
input_scale: F,
|
||||
output_scale: F,
|
||||
eps: f64,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = value[0].clone();
|
||||
let input_dims = input.dims();
|
||||
@@ -317,6 +318,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&input_evals,
|
||||
felt_to_integer_rep(input_scale) as f64,
|
||||
felt_to_integer_rep(output_scale) as f64,
|
||||
eps,
|
||||
)
|
||||
.par_iter()
|
||||
.map(|x| Value::known(integer_rep_to_felt(*x)))
|
||||
@@ -335,7 +337,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
let claimed_output = identity(config, region, &[claimed_output], true)?;
|
||||
// divide by input_scale
|
||||
let zero_inverse_val =
|
||||
tensor::ops::nonlinearities::zero_recip(felt_to_integer_rep(output_scale) as f64)[0];
|
||||
tensor::ops::nonlinearities::zero_recip(felt_to_integer_rep(output_scale) as f64, eps)[0];
|
||||
let zero_inverse = create_constant_tensor(integer_rep_to_felt(zero_inverse_val), 1);
|
||||
|
||||
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
|
||||
@@ -473,7 +475,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]),
|
||||
/// &[3, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = rsqrt::<Fp>(&dummy_config, &mut dummy_region, &[x], 1.0.into(), 1.0.into()).unwrap();
|
||||
/// let result = rsqrt::<Fp>(&dummy_config, &mut dummy_region, &[x], 1.0.into(), 1.0.into(), f64::EPSILON).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 1, 1, 1, 1, 1, 1]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
@@ -483,13 +485,21 @@ pub fn rsqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
value: &[ValTensor<F>; 1],
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
eps: f64,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let sqrt = sqrt(config, region, value, input_scale)?;
|
||||
|
||||
let felt_output_scale = integer_rep_to_felt(output_scale.0 as IntegerRep);
|
||||
let felt_input_scale = integer_rep_to_felt(input_scale.0 as IntegerRep);
|
||||
|
||||
let recip = recip(config, region, &[sqrt], felt_input_scale, felt_output_scale)?;
|
||||
let recip = recip(
|
||||
config,
|
||||
region,
|
||||
&[sqrt],
|
||||
felt_input_scale,
|
||||
felt_output_scale,
|
||||
eps,
|
||||
)?;
|
||||
|
||||
Ok(recip)
|
||||
}
|
||||
@@ -1547,7 +1557,7 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
|
||||
/// * Creates pairs: (index_input, value_input) for original elements
|
||||
/// * Creates pairs: (index_output, value_output) for permuted elements
|
||||
/// * index_input is a fixed sequence 0,1,2... corresponding to input positions
|
||||
///
|
||||
///
|
||||
/// - Core permutation verification:
|
||||
/// * For each (index_input, value_input), verify there exists exactly one
|
||||
/// (index_output, value_output) such that value_input = value_output
|
||||
@@ -5702,7 +5712,7 @@ pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = ln::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into()).unwrap();
|
||||
/// let result = ln::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into(), f64::EPSILON).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 0, 4, -8]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
///
|
||||
@@ -5712,6 +5722,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
eps: f64,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// first generate the claimed val
|
||||
|
||||
@@ -5882,6 +5893,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[pow2_prior_to_claimed_distance],
|
||||
scale_as_felt,
|
||||
scale_as_felt * scale_as_felt,
|
||||
eps,
|
||||
)?;
|
||||
|
||||
let interpolated_distance = pairwise(
|
||||
@@ -5910,6 +5922,7 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&[pow2_next_to_claimed_distance],
|
||||
scale_as_felt,
|
||||
scale_as_felt * scale_as_felt,
|
||||
eps,
|
||||
)?;
|
||||
|
||||
let interpolated_distance_next = pairwise(
|
||||
@@ -6698,12 +6711,13 @@ pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
axes: &[usize],
|
||||
eps: f64,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let soft_max_at_scale = move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, CircuitError> {
|
||||
softmax(config, region, values, input_scale, output_scale)
|
||||
softmax(config, region, values, input_scale, output_scale, eps)
|
||||
};
|
||||
|
||||
let output = multi_dim_axes_op(config, region, values, axes, soft_max_at_scale)?;
|
||||
@@ -6718,6 +6732,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
values: &[ValTensor<F>; 1],
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
eps: f64,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let is_assigned = values[0].all_prev_assigned();
|
||||
let mut input = values[0].clone();
|
||||
@@ -6736,6 +6751,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
&[denom],
|
||||
input_felt_scale,
|
||||
output_felt_scale,
|
||||
eps,
|
||||
)?;
|
||||
// product of num * (1 / denom) = input_scale * output_scale
|
||||
pairwise(config, region, &[input, inv_denom], BaseOp::Mult)
|
||||
@@ -6760,7 +6776,7 @@ pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
|
||||
/// Some(&[2, 2, 3, 2, 2, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap());
|
||||
/// let result = softmax::<Fp>(&dummy_config, &mut dummy_region, &[x], 128.0.into(), (128.0 * 128.0).into()).unwrap();
|
||||
/// let result = softmax::<Fp>(&dummy_config, &mut dummy_region, &[x], 128.0.into(), (128.0 * 128.0).into(), f64::EPSILON).unwrap();
|
||||
/// // doubles the scale of the input
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[350012, 350012, 352768, 350012, 350012, 344500]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
@@ -6771,6 +6787,7 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
eps: f64,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// get the max then subtract it
|
||||
let max_val = max(config, region, values)?;
|
||||
@@ -6787,7 +6804,14 @@ pub fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
},
|
||||
)?;
|
||||
|
||||
percent(config, region, &[ex.clone()], input_scale, output_scale)
|
||||
percent(
|
||||
config,
|
||||
region,
|
||||
&[ex.clone()],
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
)
|
||||
}
|
||||
|
||||
/// Checks that the percent error between the expected public output and the actual output value
|
||||
|
||||
@@ -382,6 +382,44 @@ pub struct Cli {
|
||||
pub command: Option<Commands>,
|
||||
}
|
||||
|
||||
/// Custom parser for data field that handles both direct JSON strings and file paths with '@' prefix
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
pub struct DataField(pub String);
|
||||
|
||||
impl FromStr for DataField {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// Check if the input starts with '@'
|
||||
if s.starts_with('@') {
|
||||
// Extract the file path (remove the '@' prefix)
|
||||
let file_path = &s[1..];
|
||||
|
||||
// Read the file content
|
||||
let content = std::fs::read_to_string(file_path)
|
||||
.map_err(|e| format!("Failed to read data file '{}': {}", file_path, e))?;
|
||||
|
||||
// Return the file content as the data field value
|
||||
Ok(DataField(content))
|
||||
} else {
|
||||
// Use the input string directly
|
||||
Ok(DataField(s.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for DataField {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![self.0.clone()]
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DataField {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
|
||||
pub enum Commands {
|
||||
@@ -400,9 +438,9 @@ pub enum Commands {
|
||||
|
||||
/// Generates the witness from an input file.
|
||||
GenWitness {
|
||||
/// The path to the .json data file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
/// The path to the .json data file (with @ prefix) or a raw data string of the form '{"input_data": [[1, 2, 3]]}'
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_parser = DataField::from_str)]
|
||||
data: Option<DataField>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
@@ -443,6 +481,12 @@ pub enum Commands {
|
||||
/// random seed for reproducibility (optional)
|
||||
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
|
||||
seed: u64,
|
||||
/// min value for random data
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
min: Option<f32>,
|
||||
/// max value for random data
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
max: Option<f32>,
|
||||
},
|
||||
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
|
||||
CalibrateSettings {
|
||||
@@ -641,9 +685,9 @@ pub enum Commands {
|
||||
/// Should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'T', long, value_hint = clap::ValueHint::FilePath)]
|
||||
test_data: PathBuf,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
/// RPC URL for an Ethereum node
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
rpc_url: String,
|
||||
/// where the input data come from
|
||||
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
|
||||
input_source: TestDataSource,
|
||||
@@ -728,7 +772,7 @@ pub enum Commands {
|
||||
},
|
||||
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
|
||||
#[command(name = "create-evm-vka")]
|
||||
CreateEvmVKArtifact {
|
||||
CreateEvmVka {
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -747,7 +791,7 @@ pub enum Commands {
|
||||
},
|
||||
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
|
||||
#[command(name = "create-evm-da")]
|
||||
CreateEvmDataAttestation {
|
||||
CreateEvmDa {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
|
||||
settings_path: Option<PathBuf>,
|
||||
@@ -838,9 +882,9 @@ pub enum Commands {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
/// RPC URL for an Ethereum node
|
||||
#[arg(short = 'U', long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: String,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
|
||||
/// The path to output the contract address
|
||||
addr_path: Option<PathBuf>,
|
||||
@@ -856,7 +900,7 @@ pub enum Commands {
|
||||
},
|
||||
/// Deploys an evm verifier that allows for data attestation
|
||||
#[command(name = "deploy-evm-da")]
|
||||
DeployEvmDataAttestation {
|
||||
DeployEvmDa {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
@@ -866,9 +910,9 @@ pub enum Commands {
|
||||
/// The path to the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
/// RPC URL for an Ethereum node
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
rpc_url: String,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA, value_hint = clap::ValueHint::FilePath)]
|
||||
/// The path to output the contract address
|
||||
addr_path: Option<PathBuf>,
|
||||
@@ -888,9 +932,9 @@ pub enum Commands {
|
||||
/// The path to verifier contract's address
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
|
||||
addr_verifier: H160Flag,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
/// RPC URL for an Ethereum node
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
rpc_url: String,
|
||||
/// does the verifier use data attestation ?
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_da: Option<H160Flag>,
|
||||
|
||||
28
src/eth.rs
28
src/eth.rs
@@ -12,7 +12,6 @@ use alloy::dyn_abi::abi::TokenSeq;
|
||||
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
|
||||
// use alloy::providers::Middleware;
|
||||
use alloy::json_abi::JsonAbi;
|
||||
use alloy::node_bindings::Anvil;
|
||||
use alloy::primitives::ruint::ParseError;
|
||||
use alloy::primitives::{B256, I256, ParseSignedError};
|
||||
use alloy::providers::ProviderBuilder;
|
||||
@@ -313,25 +312,12 @@ pub type ContractFactory<M> = CallBuilder<Http<Client>, Arc<M>, ()>;
|
||||
|
||||
/// Return an instance of Anvil and a client for the given RPC URL. If none is provided, a local client is used.
|
||||
pub async fn setup_eth_backend(
|
||||
rpc_url: Option<&str>,
|
||||
rpc_url: &str,
|
||||
private_key: Option<&str>,
|
||||
) -> Result<(EthersClient, alloy::primitives::Address), EthError> {
|
||||
// Launch anvil
|
||||
|
||||
let endpoint: String;
|
||||
if let Some(rpc_url) = rpc_url {
|
||||
endpoint = rpc_url.to_string();
|
||||
} else {
|
||||
let anvil = Anvil::new()
|
||||
.args([
|
||||
"--code-size-limit=41943040",
|
||||
"--disable-block-gas-limit",
|
||||
"-p",
|
||||
"8545",
|
||||
])
|
||||
.spawn();
|
||||
endpoint = anvil.endpoint();
|
||||
}
|
||||
let endpoint = rpc_url.to_string();
|
||||
|
||||
// Instantiate the wallet
|
||||
let wallet: LocalWallet;
|
||||
@@ -365,7 +351,7 @@ pub async fn setup_eth_backend(
|
||||
///
|
||||
pub async fn deploy_contract_via_solidity(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
rpc_url: &str,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
@@ -387,7 +373,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
settings_path: PathBuf,
|
||||
input: String,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
rpc_url: &str,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
) -> Result<H160, EthError> {
|
||||
@@ -574,7 +560,7 @@ pub async fn verify_proof_via_solidity(
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
addr: H160,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
rpc_url: &str,
|
||||
) -> Result<bool, EthError> {
|
||||
let flattened_instances = proof.instances.into_iter().flatten();
|
||||
|
||||
@@ -676,7 +662,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
addr_verifier: H160,
|
||||
addr_da: H160,
|
||||
addr_vk: Option<H160>,
|
||||
rpc_url: Option<&str>,
|
||||
rpc_url: &str,
|
||||
) -> Result<bool, EthError> {
|
||||
use ethabi::{Function, Param, ParamType, StateMutability, Token};
|
||||
|
||||
@@ -1015,7 +1001,7 @@ pub fn fix_da_sol(commitment_bytes: Option<Vec<u8>>, only_kzg: bool) -> Result<S
|
||||
require(checkKzgCommits(encoded), "Invalid KZG commitments");
|
||||
// static call the verifier contract to verify the proof
|
||||
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
|
||||
|
||||
|
||||
if (success) {
|
||||
return abi.decode(returndata, (bool));
|
||||
} else {
|
||||
|
||||
@@ -45,6 +45,7 @@ use halo2curves::serde::SerdeObject;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use instant::Instant;
|
||||
use itertools::Itertools;
|
||||
use lazy_static::lazy_static;
|
||||
use log::debug;
|
||||
use log::{info, trace, warn};
|
||||
use serde::Serialize;
|
||||
@@ -65,8 +66,6 @@ use thiserror::Error;
|
||||
use tract_onnx::prelude::IntoTensor;
|
||||
use tract_onnx::prelude::Tensor as TractTensor;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
lazy_static! {
|
||||
#[derive(Debug)]
|
||||
/// The path to the ezkl related data.
|
||||
@@ -138,11 +137,15 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
data,
|
||||
variables,
|
||||
seed,
|
||||
min,
|
||||
max,
|
||||
} => gen_random_data(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
variables,
|
||||
seed,
|
||||
min,
|
||||
max,
|
||||
),
|
||||
Commands::CalibrateSettings {
|
||||
model,
|
||||
@@ -173,7 +176,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
srs_path,
|
||||
} => gen_witness(
|
||||
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
data.unwrap_or(DataField(DEFAULT_DATA.into())).to_string(),
|
||||
Some(output.unwrap_or(DEFAULT_WITNESS.into())),
|
||||
vk_path,
|
||||
srs_path,
|
||||
@@ -212,8 +215,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
addr_vk,
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
|
||||
Commands::CreateEvmVKArtifact {
|
||||
Commands::CreateEvmVka {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
@@ -229,7 +231,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
Commands::CreateEvmDataAttestation {
|
||||
Commands::CreateEvmDa {
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
@@ -434,7 +436,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
Commands::DeployEvmDataAttestation {
|
||||
Commands::DeployEvmDa {
|
||||
data,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
@@ -841,6 +843,8 @@ pub(crate) fn gen_random_data(
|
||||
data_path: PathBuf,
|
||||
variables: Vec<(String, usize)>,
|
||||
seed: u64,
|
||||
min: Option<f32>,
|
||||
max: Option<f32>,
|
||||
) -> Result<String, EZKLError> {
|
||||
let mut file = std::fs::File::open(&model_path).map_err(|e| {
|
||||
crate::graph::errors::GraphError::ReadWriteFileError(
|
||||
@@ -859,22 +863,32 @@ pub(crate) fn gen_random_data(
|
||||
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
|
||||
.map_err(|e| EZKLError::from(e.to_string()))?;
|
||||
|
||||
let min = min.unwrap_or(0.0);
|
||||
let max = max.unwrap_or(1.0);
|
||||
|
||||
/// Generates a random tensor of a given size and type.
|
||||
fn random(
|
||||
sizes: &[usize],
|
||||
datum_type: tract_onnx::prelude::DatumType,
|
||||
seed: u64,
|
||||
min: f32,
|
||||
max: f32,
|
||||
) -> TractTensor {
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
|
||||
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
|
||||
let slice = tensor.as_slice_mut::<f32>().unwrap();
|
||||
slice.iter_mut().for_each(|x| *x = rng.r#gen());
|
||||
slice.iter_mut().for_each(|x| *x = rng.gen_range(min..max));
|
||||
tensor.cast_to_dt(datum_type).unwrap().into_owned()
|
||||
}
|
||||
|
||||
fn tensor_for_fact(fact: &tract_onnx::prelude::TypedFact, seed: u64) -> TractTensor {
|
||||
fn tensor_for_fact(
|
||||
fact: &tract_onnx::prelude::TypedFact,
|
||||
seed: u64,
|
||||
min: f32,
|
||||
max: f32,
|
||||
) -> TractTensor {
|
||||
if let Some(value) = &fact.konst {
|
||||
return value.clone().into_tensor();
|
||||
}
|
||||
@@ -885,12 +899,14 @@ pub(crate) fn gen_random_data(
|
||||
.expect("Expected concrete shape, found: {fact:?}"),
|
||||
fact.datum_type,
|
||||
seed,
|
||||
min,
|
||||
max,
|
||||
)
|
||||
}
|
||||
|
||||
let generated = input_facts
|
||||
.iter()
|
||||
.map(|v| tensor_for_fact(v, seed))
|
||||
.map(|v| tensor_for_fact(v, seed, min, max))
|
||||
.collect_vec();
|
||||
|
||||
let data = GraphData::from_tract_data(&generated)?;
|
||||
@@ -1592,7 +1608,7 @@ pub(crate) async fn deploy_da_evm(
|
||||
data: String,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
rpc_url: String,
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
@@ -1601,7 +1617,7 @@ pub(crate) async fn deploy_da_evm(
|
||||
settings_path,
|
||||
data,
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
&rpc_url,
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
)
|
||||
@@ -1616,7 +1632,7 @@ pub(crate) async fn deploy_da_evm(
|
||||
|
||||
pub(crate) async fn deploy_evm(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
rpc_url: String,
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
@@ -1629,7 +1645,7 @@ pub(crate) async fn deploy_evm(
|
||||
};
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
&rpc_url,
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
@@ -1672,7 +1688,7 @@ pub(crate) fn encode_evm_calldata(
|
||||
pub(crate) async fn verify_evm(
|
||||
proof_path: PathBuf,
|
||||
addr_verifier: H160Flag,
|
||||
rpc_url: Option<String>,
|
||||
rpc_url: String,
|
||||
addr_da: Option<H160Flag>,
|
||||
addr_vk: Option<H160Flag>,
|
||||
) -> Result<String, EZKLError> {
|
||||
@@ -1686,7 +1702,7 @@ pub(crate) async fn verify_evm(
|
||||
addr_verifier.into(),
|
||||
addr_da.into(),
|
||||
addr_vk.map(|s| s.into()),
|
||||
rpc_url.as_deref(),
|
||||
&rpc_url,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
@@ -1694,7 +1710,7 @@ pub(crate) async fn verify_evm(
|
||||
proof.clone(),
|
||||
addr_verifier.into(),
|
||||
addr_vk.map(|s| s.into()),
|
||||
rpc_url.as_deref(),
|
||||
&rpc_url,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
@@ -1835,7 +1851,7 @@ pub(crate) async fn setup_test_evm_data(
|
||||
data_path: String,
|
||||
compiled_circuit_path: PathBuf,
|
||||
test_data: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
rpc_url: String,
|
||||
input_source: TestDataSource,
|
||||
output_source: TestDataSource,
|
||||
) -> Result<String, EZKLError> {
|
||||
|
||||
@@ -98,8 +98,6 @@ pub enum GraphError {
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
#[error("[tokio postgres] {0}")]
|
||||
TokioPostgresError(#[from] tokio_postgres::Error),
|
||||
/// Eth error
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
@@ -141,7 +139,9 @@ pub enum GraphError {
|
||||
#[error("range check {0} is too large")]
|
||||
RangeCheckTooLarge(usize),
|
||||
///Cannot use on-chain data source as private data
|
||||
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
|
||||
#[error(
|
||||
"cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm."
|
||||
)]
|
||||
OnChainDataSource,
|
||||
/// Missing data source
|
||||
#[error("missing data source")]
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::integer_rep_to_felt;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::postgres::Client;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
@@ -19,13 +17,10 @@ use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::{
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
tract_data::{TVec, prelude::Tensor as TractTensor},
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
|
||||
|
||||
type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
@@ -201,9 +196,9 @@ impl OnChainSource {
|
||||
data: &FileSource,
|
||||
scales: Vec<crate::Scale>,
|
||||
mut shapes: Vec<Vec<usize>>,
|
||||
rpc: Option<&str>,
|
||||
rpc: &str,
|
||||
) -> Result<Self, GraphError> {
|
||||
use crate::eth::{read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT};
|
||||
use crate::eth::{read_on_chain_inputs, test_on_chain_data};
|
||||
use log::debug;
|
||||
|
||||
// Set up local anvil instance for reading on-chain data
|
||||
@@ -217,7 +212,7 @@ impl OnChainSource {
|
||||
shapes[idx] = vec![i.len()];
|
||||
}
|
||||
}
|
||||
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
|
||||
let used_rpc = rpc.to_string();
|
||||
|
||||
let call_to_account = test_on_chain_data(client.clone(), data).await?;
|
||||
debug!("Call to account: {:?}", call_to_account);
|
||||
@@ -260,9 +255,6 @@ pub enum DataSource {
|
||||
File(FileSource),
|
||||
/// Data fetched from blockchain contracts
|
||||
OnChain(OnChainSource),
|
||||
/// Data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DB(PostgresSource),
|
||||
}
|
||||
|
||||
impl Default for DataSource {
|
||||
@@ -323,15 +315,6 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
return Ok(DataSource::OnChain(t));
|
||||
}
|
||||
|
||||
// Try deserializing as PostgresSource if feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = third_try {
|
||||
return Ok(DataSource::DB(t));
|
||||
}
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize DataSource"))
|
||||
}
|
||||
}
|
||||
@@ -381,7 +364,7 @@ impl GraphData {
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"non file data cannot be split into batches".to_string(),
|
||||
))
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(inputs)
|
||||
@@ -434,13 +417,13 @@ impl GraphData {
|
||||
/// Loads graph input data from a string, first seeing if it is a file path or JSON data
|
||||
/// If it is a file path, it will load the data from the file
|
||||
/// Otherwise, it will attempt to parse the string as JSON data
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - String containing the input data
|
||||
/// # Returns
|
||||
/// A new GraphData instance containing the loaded data
|
||||
pub fn from_str(data: &str) -> Result<Self, GraphError> {
|
||||
let graph_input = serde_json::from_str(data);
|
||||
let graph_input = serde_json::from_str(data);
|
||||
match graph_input {
|
||||
Ok(graph_input) => {
|
||||
return Ok(graph_input);
|
||||
@@ -515,13 +498,8 @@ impl GraphData {
|
||||
return Err(GraphError::InvalidDims(
|
||||
0,
|
||||
"on-chain data cannot be split into batches".to_string(),
|
||||
))
|
||||
));
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
GraphData {
|
||||
input_data: DataSource::DB(data),
|
||||
output_data: _,
|
||||
} => data.fetch_and_format_as_file().await?,
|
||||
};
|
||||
|
||||
// Process each input tensor according to its shape
|
||||
@@ -538,7 +516,6 @@ impl GraphData {
|
||||
input.len(),
|
||||
input_size
|
||||
),
|
||||
|
||||
));
|
||||
}
|
||||
|
||||
@@ -592,28 +569,6 @@ impl GraphData {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_postgres_source_new() {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let source = PostgresSource::new(
|
||||
"localhost".to_string(),
|
||||
"5432".to_string(),
|
||||
"user".to_string(),
|
||||
"SELECT * FROM table".to_string(),
|
||||
"database".to_string(),
|
||||
"password".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(source.host, "localhost");
|
||||
assert_eq!(source.port, "5432");
|
||||
assert_eq!(source.user, "user");
|
||||
assert_eq!(source.query, "SELECT * FROM table");
|
||||
assert_eq!(source.dbname, "database");
|
||||
assert_eq!(source.password, "password");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
// Test backwards compatibility with old format
|
||||
@@ -656,95 +611,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
/// Source data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
/// Database host address
|
||||
pub host: RPCUrl,
|
||||
/// Database user name
|
||||
pub user: String,
|
||||
/// Database password
|
||||
pub password: String,
|
||||
/// SQL query to execute
|
||||
pub query: String,
|
||||
/// Database name
|
||||
pub dbname: String,
|
||||
/// Database port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Creates a new PostgreSQL data source
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches data from the PostgreSQL database
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// Configuration string
|
||||
let config = if self.password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
self.host, self.user, self.dbname, self.port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
self.host, self.user, self.dbname, self.port, self.password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
|
||||
// Extract rows from query
|
||||
for row in client.query(&self.query, &[]).await? {
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetches and formats data as FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallToAccount {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -768,14 +634,6 @@ impl ToPyObject for DataSource {
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DataSource::DB(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("host", &source.host).unwrap();
|
||||
dict.set_item("user", &source.user).unwrap();
|
||||
dict.set_item("query", &source.query).unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,6 @@ pub mod model;
|
||||
pub mod modules;
|
||||
/// Inner elements of a computational graph that represent a single operation / constraints.
|
||||
pub mod node;
|
||||
/// postgres helper functions
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
pub mod postgres;
|
||||
/// Helper functions
|
||||
pub mod utilities;
|
||||
/// Representations of a computational graph's variables.
|
||||
@@ -36,12 +33,12 @@ use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSize
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::region::{ConstantsMap, RegionSettings};
|
||||
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::table::{RESERVED_BLINDING_ROWS_PAD, Range, Table, num_cols_required};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::{felt_to_f64, IntegerRep};
|
||||
use crate::fieldutils::{IntegerRep, felt_to_f64};
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
use crate::{EZKL_BUF_CAPACITY, RunArgs};
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
@@ -56,13 +53,13 @@ use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
pub use model::*;
|
||||
pub use node::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDictMethods;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Deref;
|
||||
@@ -764,7 +761,7 @@ pub struct TestOnChainData {
|
||||
/// The path to the test witness
|
||||
pub data: std::path::PathBuf,
|
||||
/// rpc endpoint
|
||||
pub rpc: Option<String>,
|
||||
pub rpc: String,
|
||||
/// data sources for the on chain data
|
||||
pub data_sources: TestSources,
|
||||
}
|
||||
@@ -1011,10 +1008,6 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
DataSource::DB(pg) => {
|
||||
let data = pg.fetch_and_format_as_file().await?;
|
||||
self.load_file_data(&data, &shapes, scales, input_types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1027,7 +1020,7 @@ impl GraphCircuit {
|
||||
scales: Vec<crate::Scale>,
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
|
||||
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
|
||||
let (client, client_address) = setup_eth_backend(&source.rpc, None).await?;
|
||||
let input = read_on_chain_inputs(client.clone(), client_address, &source.call).await?;
|
||||
let quantized_evm_inputs =
|
||||
evm_quantize(client, scales, &input, &source.call.decimals).await?;
|
||||
@@ -1481,13 +1474,9 @@ impl GraphCircuit {
|
||||
// print file data
|
||||
debug!("file data: {:?}", file_data);
|
||||
|
||||
let on_chain_data: OnChainSource = OnChainSource::test_from_file_data(
|
||||
&file_data,
|
||||
scales,
|
||||
shapes,
|
||||
test_on_chain_data.rpc.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
let on_chain_data: OnChainSource =
|
||||
OnChainSource::test_from_file_data(&file_data, scales, shapes, &test_on_chain_data.rpc)
|
||||
.await?;
|
||||
// Here we update the GraphData struct with the on-chain data
|
||||
if input_data.is_some() {
|
||||
data.input_data = on_chain_data.clone().into();
|
||||
|
||||
@@ -1,493 +0,0 @@
|
||||
use log::{debug, error, info};
|
||||
use std::fmt::Debug;
|
||||
use std::net::IpAddr;
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, pin::Pin};
|
||||
use tokio::task::JoinHandle;
|
||||
#[doc(inline)]
|
||||
pub use tokio_postgres::config::{
|
||||
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
|
||||
};
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::NoTls;
|
||||
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
|
||||
|
||||
/// Connection configuration.
|
||||
///
|
||||
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
|
||||
///
|
||||
///
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
config: tokio_postgres::Config,
|
||||
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Config {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.debug_struct("Config")
|
||||
.field("config", &self.config)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Config {
|
||||
Config::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Creates a new configuration.
|
||||
pub fn new() -> Config {
|
||||
tokio_postgres::Config::new().into()
|
||||
}
|
||||
|
||||
/// Sets the user to authenticate with.
|
||||
///
|
||||
/// If the user is not set, then this defaults to the user executing this process.
|
||||
pub fn user(&mut self, user: &str) -> &mut Config {
|
||||
self.config.user(user);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the user to authenticate with, if one has been configured with
|
||||
/// the `user` method.
|
||||
pub fn get_user(&self) -> Option<&str> {
|
||||
self.config.get_user()
|
||||
}
|
||||
|
||||
/// Sets the password to authenticate with.
|
||||
pub fn password<T>(&mut self, password: T) -> &mut Config
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
{
|
||||
self.config.password(password);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the password to authenticate with, if one has been configured with
|
||||
/// the `password` method.
|
||||
pub fn get_password(&self) -> Option<&[u8]> {
|
||||
self.config.get_password()
|
||||
}
|
||||
|
||||
/// Sets the name of the database to connect to.
|
||||
///
|
||||
/// Defaults to the user.
|
||||
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
|
||||
self.config.dbname(dbname);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the name of the database to connect to, if one has been configured
|
||||
/// with the `dbname` method.
|
||||
pub fn get_dbname(&self) -> Option<&str> {
|
||||
self.config.get_dbname()
|
||||
}
|
||||
|
||||
/// Sets command line options used to configure the server.
|
||||
pub fn options(&mut self, options: &str) -> &mut Config {
|
||||
self.config.options(options);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the command line options used to configure the server, if the
|
||||
/// options have been set with the `options` method.
|
||||
pub fn get_options(&self) -> Option<&str> {
|
||||
self.config.get_options()
|
||||
}
|
||||
|
||||
/// Sets the value of the `application_name` runtime parameter.
|
||||
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
|
||||
self.config.application_name(application_name);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the value of the `application_name` runtime parameter, if it has
|
||||
/// been set with the `application_name` method.
|
||||
pub fn get_application_name(&self) -> Option<&str> {
|
||||
self.config.get_application_name()
|
||||
}
|
||||
|
||||
/// Sets the SSL configuration.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
|
||||
self.config.ssl_mode(ssl_mode);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the SSL configuration.
|
||||
pub fn get_ssl_mode(&self) -> SslMode {
|
||||
self.config.get_ssl_mode()
|
||||
}
|
||||
|
||||
/// Adds a host to the configuration.
|
||||
///
|
||||
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
|
||||
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
|
||||
/// There must be either no hosts, or the same number of hosts as hostaddrs.
|
||||
pub fn host(&mut self, host: &str) -> &mut Config {
|
||||
self.config.host(host);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the hosts that have been added to the configuration with `host`.
|
||||
pub fn get_hosts(&self) -> &[Host] {
|
||||
self.config.get_hosts()
|
||||
}
|
||||
|
||||
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
|
||||
pub fn get_hostaddrs(&self) -> &[IpAddr] {
|
||||
self.config.get_hostaddrs()
|
||||
}
|
||||
|
||||
/// Adds a Unix socket host to the configuration.
|
||||
///
|
||||
/// Unlike `host`, this method allows non-UTF8 paths.
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
pub fn host_path<T>(&mut self, host: T) -> &mut Config
|
||||
where
|
||||
T: AsRef<Path>,
|
||||
{
|
||||
self.config.host_path(host);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a hostaddr to the configuration.
|
||||
///
|
||||
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
|
||||
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
|
||||
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
|
||||
self.config.hostaddr(hostaddr);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a port to the configuration.
|
||||
///
|
||||
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
|
||||
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
|
||||
/// as hosts.
|
||||
pub fn port(&mut self, port: u16) -> &mut Config {
|
||||
self.config.port(port);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the ports that have been added to the configuration with `port`.
|
||||
pub fn get_ports(&self) -> &[u16] {
|
||||
self.config.get_ports()
|
||||
}
|
||||
|
||||
/// Sets the timeout applied to socket-level connection attempts.
|
||||
///
|
||||
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
|
||||
/// host separately. Defaults to no limit.
|
||||
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
|
||||
self.config.connect_timeout(connect_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the connection timeout, if one has been set with the
|
||||
/// `connect_timeout` method.
|
||||
pub fn get_connect_timeout(&self) -> Option<&Duration> {
|
||||
self.config.get_connect_timeout()
|
||||
}
|
||||
|
||||
/// Sets the TCP user timeout.
|
||||
///
|
||||
/// This is ignored for Unix domain socket connections. It is only supported on systems where
|
||||
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
|
||||
/// on other systems, it has no effect.
|
||||
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
|
||||
self.config.tcp_user_timeout(tcp_user_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the TCP user timeout, if one has been set with the
|
||||
/// `user_timeout` method.
|
||||
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
|
||||
self.config.get_tcp_user_timeout()
|
||||
}
|
||||
|
||||
/// Controls the use of TCP keepalive.
|
||||
///
|
||||
/// This is ignored for Unix domain socket connections. Defaults to `true`.
|
||||
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
|
||||
self.config.keepalives(keepalives);
|
||||
self
|
||||
}
|
||||
|
||||
/// Reports whether TCP keepalives will be used.
|
||||
pub fn get_keepalives(&self) -> bool {
|
||||
self.config.get_keepalives()
|
||||
}
|
||||
|
||||
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
|
||||
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
|
||||
self.config.keepalives_idle(keepalives_idle);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the configured amount of idle time before a keepalive packet will
|
||||
/// be sent on the connection.
|
||||
pub fn get_keepalives_idle(&self) -> Duration {
|
||||
self.config.get_keepalives_idle()
|
||||
}
|
||||
|
||||
/// Sets the time interval between TCP keepalive probes.
|
||||
/// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
|
||||
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
|
||||
self.config.keepalives_interval(keepalives_interval);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the time interval between TCP keepalive probes.
|
||||
pub fn get_keepalives_interval(&self) -> Option<Duration> {
|
||||
self.config.get_keepalives_interval()
|
||||
}
|
||||
|
||||
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
|
||||
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
|
||||
self.config.keepalives_retries(keepalives_retries);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
|
||||
pub fn get_keepalives_retries(&self) -> Option<u32> {
|
||||
self.config.get_keepalives_retries()
|
||||
}
|
||||
|
||||
/// Sets the requirements of the session.
|
||||
///
|
||||
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
|
||||
/// secondary servers. Defaults to `Any`.
|
||||
pub fn target_session_attrs(
|
||||
&mut self,
|
||||
target_session_attrs: TargetSessionAttrs,
|
||||
) -> &mut Config {
|
||||
self.config.target_session_attrs(target_session_attrs);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the requirements of the session.
|
||||
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
|
||||
self.config.get_target_session_attrs()
|
||||
}
|
||||
|
||||
/// Sets the channel binding behavior.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
|
||||
self.config.channel_binding(channel_binding);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the channel binding behavior.
|
||||
pub fn get_channel_binding(&self) -> ChannelBinding {
|
||||
self.config.get_channel_binding()
|
||||
}
|
||||
|
||||
/// Sets the host load balancing behavior.
|
||||
///
|
||||
/// Defaults to `disable`.
|
||||
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
|
||||
self.config.load_balance_hosts(load_balance_hosts);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the host load balancing behavior.
|
||||
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
|
||||
self.config.get_load_balance_hosts()
|
||||
}
|
||||
|
||||
/// Sets the notice callback.
|
||||
///
|
||||
/// This callback will be invoked with the contents of every
|
||||
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
|
||||
/// the same structure as errors, but they are not "errors" per-se.
|
||||
///
|
||||
/// Notices are distinct from notifications, which are instead accessible
|
||||
/// via the [`Notifications`] API.
|
||||
///
|
||||
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
|
||||
/// [`Notifications`]: crate::Notifications
|
||||
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
|
||||
where
|
||||
F: Fn(DbError) + Send + Sync + 'static,
|
||||
{
|
||||
self.notice_callback = Arc::new(f);
|
||||
self
|
||||
}
|
||||
|
||||
/// Opens a connection to a PostgreSQL database.
|
||||
pub async fn connect(&self) -> Result<Client, Error> {
|
||||
let (client, connection) = self.config.connect(NoTls).await?;
|
||||
|
||||
let connection = Connection::new(connection);
|
||||
|
||||
Ok(Client::new(client, connection))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Config {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Config, Error> {
|
||||
s.parse::<tokio_postgres::Config>().map(Config::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Config> for Config {
|
||||
fn from(config: tokio_postgres::Config) -> Config {
|
||||
Config {
|
||||
config,
|
||||
notice_callback: Arc::new(|notice| {
|
||||
info!("{}: {}", notice.severity(), notice.message())
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_debug_implementations, dead_code)]
|
||||
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
|
||||
/// get dropped.
|
||||
pub struct Connection {
|
||||
/// The underlying connection stream.
|
||||
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Creates a new connection.
|
||||
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
|
||||
Connection {
|
||||
connection: Box::pin(connection),
|
||||
}
|
||||
}
|
||||
|
||||
/// start the connection
|
||||
pub async fn start(self) {
|
||||
if let Err(e) = self.connection.await {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_debug_implementations, dead_code)]
|
||||
/// An asynchronous PostgreSQL client.
|
||||
pub struct Client {
|
||||
connection: JoinHandle<()>,
|
||||
client: tokio_postgres::Client,
|
||||
}
|
||||
|
||||
impl Drop for Client {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.close_inner();
|
||||
}
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
let thread = tokio::spawn(async move {
|
||||
connection.start().await;
|
||||
});
|
||||
|
||||
Client {
|
||||
client,
|
||||
connection: thread,
|
||||
}
|
||||
}
|
||||
|
||||
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
|
||||
///
|
||||
/// See the documentation for [`Config`] for information about the connection syntax.
|
||||
///
|
||||
/// [`Config`]: config/struct.Config.html
|
||||
pub async fn connect(params: &str) -> Result<Client, Error> {
|
||||
debug!("Connecting to database with params: {}", params);
|
||||
params.parse::<Config>()?.connect().await
|
||||
}
|
||||
|
||||
/// Returns a new `Config` object which can be used to configure and connect to a database.
|
||||
pub fn configure() -> Config {
|
||||
Config::new()
|
||||
}
|
||||
|
||||
/// Executes a statement, returning the number of rows modified.
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
|
||||
///
|
||||
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
pub async fn execute<T>(
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<u64, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement + Debug,
|
||||
{
|
||||
debug!("Executing query: {:?}", query);
|
||||
self.client.execute(query, params).await
|
||||
}
|
||||
|
||||
/// Executes a statement, returning the resulting rows.
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
pub async fn query<T>(
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<Vec<Row>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement + Debug,
|
||||
{
|
||||
debug!("Executing query: {:?}", query);
|
||||
self.client.query(query, params).await
|
||||
}
|
||||
|
||||
/// Determines if the client's connection has already closed.
|
||||
///
|
||||
/// If this returns `true`, the client is no longer usable.
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.client.is_closed()
|
||||
}
|
||||
|
||||
/// Closes the client's connection to the server.
|
||||
///
|
||||
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
|
||||
/// caller.
|
||||
pub fn close(mut self) -> Result<(), Error> {
|
||||
self.close_inner()
|
||||
}
|
||||
|
||||
fn close_inner(&mut self) -> Result<(), Error> {
|
||||
self.client.__private_api_close();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -858,6 +858,7 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
eps: run_args.get_epsilon(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -903,6 +904,7 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
eps: run_args.get_epsilon(),
|
||||
})
|
||||
}
|
||||
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
@@ -913,6 +915,7 @@ pub fn new_op_from_onnx(
|
||||
if run_args.bounded_log_lookup {
|
||||
SupportedOp::Hybrid(HybridOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
eps: run_args.get_epsilon(),
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
@@ -1131,6 +1134,7 @@ pub fn new_op_from_onnx(
|
||||
input_scale: scale_to_multiplier(in_scale).into(),
|
||||
output_scale: scale_to_multiplier(max_scale).into(),
|
||||
axes: softmax_op.axes.to_vec(),
|
||||
eps: run_args.get_epsilon(),
|
||||
})
|
||||
}
|
||||
"MaxPool" => {
|
||||
|
||||
15
src/lib.rs
15
src/lib.rs
@@ -97,11 +97,11 @@ impl From<String> for EZKLError {
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode};
|
||||
use circuit::{CheckMode, table::Range};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::Args;
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::{Visibility, MAX_PUBLIC_SRS};
|
||||
use graph::{MAX_PUBLIC_SRS, Visibility};
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
};
|
||||
@@ -350,6 +350,16 @@ pub struct RunArgs {
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
/// Optional override for epsilon value
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
|
||||
pub epsilon: Option<f64>,
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
/// Returns the epsilon value
|
||||
pub fn get_epsilon(&self) -> f64 {
|
||||
self.epsilon.unwrap_or(f64::EPSILON)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
@@ -376,6 +386,7 @@ impl Default for RunArgs {
|
||||
decomp_base: 16384,
|
||||
decomp_legs: 2,
|
||||
ignore_range_check_inputs_outputs: false,
|
||||
epsilon: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap();
|
||||
@@ -168,7 +168,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let a = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
@@ -188,7 +188,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap();
|
||||
@@ -196,7 +196,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let a = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9]),
|
||||
@@ -216,7 +216,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap();
|
||||
@@ -224,7 +224,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn trilu<T: TensorType + std::marker::Send + std::marker::Sync>(
|
||||
a: &Tensor<T>,
|
||||
@@ -1859,14 +1859,14 @@ pub mod nonlinearities {
|
||||
/// Some(&[4, 25, 8, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = rsqrt(&x, 1.0);
|
||||
/// let result = rsqrt(&x, 1.0, f64::EPSILON);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64, eps: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
let fout = scale_input / (kix.sqrt() + f64::EPSILON);
|
||||
let fout = scale_input / (kix.sqrt() + eps);
|
||||
let rounded = fout.round();
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
@@ -2339,15 +2339,20 @@ pub mod nonlinearities {
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2_f64;
|
||||
/// let result = recip(&x, 1.0, k);
|
||||
/// let result = recip(&x, 1.0, k, f64::EPSILON);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn recip(a: &Tensor<IntegerRep>, input_scale: f64, out_scale: f64) -> Tensor<IntegerRep> {
|
||||
pub fn recip(
|
||||
a: &Tensor<IntegerRep>,
|
||||
input_scale: f64,
|
||||
out_scale: f64,
|
||||
eps: f64,
|
||||
) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rescaled = (a_i as f64) / input_scale;
|
||||
let denom = if rescaled == 0_f64 {
|
||||
(1_f64) / (rescaled + f64::EPSILON)
|
||||
(1_f64) / (rescaled + eps)
|
||||
} else {
|
||||
(1_f64) / (rescaled)
|
||||
};
|
||||
@@ -2366,16 +2371,16 @@ pub mod nonlinearities {
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
|
||||
/// let k = 2_f64;
|
||||
/// let result = zero_recip(1.0);
|
||||
/// let result = zero_recip(1.0, f64::EPSILON);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4503599627370496]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn zero_recip(out_scale: f64) -> Tensor<IntegerRep> {
|
||||
pub fn zero_recip(out_scale: f64, eps: f64) -> Tensor<IntegerRep> {
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[0]), &[1]).unwrap();
|
||||
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rescaled = a_i as f64;
|
||||
let denom = (1_f64) / (rescaled + f64::EPSILON);
|
||||
let denom = (1_f64) / (rescaled + eps);
|
||||
let d_inv_x = out_scale * denom;
|
||||
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
|
||||
})
|
||||
|
||||
Binary file not shown.
@@ -272,16 +272,6 @@ mod py_tests {
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn postgres_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
let test_dir: TempDir = TempDir::new("mean_postgres").unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::py_tests::mv_test_(path, "mean_postgres.ipynb");
|
||||
run_notebook(path, "mean_postgres.ipynb");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tictactoe_autoencoder_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
|
||||
@@ -479,15 +479,15 @@ async def test_deploy_evm_reusable_and_vka():
|
||||
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path_verifier,
|
||||
sol_code_path,
|
||||
anvil_url,
|
||||
sol_code_path,
|
||||
"verifier/reusable",
|
||||
)
|
||||
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path_vk,
|
||||
vk_code_path,
|
||||
anvil_url,
|
||||
vk_code_path,
|
||||
"vka",
|
||||
)
|
||||
|
||||
@@ -506,8 +506,8 @@ async def test_deploy_evm():
|
||||
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
sol_code_path,
|
||||
anvil_url,
|
||||
sol_code_path,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
@@ -528,8 +528,8 @@ async def test_deploy_evm_with_private_key():
|
||||
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
anvil_url,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
private_key=anvil_default_private_key
|
||||
)
|
||||
|
||||
@@ -540,8 +540,8 @@ async def test_deploy_evm_with_private_key():
|
||||
with pytest.raises(RuntimeError, match="Failed to run deploy_evm"):
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
anvil_url,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
private_key=custom_zero_balance_private_key
|
||||
)
|
||||
|
||||
@@ -564,8 +564,8 @@ async def test_verify_evm():
|
||||
|
||||
res = await ezkl.verify_evm(
|
||||
addr,
|
||||
anvil_url,
|
||||
proof_path,
|
||||
rpc_url=anvil_url,
|
||||
# sol_code_path
|
||||
# optimizer_runs
|
||||
)
|
||||
@@ -604,8 +604,8 @@ async def test_verify_evm_separate_vk():
|
||||
|
||||
res = await ezkl.verify_evm(
|
||||
addr_verifier,
|
||||
anvil_url,
|
||||
proof_path,
|
||||
rpc_url=anvil_url,
|
||||
addr_vk=addr_vk,
|
||||
# sol_code_path
|
||||
# optimizer_runs
|
||||
@@ -831,8 +831,8 @@ async def test_evm_aggregate_and_verify_aggr():
|
||||
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
anvil_url,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
)
|
||||
|
||||
# as a sanity check
|
||||
|
||||
Reference in New Issue
Block a user