Compare commits

...

14 Commits

Author SHA1 Message Date
github-actions[bot]
1066b1ebcf ci: update version string in docs 2025-06-21 17:24:45 +00:00
dante
e81d93a73a chore: display calibration fail reasons (#983) 2025-06-21 19:24:30 +02:00
dante
40ce9dfde9 chore: rm lots of clones (#980) 2025-05-26 10:54:09 -04:00
dante
839030ce10 chore: rm halo2proofs patches (#976) 2025-04-29 10:58:35 -04:00
dante
cfccc5460c refactor: rm postgres (#977) 2025-04-29 08:59:14 -04:00
dante
0de0682bfa refactor: configurable div epsilon (#968) 2025-04-23 09:12:24 +01:00
dante
bf9cf14ab7 refactor!: rpc url should be required (#965)
BREAKING CHANGE: in python the order of arguments for evm related functions has changed
2025-04-22 12:45:36 +01:00
dante
6818962ac2 chore: pass in raw data for gen-witness from file (#964) 2025-04-06 14:08:11 -04:00
dante
70469e3bf9 chore: add min/max to gen-random-data (#960) 2025-03-25 19:32:15 +00:00
dante
52ff187e55 refactor: command struct names should match str (#959) 2025-03-24 12:54:43 +00:00
dante
4e57a5a486 docs: link to audit (#958)
---------

Co-authored-by: Jason Morton <jason.morton@gmail.com>
2025-03-23 21:12:44 +00:00
Ethan Cemer
fe978caa85 fix!: bug fixes (#956)
BREAKING CHANGE: DA verifier no longer backwards compatible
2025-03-18 22:08:29 +00:00
dante
1bef92407c fix: recip denom epsilon can induce non opt res (#957) 2025-03-17 14:46:33 +00:00
dante
5ff1c48ede refactor: allow for negative stride downsample (#955) 2025-03-14 12:07:37 -04:00
77 changed files with 42076 additions and 5475 deletions

View File

@@ -258,7 +258,7 @@ jobs:
- name: Install built wheel
if: matrix.target == 'x86_64-unknown-linux-musl'
uses: addnab/docker-run-action@v3
uses: addnab/docker-run-action@3e77f186b7a929ef010f183a9e24c0f9955ea609
with:
image: alpine:latest
options: -v ${{ github.workspace }}:/io -w /io
@@ -380,7 +380,7 @@ jobs:
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
uses: dfm/rtds-action@618148c547f4b56cdf4fa4dcf3a94c91ce025f2d
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}

View File

@@ -198,15 +198,15 @@ jobs:
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal,mimalloc
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'

View File

@@ -33,7 +33,7 @@ jobs:
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3
with:
crate: cargo-nextest
locked: true
@@ -233,7 +233,7 @@ jobs:
with:
# Pin to version 0.12.1
version: "v0.12.1"
- uses: nanasess/setup-chromedriver@e93e57b843c0c92788f22483f1a31af8ee48db25 #v2.3.0
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
@@ -353,9 +353,6 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
- name: Build @ezkljs/verify package
run: |
cd in-browser-evm-verifier
@@ -478,9 +475,6 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
@@ -725,24 +719,6 @@ jobs:
permissions:
contents: read
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
@@ -779,8 +755,6 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
- name: Felt conversion
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
@@ -875,4 +849,4 @@ jobs:
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder

3
.gitignore vendored
View File

@@ -9,6 +9,7 @@ pkg
!AttestData.sol
!VerifierBase.sol
!LoadInstances.sol
!AttestData.t.sol
*.pf
*.vk
*.pk
@@ -49,3 +50,5 @@ timingData.json
!tests/assets/vk.key
docs/python/build
!tests/assets/vk_aggr.key
cache
out

2589
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
[package]
name = "ezkl"
version = "0.0.0"
edition = "2024"
edition = "2021"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -69,20 +69,18 @@ reqwest = { version = "0.12.4", default-features = false, features = [
"stream",
], optional = true }
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.23.2", features = [
pyo3 = { version = "0.24.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default-features = false, optional = true }
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.24.0", features = [
"attributes",
"tokio-runtime",
], default-features = false, optional = true }
@@ -90,9 +88,9 @@ pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
jemallocator = { version = "0.5", optional = true }
mimalloc = { version = "0.1", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
@@ -242,12 +240,9 @@ ezkl = [
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:clap_complete",
@@ -274,22 +269,19 @@ no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
jemalloc = ["dep:jemallocator"]
mimalloc = ["dep:mimalloc"]
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
#panic = "abort"
# panic = "abort"
[profile.test-runs]
@@ -297,8 +289,4 @@ inherits = "dev"
opt-level = 3
[package.metadata.wasm-pack.profile.release]
wasm-opt = [
"-O4",
"--flexible-inline-max-function-size",
"4294967295",
]
wasm-opt = ["-O4", "--flexible-inline-max-function-size", "4294967295"]

View File

@@ -43,7 +43,7 @@ The generated proofs can then be verified with much less computational resources
----------------------
### getting started ⚙️
### Getting Started ⚙️
The easiest way to get started is to try out a notebook.
@@ -76,12 +76,12 @@ For more details visit the [docs](https://docs.ezkl.xyz). The CLI is faster than
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
#### In-browser EVM verifier
#### In-browser EVM Verifier
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
### building the project 🔨
### Building the Project 🔨
#### Rust CLI
@@ -96,7 +96,7 @@ cargo install --locked --path .
#### building python bindings
#### Building Python Bindings
Python bindings exists and can be built using `maturin`. You will need `rust` and `cargo` to be installed.
```bash
@@ -126,7 +126,7 @@ unset ENABLE_ICICLE_GPU
**NOTE:** Even with the above environment variable set, icicle is disabled for circuits where k <= 8. To change the value of `k` where icicle is enabled, you can set the environment variable `ICICLE_SMALL_K`.
### contributing 🌎
### Contributing 🌎
If you're interested in contributing and are unsure where to start, reach out to one of the maintainers:
@@ -144,20 +144,21 @@ More broadly:
Any contribution intentionally submitted for inclusion in the work by you shall be licensed to Zkonduit Inc. under the terms and conditions specified in the [CLA](https://github.com/zkonduit/ezkl/blob/main/cla.md), which you agree to by intentionally submitting a contribution. In particular, you have the right to submit the contribution and we can distribute it, among other terms and conditions.
### no security guarantees
Ezkl is unaudited, beta software undergoing rapid development. There may be bugs. No guarantees of security are made and it should not be relied on in production.
### Audits & Security
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
[v21.0.0](https://github.com/zkonduit/ezkl/releases/tag/v21.0.0) has been audited by Trail of Bits, the report can be found [here](https://github.com/trailofbits/publications/blob/master/reviews/2025-03-zkonduit-ezkl-securityreview.pdf).
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
### Advanced security topics
Check out `docs/advanced_security` for more advanced information on potential threat vectors.
Check out `docs/advanced_security` for more advanced information on potential threat vectors that are specific to zero-knowledge inference, quantization, and to machine learning models generally.
### No Warranty
### no warranty
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Copyright (c) 2025 Zkonduit Inc.

312
abis/DataAttestation.json Normal file
View File

@@ -0,0 +1,312 @@
[
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256[]",
"name": "_decimals",
"type": "uint256[]"
},
{
"internalType": "uint256[]",
"name": "_bits",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [],
"name": "HALF_ORDER",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "ORDER",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"name": "attestData",
"outputs": [],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "callData",
"outputs": [
{
"internalType": "bytes",
"name": "",
"type": "bytes"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "contractAddress",
"outputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "getInstancesCalldata",
"outputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "getInstancesMemory",
"outputs": [
{
"internalType": "uint256[]",
"name": "instances",
"type": "uint256[]"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "index",
"type": "uint256"
}
],
"name": "getScalars",
"outputs": [
{
"components": [
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "bits",
"type": "uint256"
}
],
"internalType": "struct DataAttestation.Scalars",
"name": "",
"type": "tuple"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
{
"internalType": "uint8",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "x",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "y",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "denominator",
"type": "uint256"
}
],
"name": "mulDiv",
"outputs": [
{
"internalType": "uint256",
"name": "result",
"type": "uint256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "int256",
"name": "x",
"type": "int256"
},
{
"components": [
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
},
{
"internalType": "uint256",
"name": "bits",
"type": "uint256"
}
],
"internalType": "struct DataAttestation.Scalars",
"name": "_scalars",
"type": "tuple"
}
],
"name": "quantizeData",
"outputs": [
{
"internalType": "int256",
"name": "quantized_data",
"type": "int256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "target",
"type": "address"
},
{
"internalType": "bytes",
"name": "data",
"type": "bytes"
}
],
"name": "staticCall",
"outputs": [
{
"internalType": "bytes",
"name": "",
"type": "bytes"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "int256",
"name": "x",
"type": "int256"
}
],
"name": "toFieldElement",
"outputs": [
{
"internalType": "uint256",
"name": "field_element",
"type": "uint256"
}
],
"stateMutability": "pure",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
}
]

View File

@@ -1,167 +0,0 @@
[
{
"inputs": [
{
"internalType": "address[]",
"name": "_contractAddresses",
"type": "address[]"
},
{
"internalType": "bytes[][]",
"name": "_callData",
"type": "bytes[][]"
},
{
"internalType": "uint256[][]",
"name": "_decimals",
"type": "uint256[][]"
},
{
"internalType": "uint256[]",
"name": "_scales",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
},
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"name": "accountCalls",
"outputs": [
{
"internalType": "address",
"name": "contractAddress",
"type": "address"
},
{
"internalType": "uint256",
"name": "callCount",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "admin",
"outputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
{
"internalType": "uint8",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"name": "scales",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "address[]",
"name": "_contractAddresses",
"type": "address[]"
},
{
"internalType": "bytes[][]",
"name": "_callData",
"type": "bytes[][]"
},
{
"internalType": "uint256[][]",
"name": "_decimals",
"type": "uint256[][]"
}
],
"name": "updateAccountCalls",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"name": "updateAdmin",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
}
]

View File

@@ -1,147 +0,0 @@
[
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "_decimals",
"type": "uint256"
},
{
"internalType": "uint256[]",
"name": "_scales",
"type": "uint256[]"
},
{
"internalType": "uint8",
"name": "_instanceOffset",
"type": "uint8"
},
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [],
"name": "accountCall",
"outputs": [
{
"internalType": "address",
"name": "contractAddress",
"type": "address"
},
{
"internalType": "bytes",
"name": "callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "decimals",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "admin",
"outputs": [
{
"internalType": "address",
"name": "",
"type": "address"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [],
"name": "instanceOffset",
"outputs": [
{
"internalType": "uint8",
"name": "",
"type": "uint8"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_contractAddresses",
"type": "address"
},
{
"internalType": "bytes",
"name": "_callData",
"type": "bytes"
},
{
"internalType": "uint256",
"name": "_decimals",
"type": "uint256"
}
],
"name": "updateAccountCalls",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "_admin",
"type": "address"
}
],
"name": "updateAdmin",
"outputs": [],
"stateMutability": "nonpayable",
"type": "function"
},
{
"inputs": [
{
"internalType": "address",
"name": "verifier",
"type": "address"
},
{
"internalType": "bytes",
"name": "encoded",
"type": "bytes"
}
],
"name": "verifyWithDataAttestation",
"outputs": [
{
"internalType": "bool",
"name": "",
"type": "bool"
}
],
"stateMutability": "view",
"type": "function"
}
]

View File

@@ -68,7 +68,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
&[&self.image, &self.kernel, &self.bias],
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -61,7 +62,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),

View File

@@ -17,6 +17,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -86,13 +87,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -17,6 +17,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -87,13 +88,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.unwrap();

View File

@@ -63,7 +63,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[self.image.clone()],
&[&self.image],
Box::new(HybridOp::SumPool {
padding: vec![(0, 0); 2],
stride: vec![1, 1],

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -57,7 +58,11 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.unwrap();
Ok(())
},

View File

@@ -16,6 +16,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -58,7 +59,11 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(4)),
)
.unwrap();
Ok(())
},

View File

@@ -70,7 +70,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,

View File

@@ -67,7 +67,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -8,21 +8,27 @@ contract LoadInstances {
*/
function getInstancesMemory(
bytes memory encoded
) internal pure returns (uint256[] memory instances) {
) public pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := mload(add(encoded, 0x20))
}
if (funcSig == 0xaf83a18d) {
instances_offset = 0x64;
} else if (funcSig == 0x1e8e1e13) {
instances_offset = 0x44;
} else {
revert("Invalid function signature");
}
assembly {
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := mload(
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
)
instances_offset := mload(add(encoded, instances_offset))
instances_length := mload(add(add(encoded, 0x24), instances_offset))
}
@@ -41,6 +47,10 @@ contract LoadInstances {
)
}
}
require(
funcSig == 0xaf83a18d || funcSig == 0x1e8e1e13,
"Invalid function signature"
);
}
/**
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
@@ -49,23 +59,31 @@ contract LoadInstances {
*/
function getInstancesCalldata(
bytes calldata encoded
) internal pure returns (uint256[] memory instances) {
) public pure returns (uint256[] memory instances) {
bytes4 funcSig;
uint256 instances_offset;
uint256 instances_length;
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
}
if (funcSig == 0xaf83a18d) {
instances_offset = 0x44;
} else if (funcSig == 0x1e8e1e13) {
instances_offset = 0x24;
} else {
revert("Invalid function signature");
}
// We need to create a new assembly block in order for solidity
// to cast the funcSig to a bytes4 type. Otherwise it will load the entire first 32 bytes of the calldata
// within the block
assembly {
// Fetch instances offset which is 4 + 32 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
instances_offset := calldataload(
add(
encoded.offset,
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
add(encoded.offset, instances_offset)
)
instances_length := calldataload(
@@ -96,7 +114,7 @@ contract LoadInstances {
// The kzg commitments of a given model, all aggregated into a single bytes array.
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
// It will be used to check that the proof commitments match the expected commitments.
bytes constant COMMITMENT_KZG = hex"";
bytes constant COMMITMENT_KZG = hex"1234";
contract SwapProofCommitments {
/**
@@ -113,17 +131,20 @@ contract SwapProofCommitments {
assembly {
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
funcSig := calldataload(encoded.offset)
}
if (funcSig == 0xaf83a18d) {
proof_offset = 0x24;
} else if (funcSig == 0x1e8e1e13) {
proof_offset = 0x04;
} else {
revert("Invalid function signature");
}
assembly {
// Fetch proof offset which is 4 + 32 bytes away from
// start of encoded for `verifyProof(bytes,uint256[])`,
// and 4 + 32 + 32 away for `verifyProof(address,bytes,uint256[])`
proof_offset := calldataload(
add(
encoded.offset,
add(0x04, mul(0x20, eq(funcSig, 0xaf83a18d)))
)
)
proof_offset := calldataload(add(encoded.offset, proof_offset))
proof_length := calldataload(
add(add(encoded.offset, 0x04), proof_offset)
@@ -154,7 +175,7 @@ contract SwapProofCommitments {
let wordCommitment := mload(add(commitment, i))
equal := eq(wordProof, wordCommitment)
if eq(equal, 0) {
return(0, 0)
break
}
}
}
@@ -163,36 +184,38 @@ contract SwapProofCommitments {
} /// end checkKzgCommits
}
contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only call to account to fetch the data that EZKL reads from.
* @param the address of the account to make calls to
* @param the abi encoded function calls to make to the `contractAddress`
*/
struct AccountCall {
address contractAddress;
bytes callData;
contract DataAttestation is LoadInstances, SwapProofCommitments {
// the address of the account to make calls to
address public immutable contractAddress;
// the abi encoded function calls to make to the `contractAddress` that returns the attested to data
bytes public callData;
struct Scalars {
// The number of base 10 decimals to scale the data by.
// For most ERC20 tokens this is 1e18
uint256 decimals;
// The number of fractional bits of the fixed point EZKL data points.
uint256 bits;
}
AccountCall public accountCall;
uint[] scales;
Scalars[] private scalars;
address public admin;
function getScalars(uint256 index) public view returns (Scalars memory) {
return scalars[index];
}
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER =
uint256 public constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 constant INPUT_LEN = 0;
uint256 constant OUTPUT_LEN = 0;
uint256 public constant HALF_ORDER = ORDER >> 1;
uint8 public instanceOffset;
@@ -204,53 +227,27 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
constructor(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals,
uint[] memory _scales,
uint8 _instanceOffset,
address _admin
uint256[] memory _decimals,
uint[] memory _bits,
uint8 _instanceOffset
) {
admin = _admin;
for (uint i; i < _scales.length; i++) {
scales.push(1 << _scales[i]);
require(
_bits.length == _decimals.length,
"Invalid scalar array lengths"
);
for (uint i; i < _bits.length; i++) {
scalars.push(Scalars(10 ** _decimals[i], 1 << _bits[i]));
}
populateAccountCalls(_contractAddresses, _callData, _decimals);
contractAddress = _contractAddresses;
callData = _callData;
instanceOffset = _instanceOffset;
}
function updateAdmin(address _admin) external {
require(msg.sender == admin, "Only admin can update admin");
if (_admin == address(0)) {
revert();
}
admin = _admin;
}
function updateAccountCalls(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals
) external {
require(msg.sender == admin, "Only admin can update account calls");
populateAccountCalls(_contractAddresses, _callData, _decimals);
}
function populateAccountCalls(
address _contractAddresses,
bytes memory _callData,
uint256 _decimals
) internal {
AccountCall memory _accountCall = accountCall;
_accountCall.contractAddress = _contractAddresses;
_accountCall.callData = _callData;
_accountCall.decimals = 10 ** _decimals;
accountCall = _accountCall;
}
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) internal pure returns (uint256 result) {
) public pure returns (uint256 result) {
unchecked {
uint256 prod0;
uint256 prod1;
@@ -298,21 +295,28 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
/**
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param x - One of the elements of the data returned from the account calls
* @param _decimals - Number of base 10 decimals to scale the data by.
* @param _scale - The base 2 scale used to convert the floating point value into a fixed point value.
* @param _scalars - The scaling factors for the data returned from the account calls.
*
*/
function quantizeData(
int x,
uint256 _decimals,
uint256 _scale
) internal pure returns (int256 quantized_data) {
Scalars memory _scalars
) public pure returns (int256 quantized_data) {
if (_scalars.bits == 1 && _scalars.decimals == 1) {
return x;
}
bool neg = x < 0;
if (neg) x = -x;
uint output = mulDiv(uint256(x), _scale, _decimals);
if (mulmod(uint256(x), _scale, _decimals) * 2 >= _decimals) {
uint output = mulDiv(uint256(x), _scalars.bits, _scalars.decimals);
if (
mulmod(uint256(x), _scalars.bits, _scalars.decimals) * 2 >=
_scalars.decimals
) {
output += 1;
}
if (output > HALF_ORDER) {
revert("Overflow field modulus");
}
quantized_data = neg ? -int256(output) : int256(output);
}
/**
@@ -324,7 +328,7 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
function staticCall(
address target,
bytes memory data
) internal view returns (bytes memory) {
) public view returns (bytes memory) {
(bool success, bytes memory returndata) = target.staticcall(data);
if (success) {
if (returndata.length == 0) {
@@ -345,7 +349,7 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
*/
function toFieldElement(
int256 x
) internal pure returns (uint256 field_element) {
) public pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
@@ -355,315 +359,16 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
*/
function attestData(uint256[] memory instances) internal view {
require(
instances.length >= INPUT_LEN + OUTPUT_LEN,
"Invalid public inputs length"
);
AccountCall memory _accountCall = accountCall;
uint[] memory _scales = scales;
bytes memory returnData = staticCall(
_accountCall.contractAddress,
_accountCall.callData
);
function attestData(uint256[] memory instances) public view {
bytes memory returnData = staticCall(contractAddress, callData);
int256[] memory x = abi.decode(returnData, (int256[]));
uint _offset;
int output = quantizeData(x[0], _accountCall.decimals, _scales[0]);
uint field_element = toFieldElement(output);
int output;
uint fieldElement;
for (uint i = 0; i < x.length; i++) {
if (field_element != instances[i + instanceOffset]) {
_offset += 1;
} else {
break;
}
}
uint length = x.length - _offset;
for (uint i = 1; i < length; i++) {
output = quantizeData(x[i], _accountCall.decimals, _scales[i]);
field_element = toFieldElement(output);
require(
field_element == instances[i + instanceOffset + _offset],
"Public input does not match"
);
}
}
/**
* @dev Verify the proof with the data attestation.
* @param verifier - The address of the verifier contract.
* @param encoded - The verifier calldata.
*/
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
attestData(getInstancesCalldata(encoded));
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {
revert("low-level call to verifier failed");
}
}
}
// This contract serves as a Data Attestation Verifier for the EZKL model.
// It is designed to read and attest to instances of proofs generated from a specified circuit.
// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions.
// Overview of the contract functionality:
// 1. Initialization: Through the constructor, it sets up the contract calls that the EZKL model will read from.
// 2. Data Quantization: Quantizes the returned data into a scaled fixed-point representation. See the `quantizeData` method for details.
// 3. Static Calls: Makes static calls to fetch data from other contracts. See the `staticCall` method.
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
// 6. Proof Verification: The `verifyWithDataAttestationMulti` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
// then calls the `verifyProof` method to verify the proof on the verifier.
contract DataAttestationMulti is LoadInstances, SwapProofCommitments {
/**
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
* @param the address of the account to make calls to
* @param the abi encoded function calls to make to the `contractAddress`
*/
struct AccountCall {
address contractAddress;
mapping(uint256 => bytes) callData;
mapping(uint256 => uint256) decimals;
uint callCount;
}
AccountCall[] public accountCalls;
uint[] public scales;
address public admin;
/**
* @notice EZKL P value
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
*/
uint256 constant ORDER =
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
);
uint256 constant INPUT_CALLS = 0;
uint256 constant OUTPUT_CALLS = 0;
uint8 public instanceOffset;
/**
* @dev Initialize the contract with account calls the EZKL model will read from.
* @param _contractAddresses - The calls to all the contracts EZKL reads storage from.
* @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
*/
constructor(
address[] memory _contractAddresses,
bytes[][] memory _callData,
uint256[][] memory _decimals,
uint[] memory _scales,
uint8 _instanceOffset,
address _admin
) {
admin = _admin;
for (uint i; i < _scales.length; i++) {
scales.push(1 << _scales[i]);
}
populateAccountCalls(_contractAddresses, _callData, _decimals);
instanceOffset = _instanceOffset;
}
function updateAdmin(address _admin) external {
require(msg.sender == admin, "Only admin can update admin");
if (_admin == address(0)) {
revert();
}
admin = _admin;
}
function updateAccountCalls(
address[] memory _contractAddresses,
bytes[][] memory _callData,
uint256[][] memory _decimals
) external {
require(msg.sender == admin, "Only admin can update account calls");
populateAccountCalls(_contractAddresses, _callData, _decimals);
}
function populateAccountCalls(
address[] memory _contractAddresses,
bytes[][] memory _callData,
uint256[][] memory _decimals
) internal {
require(
_contractAddresses.length == _callData.length &&
accountCalls.length == _contractAddresses.length,
"Invalid input length"
);
require(
_decimals.length == _contractAddresses.length,
"Invalid number of decimals"
);
// fill in the accountCalls storage array
uint counter = 0;
for (uint256 i = 0; i < _contractAddresses.length; i++) {
AccountCall storage accountCall = accountCalls[i];
accountCall.contractAddress = _contractAddresses[i];
accountCall.callCount = _callData[i].length;
for (uint256 j = 0; j < _callData[i].length; j++) {
accountCall.callData[j] = _callData[i][j];
accountCall.decimals[j] = 10 ** _decimals[i][j];
}
// count the total number of storage reads across all of the accounts
counter += _callData[i].length;
}
require(
counter == INPUT_CALLS + OUTPUT_CALLS,
"Invalid number of calls"
);
}
function mulDiv(
uint256 x,
uint256 y,
uint256 denominator
) internal pure returns (uint256 result) {
unchecked {
uint256 prod0;
uint256 prod1;
assembly {
let mm := mulmod(x, y, not(0))
prod0 := mul(x, y)
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
if (prod1 == 0) {
return prod0 / denominator;
}
require(denominator > prod1, "Math: mulDiv overflow");
uint256 remainder;
assembly {
remainder := mulmod(x, y, denominator)
prod1 := sub(prod1, gt(remainder, prod0))
prod0 := sub(prod0, remainder)
}
uint256 twos = denominator & (~denominator + 1);
assembly {
denominator := div(denominator, twos)
prod0 := div(prod0, twos)
twos := add(div(sub(0, twos), twos), 1)
}
prod0 |= prod1 * twos;
uint256 inverse = (3 * denominator) ^ 2;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
result = prod0 * inverse;
return result;
}
}
/**
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
* @param data - The data returned from the account calls.
* @param decimals - The number of decimals the data returned from the account calls has (for floating point representation).
* @param scale - The scale used to convert the floating point value into a fixed point value.
*/
function quantizeData(
bytes memory data,
uint256 decimals,
uint256 scale
) internal pure returns (int256 quantized_data) {
int x = abi.decode(data, (int256));
bool neg = x < 0;
if (neg) x = -x;
uint output = mulDiv(uint256(x), scale, decimals);
if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) {
output += 1;
}
quantized_data = neg ? -int256(output) : int256(output);
}
/**
* @dev Make a static call to the account to fetch the data that EZKL reads from.
* @param target - The address of the account to make calls to.
* @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
* @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise)
*/
function staticCall(
address target,
bytes memory data
) internal view returns (bytes memory) {
(bool success, bytes memory returndata) = target.staticcall(data);
if (success) {
if (returndata.length == 0) {
require(
target.code.length > 0,
"Address: call to non-contract"
);
}
return returndata;
} else {
revert("Address: low-level call failed");
}
}
/**
* @dev Convert the fixed point quantized data into a field element.
* @param x - The quantized data.
* @return field_element - The field element.
*/
function toFieldElement(
int256 x
) internal pure returns (uint256 field_element) {
// The casting down to uint256 is safe because the order is about 2^254, and the value
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
return uint256(x + int(ORDER)) % ORDER;
}
/**
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
*/
function attestData(uint256[] memory instances) internal view {
require(
instances.length >= INPUT_CALLS + OUTPUT_CALLS,
"Invalid public inputs length"
);
uint256 _accountCount = accountCalls.length;
uint counter = 0;
for (uint8 i = 0; i < _accountCount; ++i) {
address account = accountCalls[i].contractAddress;
for (uint8 j = 0; j < accountCalls[i].callCount; j++) {
bytes memory returnData = staticCall(
account,
accountCalls[i].callData[j]
);
uint256 scale = scales[counter];
int256 quantized_data = quantizeData(
returnData,
accountCalls[i].decimals[j],
scale
);
uint256 field_element = toFieldElement(quantized_data);
require(
field_element == instances[counter + instanceOffset],
"Public input does not match"
);
counter++;
output = quantizeData(x[i], scalars[i]);
fieldElement = toFieldElement(output);
if (fieldElement != instances[i]) {
revert("Public input does not match");
}
}
}

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '22.0.5'
version = release

View File

@@ -32,7 +32,6 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -216,11 +215,7 @@ where
.layer_config
.layout(
&mut region,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
&[&self.input, &self.l0_params[0], &self.l0_params[1]],
Box::new(op),
)
.unwrap();
@@ -229,7 +224,7 @@ where
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -241,7 +236,7 @@ where
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(LookupOp::Div { denom: 32.0.into() }),
)
.unwrap()
@@ -253,7 +248,7 @@ where
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone(), x],
&[&self.l2_params[0], &x],
Box::new(PolyOp::Einsum {
equation: "ij,j->ik".to_string(),
}),
@@ -265,7 +260,7 @@ where
.layer_config
.layout(
&mut region,
&[x, self.l2_params[1].clone()],
&[&x, &self.l2_params[1]],
Box::new(PolyOp::Add),
)
.unwrap()

View File

@@ -117,10 +117,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[
self.l0_params[0].clone().try_into().unwrap(),
self.input.clone(),
],
&[&self.l0_params[0].clone().try_into().unwrap(), &self.input],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -135,7 +132,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x, self.l0_params[1].clone().try_into().unwrap()],
&[&x, &self.l0_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -147,7 +144,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x],
&[&x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -163,7 +160,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone().try_into().unwrap(), x],
&[&self.l2_params[0].clone().try_into().unwrap(), &x],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -178,7 +175,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x, self.l2_params[1].clone().try_into().unwrap()],
&[&x, &self.l2_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -190,7 +187,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x],
&[&x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -203,7 +200,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(LookupOp::Div {
denom: ezkl::circuit::utils::F32::from(128.),
}),

View File

@@ -1088,7 +1088,7 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" rpc_url='http://127.0.0.1:3030'\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",

View File

@@ -272,33 +272,21 @@
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like:\n",
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": [\n",
" {\n",
" \"call_data\": [\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
" 5\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
" 5\n",
" ]\n",
" ],\n",
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
" }\n",
" ]\n",
" }\n",
"}"
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": {\n",
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
" \"len\": 12 // The number of data points returned by the view function (the length of the array)\n",
" }\n",
" }\n",
"}\n",
"```"
]
},
{
@@ -307,7 +295,7 @@
"metadata": {},
"outputs": [],
"source": [
"await ezkl.setup_test_evm_witness(\n",
"await ezkl.setup_test_evm_data(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
@@ -484,8 +472,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -538,9 +526,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -569,8 +557,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -578,7 +566,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"display_name": ".env",
"language": "python",
"name": "python3"
},
@@ -592,7 +580,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.9"
},
"orig_nbformat": 4
},

View File

@@ -337,6 +337,7 @@
"w3 = Web3(HTTPProvider(RPC_URL))\n",
"\n",
"def test_on_chain_data(res):\n",
" print(f'poseidon_hash: {res[\"processed_outputs\"][\"poseidon_hash\"]}')\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = [int(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
"\n",
@@ -356,6 +357,9 @@
" arr.push(_numbers[i]);\n",
" }\n",
" }\n",
" function getArr() public view returns (uint[] memory) {\n",
" return arr;\n",
" }\n",
" }\n",
" '''\n",
"\n",
@@ -382,31 +386,30 @@
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
"\n",
" # Step 4: Interact with the contract\n",
" calldata = []\n",
" for i, _ in enumerate(data):\n",
" call = contract.functions.arr(i).build_transaction()\n",
" calldata.append((call['data'][2:], 0))\n",
" calldata = contract.functions.getArr().build_transaction()['data'][2:]\n",
"\n",
" # Prepare the calls_to_account object\n",
" # If you were calling view functions across multiple contracts,\n",
" # you would have multiple entries in the calls_to_account array,\n",
" # one for each contract.\n",
" calls_to_account = [{\n",
" decimals = [0] * len(data)\n",
" call_to_account = {\n",
" 'call_data': calldata,\n",
" 'decimals': decimals,\n",
" 'address': contract.address[2:], # remove the '0x' prefix\n",
" }]\n",
" }\n",
"\n",
" print(f'calls_to_account: {calls_to_account}')\n",
" print(f'call_to_account: {call_to_account}')\n",
"\n",
" return calls_to_account\n",
" return call_to_account\n",
"\n",
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
"start_anvil()\n",
"\n",
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
"calls_to_account = test_on_chain_data(res)\n",
"call_to_account = test_on_chain_data(res)\n",
"\n",
"data = dict(input_data = [data_array], output_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"data = dict(input_data = [data_array], output_data = {'rpc': RPC_URL, 'call': call_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
@@ -540,8 +543,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -594,9 +597,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -625,8 +628,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -634,7 +637,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"display_name": ".env",
"language": "python",
"name": "python3"
},
@@ -648,7 +651,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.9"
},
"orig_nbformat": 4
},

View File

@@ -276,33 +276,21 @@
"\n",
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
"Here is what the schema for an on-chain data source graph input file should look like:\n",
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
" \n",
"```json\n",
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": [\n",
" {\n",
" \"call_data\": [\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
" 5\n",
" ],\n",
" [\n",
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
" 5\n",
" ]\n",
" ],\n",
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
" }\n",
" ]\n",
" }\n",
"}"
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": {\n",
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
" \"len\": 3 // The number of data points returned by the view function (the length of the array)\n",
" }\n",
" }\n",
"}\n",
"```"
]
},
{
@@ -311,7 +299,7 @@
"metadata": {},
"outputs": [],
"source": [
"await ezkl.setup_test_evm_witness(\n",
"await ezkl.setup_test_evm_data(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
@@ -337,7 +325,7 @@
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.get_srs( settings_path)\n"
"res = await ezkl.get_srs( settings_path)"
]
},
{
@@ -348,27 +336,6 @@
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -391,6 +358,27 @@
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!export RUST_BACKTRACE=1\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
]
},
{
"attachments": {},
"cell_type": "markdown",
@@ -486,8 +474,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -541,9 +529,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -572,8 +560,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -581,7 +569,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"display_name": ".env",
"language": "python",
"name": "python3"
},
@@ -595,7 +583,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.11.5"
},
"orig_nbformat": 4
},

View File

@@ -453,8 +453,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -474,8 +474,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -510,4 +510,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -462,8 +462,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -483,8 +483,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -512,4 +512,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -1,462 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.decomp_legs = 4\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True\n",
"\n",
"await ezkl.get_srs(settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -504,8 +504,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -527,8 +527,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path\n",
")\n",
"assert res == True"
]
@@ -558,4 +558,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -220,15 +220,6 @@
"Check that the generated verifiers are identical for all models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"start_anvil()"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -270,8 +261,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" \"verifier/reusable\"\n",
")\n",
"\n",
@@ -297,7 +288,7 @@
"for name in names:\n",
" addr_path_vk = \"addr_vk.txt\"\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
" res = await ezkl.deploy_evm(addr_path_vk, 'http://127.0.0.1:3030', sol_key_code_path, \"vka\")\n",
" assert res == True\n",
"\n",
" with open(addr_path_vk, 'r') as file:\n",
@@ -307,8 +298,8 @@
" sol_code_path = os.path.join(name, 'vk.sol')\n",
" res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" addr_vk = addr_vk\n",
" )\n",
" assert res == True"

View File

@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -94,7 +94,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -134,7 +134,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -183,7 +183,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -201,7 +201,7 @@
"run_args.input_visibility = \"public\"\n",
"run_args.param_visibility = \"private\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.decomp_legs=6\n",
"run_args.decomp_legs=5\n",
"run_args.num_inner_cols = 1\n",
"run_args.variables = [(\"batch_size\", 1)]"
]
@@ -270,7 +270,7 @@
"{\n",
" \"input_data\": {\n",
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
" \"calls\": {\n",
" \"call\": {\n",
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
@@ -295,7 +295,6 @@
"import torch\n",
"import requests\n",
"\n",
"# This function counts the decimal places of a floating point number\n",
"def count_decimal_places(num):\n",
" num_str = str(num)\n",
" if '.' in num_str:\n",
@@ -303,69 +302,28 @@
" else:\n",
" return 0\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"def set_next_block_timestamp(anvil_url, timestamp):\n",
" # Send the JSON-RPC request to Anvil\n",
" payload = {\n",
" \"jsonrpc\": \"2.0\",\n",
" \"id\": 1,\n",
" \"method\": \"evm_setNextBlockTimestamp\",\n",
" \"params\": [timestamp]\n",
" }\n",
" response = requests.post(anvil_url, json=payload)\n",
" if response.status_code == 200:\n",
" print(f\"Next block timestamp set to: {timestamp}\")\n",
" else:\n",
" print(f\"Failed to set next block timestamp: {response.text}\")\n",
"\n",
"def on_chain_data(tensor):\n",
" # Step 0: Convert the tensor to a flat list\n",
" data = tensor.view(-1).tolist()\n",
"\n",
" # Step 1: Prepare the calldata\n",
" secondsAgo = [len(data) - 1 - i for i in range(len(data))]\n",
"\n",
" # Step 2: Prepare and compile the contract UniTickAttestor contract\n",
" contract_source_code = '''\n",
" // SPDX-License-Identifier: MIT\n",
" pragma solidity ^0.8.20;\n",
"\n",
" /// @title Pool state that is not stored\n",
" /// @notice Contains view functions to provide information about the pool that is computed rather than stored on the\n",
" /// blockchain. The functions here may have variable gas costs.\n",
" interface IUniswapV3PoolDerivedState {\n",
" /// @notice Returns the cumulative tick and liquidity as of each timestamp `secondsAgo` from the current block timestamp\n",
" /// @dev To get a time weighted average tick or liquidity-in-range, you must call this with two values, one representing\n",
" /// the beginning of the period and another for the end of the period. E.g., to get the last hour time-weighted average tick,\n",
" /// you must call it with secondsAgos = [3600, 0].\n",
" /// log base sqrt(1.0001) of token1 / token0. The TickMath library can be used to go from a tick value to a ratio.\n",
" /// @dev The time weighted average tick represents the geometric time weighted average price of the pool, in\n",
" /// @param secondsAgos From how long ago each cumulative tick and liquidity value should be returned\n",
" /// @return tickCumulatives Cumulative tick values as of each `secondsAgos` from the current block timestamp\n",
" /// @return secondsPerLiquidityCumulativeX128s Cumulative seconds per liquidity-in-range value as of each `secondsAgos` from the current block\n",
" /// timestamp\n",
" function observe(\n",
" uint32[] calldata secondsAgos\n",
" )\n",
" external\n",
" view\n",
" returns (\n",
" int56[] memory tickCumulatives,\n",
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
" );\n",
" ) external view returns (\n",
" int56[] memory tickCumulatives,\n",
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
" );\n",
" }\n",
"\n",
" /// @title Uniswap Wrapper around `pool.observe` that stores the parameters for fetching and then attesting to historical data\n",
" /// @notice Provides functions to integrate with V3 pool oracle\n",
" contract UniTickAttestor {\n",
" /**\n",
" * @notice Calculates time-weighted means of tick and liquidity for a given Uniswap V3 pool\n",
" * @param pool Address of the pool that we want to observe\n",
" * @param secondsAgo Number of seconds in the past from which to calculate the time-weighted means\n",
" * @return tickCumulatives The cumulative tick values as of each `secondsAgo` from the current block timestamp\n",
" */\n",
" int256[] private cachedTicks;\n",
"\n",
" function consult(\n",
" IUniswapV3PoolDerivedState pool,\n",
" uint32[] memory secondsAgo\n",
@@ -376,6 +334,21 @@
" tickCumulatives[i] = int256(_ticks[i]);\n",
" }\n",
" }\n",
"\n",
" function cache_price(\n",
" IUniswapV3PoolDerivedState pool,\n",
" uint32[] memory secondsAgo\n",
" ) public {\n",
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
" cachedTicks = new int256[](_ticks.length);\n",
" for (uint256 i = 0; i < _ticks.length; i++) {\n",
" cachedTicks[i] = int256(_ticks[i]);\n",
" }\n",
" }\n",
"\n",
" function readPriceCache() public view returns (int256[] memory) {\n",
" return cachedTicks;\n",
" }\n",
" }\n",
" '''\n",
"\n",
@@ -385,69 +358,44 @@
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
" })\n",
"\n",
" # Get bytecode\n",
" bytecode = compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['evm']['bytecode']['object']\n",
"\n",
" # Get ABI\n",
" # In production if you are reading from really large contracts you can just use\n",
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
" abi = json.loads(compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['metadata'])['output']['abi']\n",
"\n",
" # Step 3: Deploy the contract\n",
" # Deploy contract\n",
" UniTickAttestor = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
" tx_hash = UniTickAttestor.constructor().transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
" # passing the address and abi of the contract you are fetching data from.\n",
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
"\n",
" # Step 4: Interact with the contract\n",
" call = contract.functions.consult(\n",
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
" # Step 4: Store data via cache_price transaction\n",
" tx_hash = contract.functions.cache_price(\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).build_transaction()\n",
" result = contract.functions.consult(\n",
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
" secondsAgo\n",
" ).call()\n",
" \n",
" print(f'result: {result}')\n",
" ).transact()\n",
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
"\n",
" # Step 5: Prepare calldata for readPriceCache\n",
" call = contract.functions.readPriceCache().build_transaction()\n",
" calldata = call['data'][2:]\n",
"\n",
" time_stamp = w3.eth.get_block('latest')['timestamp']\n",
" # Get stored data\n",
" result = contract.functions.readPriceCache().call()\n",
" print(f'Cached ticks: {result}')\n",
"\n",
" print(f'time_stamp: {time_stamp}')\n",
" decimals = [0] * len(data)\n",
"\n",
" # Set the next block timestamp using the fetched time_stamp\n",
" set_next_block_timestamp(RPC_URL, time_stamp)\n",
"\n",
"\n",
" # Prepare the calls_to_account object\n",
" # If you were calling view functions across multiple contracts,\n",
" # you would have multiple entries in the calls_to_account array,\n",
" # one for each contract.\n",
" call_to_account = {\n",
" 'call_data': calldata,\n",
" 'decimals': 0,\n",
" 'address': contract.address[2:], # remove the '0x' prefix\n",
" 'len': len(data),\n",
" 'decimals': decimals,\n",
" 'address': contract.address[2:],\n",
" }\n",
"\n",
" print(f'call_to_account: {call_to_account}')\n",
"\n",
" return call_to_account\n",
"\n",
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
"start_anvil()\n",
"call_to_account = on_chain_data(x)\n",
"\n",
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
"calls_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"data = dict(input_data = {'rpc': RPC_URL, 'call': call_to_account })\n",
"json.dump(data, open(\"input.json\", 'w'))"
]
},
@@ -614,8 +562,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -668,9 +616,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -692,34 +640,7 @@
"source": [
"# !export RUST_BACKTRACE=1\n",
"\n",
"calls_to_account = on_chain_data(x)\n",
"\n",
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
"\n",
"# Serialize on-chain data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n",
"\n",
"# setup web3 instance\n",
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
"\n",
"time_stamp = w3.eth.get_block('latest')['timestamp']\n",
"\n",
"print(f'time_stamp: {time_stamp}')\n",
"\n",
"\n",
"witness_path = \"witness.json\"\n",
"\n",
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"# print(res)\n",
"assert os.path.isfile(proof_path)\n",
"# read the verifier address\n",
"addr_verifier = None\n",
@@ -732,8 +653,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]

View File

@@ -666,7 +666,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -689,8 +689,8 @@
"# await\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -701,7 +701,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -722,8 +722,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -743,7 +743,8 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
@@ -756,7 +757,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.9"
}
},
"nbformat": 4,

View File

@@ -849,8 +849,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -870,8 +870,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path\n",
")\n",
"assert res == True"
]
@@ -905,4 +905,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -246,7 +246,7 @@
"metadata": {},
"outputs": [],
"source": [
"ezkl.setup_test_evm_witness(\n",
"ezkl.setup_test_evm_data(\n",
" data_path,\n",
" compiled_model_path,\n",
" # we write the call data to the same file as the input data\n",
@@ -358,8 +358,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -374,14 +374,6 @@
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc888848",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
@@ -413,9 +405,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )"
]
},
@@ -478,8 +470,8 @@
"\n",
"res = ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -525,7 +517,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".env",
"language": "python",
"name": "python3"
},
@@ -539,7 +531,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.9"
}
},
"nbformat": 4,

Binary file not shown.

37795
examples/onnx/fr_age/lol.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -706,9 +706,9 @@ def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Pa
"""
...
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
def setup_test_evm_data(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
r"""
Setup test evm witness
Setup test evm data
Arguments
---------

View File

@@ -1,12 +1,5 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: MiMalloc = MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]

View File

@@ -206,6 +206,9 @@ struct PyRunArgs {
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
#[pyo3(get, set)]
pub ignore_range_check_inputs_outputs: bool,
/// float: epsilon used for arguments that use division
#[pyo3(get, set)]
pub epsilon: f64,
}
/// default instantiation of PyRunArgs
@@ -238,12 +241,14 @@ impl From<PyRunArgs> for RunArgs {
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
epsilon: Some(py_run_args.epsilon),
}
}
}
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
let eps = self.get_epsilon();
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
input_scale: self.input_scale,
@@ -262,6 +267,7 @@ impl Into<PyRunArgs> for RunArgs {
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
epsilon: eps,
}
}
}
@@ -962,6 +968,8 @@ fn gen_settings(
output=PathBuf::from(DEFAULT_SETTINGS),
variables=Vec::from([("batch_size".to_string(), 1)]),
seed=DEFAULT_SEED.parse().unwrap(),
min=None,
max=None
))]
#[gen_stub_pyfunction]
fn gen_random_data(
@@ -969,8 +977,10 @@ fn gen_random_data(
output: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
min: Option<f32>,
max: Option<f32>,
) -> Result<bool, PyErr> {
crate::execute::gen_random_data(model, output, variables, seed).map_err(|e| {
crate::execute::gen_random_data(model, output, variables, seed, min, max).map_err(|e| {
let err_str = format!("Failed to generate settings: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -1819,20 +1829,20 @@ fn create_evm_data_attestation(
test_data,
input_source,
output_source,
rpc_url=None,
rpc_url,
))]
#[gen_stub_pyfunction]
fn setup_test_evm_witness(
fn setup_test_evm_data(
py: Python,
data_path: String,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
input_source: PyTestDataSource,
output_source: PyTestDataSource,
rpc_url: Option<String>,
rpc_url: String,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::setup_test_evm_witness(
crate::execute::setup_test_evm_data(
data_path,
compiled_circuit_path,
test_data,
@@ -1842,7 +1852,7 @@ fn setup_test_evm_witness(
)
.await
.map_err(|e| {
let err_str = format!("Failed to run setup_test_evm_witness: {}", e);
let err_str = format!("Failed to run setup_test_evm_data: {}", e);
PyRuntimeError::new_err(err_str)
})?;
@@ -1853,8 +1863,8 @@ fn setup_test_evm_witness(
/// deploys the solidity verifier
#[pyfunction(signature = (
addr_path,
rpc_url,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
rpc_url=None,
contract_type=ContractType::default(),
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
@@ -1863,8 +1873,8 @@ fn setup_test_evm_witness(
fn deploy_evm(
py: Python,
addr_path: PathBuf,
rpc_url: String,
sol_code_path: PathBuf,
rpc_url: Option<String>,
contract_type: ContractType,
optimizer_runs: usize,
private_key: Option<String>,
@@ -1892,9 +1902,9 @@ fn deploy_evm(
#[pyfunction(signature = (
addr_path,
input_data,
rpc_url,
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
rpc_url=None,
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None
))]
@@ -1903,9 +1913,9 @@ fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
input_data: String,
rpc_url: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
@@ -1952,8 +1962,8 @@ fn deploy_da_evm(
///
#[pyfunction(signature = (
addr_verifier,
rpc_url,
proof_path=PathBuf::from(DEFAULT_PROOF),
rpc_url=None,
addr_da = None,
addr_vk = None,
))]
@@ -1961,8 +1971,8 @@ fn deploy_da_evm(
fn verify_evm<'a>(
py: Python<'a>,
addr_verifier: &'a str,
rpc_url: String,
proof_path: PathBuf,
rpc_url: Option<String>,
addr_da: Option<&'a str>,
addr_vk: Option<&'a str>,
) -> PyResult<Bound<'a, PyAny>> {
@@ -2107,7 +2117,7 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
m.add_function(wrap_pyfunction!(setup_test_evm_data, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;

View File

@@ -962,7 +962,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn layout(
&mut self,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)

View File

@@ -15,10 +15,12 @@ use serde::{Deserialize, Serialize};
pub enum HybridOp {
Ln {
scale: utils::F32,
eps: f64,
},
Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
},
Sqrt {
scale: utils::F32,
@@ -42,6 +44,7 @@ pub enum HybridOp {
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
},
Div {
denom: utils::F32,
@@ -77,6 +80,7 @@ pub enum HybridOp {
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
eps: f64,
},
Output {
decomp: bool,
@@ -105,13 +109,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
HybridOp::Greater
| HybridOp::Less
| HybridOp::Equals
| HybridOp::GreaterEqual
| HybridOp::Max
| HybridOp::Min
| HybridOp::LessEqual { .. } => {
| HybridOp::LessEqual => {
vec![0, 1]
}
_ => vec![],
@@ -128,12 +132,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Rsqrt {
input_scale,
output_scale,
eps,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
"RSQRT (input_scale={}, output_scale={}, eps={})",
input_scale, output_scale, eps
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::Ln { scale, eps } => format!("LN(scale={}, eps={})", scale, eps),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
@@ -146,16 +151,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Recip {
input_scale,
output_scale,
eps,
} => format!(
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
"RECIP (input_scale={}, output_scale={}, eps={})",
input_scale, output_scale, eps
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
padding,
stride,
kernel_shape,
normalized, data_format
normalized,
data_format,
} => format!(
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
padding, stride, kernel_shape, normalized, data_format
@@ -177,10 +184,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale,
output_scale,
axes,
eps,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
"SOFTMAX (input_scale={}, output_scale={}, axes={:?}, eps={})",
input_scale, output_scale, axes, eps
)
}
HybridOp::Output { decomp } => {
@@ -205,23 +213,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
eps,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
*eps,
)?,
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::Ln { scale, eps } => {
layouts::ln(config, region, values[..].try_into()?, *scale, *eps)?
}
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
@@ -255,12 +267,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Recip {
input_scale,
output_scale,
eps,
} => layouts::recip(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as IntegerRep),
integer_rep_to_felt(output_scale.0 as IntegerRep),
*eps,
)?,
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
@@ -317,6 +331,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale,
output_scale,
axes,
eps,
} => layouts::softmax_axes(
config,
region,
@@ -324,6 +339,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
*input_scale,
*output_scale,
axes,
*eps,
)?,
HybridOp::Output { decomp } => {
layouts::output(config, region, values[..].try_into()?, *decomp)?
@@ -346,10 +362,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
HybridOp::Greater { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Less { .. }
| HybridOp::LessEqual { .. }
HybridOp::Greater
| HybridOp::GreaterEqual
| HybridOp::Less
| HybridOp::LessEqual
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
@@ -364,6 +380,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
HybridOp::Ln {
scale: output_scale,
eps: _,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};

File diff suppressed because it is too large Load Diff

View File

@@ -186,7 +186,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(layouts::nonlinearity(
config,

View File

@@ -49,7 +49,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError>;
/// Returns the scale of the output of the operation.
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = values[0].clone();
if !value.all_prev_assigned() {
@@ -223,12 +223,29 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
true,
)?))
}
_ => Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
self.decomp,
)?)),
_ => {
if self.decomp {
log::debug!("constraining input to be decomp");
Ok(Some(
super::layouts::decompose(
config,
region,
values[..].try_into()?,
&region.base(),
&region.legs(),
false,
)?
.1,
))
} else {
log::debug!("constraining input to be identity");
Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
)?))
}
}
}
} else {
Ok(Some(value))
@@ -263,7 +280,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
&self,
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
_: &[ValTensor<F>],
_: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Err(super::CircuitError::UnsupportedOp)
}
@@ -319,8 +336,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
@@ -333,20 +355,20 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[ValTensor<F>],
_: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = if let Some(value) = &self.pre_assigned_val {
value.clone()
} else {
self.quantized_values.clone().try_into()?
};
// we gotta constrain it once if its used multiple times
Ok(Some(layouts::identity(
config,
region,
&[value],
self.decomp,
)?))
Ok(Some(if self.decomp {
log::debug!("constraining constant to be decomp");
super::layouts::decompose(config, region, &[&value], &region.base(), &region.legs(), false)?.1
} else {
log::debug!("constraining constant to be identity");
super::layouts::identity(config, region, &[&value])?
}))
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {

View File

@@ -49,7 +49,7 @@ pub enum PolyOp {
},
Downsample {
axis: usize,
stride: usize,
stride: isize,
modulo: usize,
},
DeConv {
@@ -188,7 +188,8 @@ impl<
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
stride, padding, output_padding, group, data_format, kernel_format)
stride, padding, output_padding, group, data_format, kernel_format
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
PolyOp::Slice { axis, start, end } => {
@@ -207,11 +208,11 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?, true)?,
PolyOp::LeakyReLU { slope, scale } => {
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
}
@@ -339,9 +340,7 @@ impl<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity { .. } => {
layouts::identity(config, region, values[..].try_into()?, false)?
}
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -420,14 +419,14 @@ impl<
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
PolyOp::Sign { .. } => 0,
PolyOp::Sign => 0,
_ => in_scales[0],
};
Ok(scale)
}
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
if matches!(self, PolyOp::Add | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]

View File

@@ -10,7 +10,6 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use std::{
cell::RefCell,
@@ -462,15 +461,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// Update the max and min from inputs
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
inputs: &ValTensor<F>,
) -> Result<(), CircuitError> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
}
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
let int_eval = inputs.int_evals()?;
let max = int_eval.iter().max().unwrap_or(&0);
let min = int_eval.iter().min().unwrap_or(&0);
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(*max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(*min);
Ok(())
}
@@ -505,10 +503,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// add used lookup
pub fn add_used_lookup(
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
lookup: &LookupOp,
inputs: &ValTensor<F>,
) -> Result<(), CircuitError> {
self.statistics.used_lookups.insert(lookup);
self.statistics.used_lookups.insert(lookup.clone());
self.update_max_min_lookup_inputs(inputs)
}
@@ -642,34 +640,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)?)
} else {
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
let values_map = values.create_constants_map();
self.assigned_constants.par_extend(values_map);
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_unconstrained(
&mut self,

View File

@@ -9,6 +9,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
use itertools::Itertools;
#[cfg(not(any(
all(target_arch = "wasm32", target_os = "unknown"),
not(feature = "ezkl")
@@ -64,7 +65,7 @@ mod matmul {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -141,7 +142,7 @@ mod matmul_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -215,7 +216,7 @@ mod matmul_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -302,7 +303,7 @@ mod matmul_col_ultra_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -380,6 +381,7 @@ mod matmul_col_ultra_overflow {
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -422,7 +424,7 @@ mod matmul_col_ultra_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -533,7 +535,7 @@ mod dot {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -610,7 +612,7 @@ mod dot_col_overflow_triple_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -683,7 +685,7 @@ mod dot_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -756,7 +758,7 @@ mod sum {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -826,7 +828,7 @@ mod sum_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -895,7 +897,7 @@ mod sum_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -966,7 +968,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -975,7 +977,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -984,7 +986,7 @@ mod composition {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -1061,7 +1063,7 @@ mod conv {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1218,7 +1220,7 @@ mod conv_col_ultra_overflow {
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone()],
&[&self.image, &self.kernel],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1377,7 +1379,7 @@ mod conv_relu_col_ultra_overflow {
let output = config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone()],
&[&self.image, &self.kernel],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1390,7 +1392,7 @@ mod conv_relu_col_ultra_overflow {
let _output = config
.layout(
&mut region,
&[output.unwrap().unwrap()],
&[&output.unwrap().unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -1517,7 +1519,11 @@ mod add_w_shape_casting {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -1584,7 +1590,11 @@ mod add {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -1671,8 +1681,8 @@ mod dynamic_lookup {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i],
&self.tables[i],
&self.lookups[i].iter().collect_vec().try_into().unwrap(),
&self.tables[i].iter().collect_vec().try_into().unwrap(),
)
.map_err(|_| Error::Synthesis)?;
}
@@ -1767,8 +1777,8 @@ mod shuffle {
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
references: [[ValTensor<F>; 1]; NUM_LOOP],
inputs: [ValTensor<F>; NUM_LOOP],
references: [ValTensor<F>; NUM_LOOP],
_marker: PhantomData<F>,
}
@@ -1818,15 +1828,15 @@ mod shuffle {
layouts::shuffles(
&config,
&mut region,
&self.inputs[i],
&self.references[i],
&[&self.inputs[i]],
&[&self.references[i]],
layouts::SortCollisionMode::Unsorted,
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.shuffle_col_coord(),
NUM_LOOP * self.references[0][0].len()
NUM_LOOP * self.references[0].len()
);
assert_eq!(region.shuffle_index(), NUM_LOOP);
@@ -1843,17 +1853,19 @@ mod shuffle {
// parameters
let references = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
@@ -1873,9 +1885,11 @@ mod shuffle {
} else {
loop_idx - 1
};
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * prev_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
@@ -1931,7 +1945,11 @@ mod add_with_overflow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2026,7 +2044,7 @@ mod add_with_overflow_and_poseidon {
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
let inputs = vec![assigned_inputs_a, assigned_inputs_b];
let inputs = vec![&assigned_inputs_a, &assigned_inputs_b];
layouter.assign_region(
|| "model",
@@ -2135,7 +2153,11 @@ mod sub {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sub),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2202,7 +2224,11 @@ mod mult {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Mult),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2269,7 +2295,11 @@ mod pow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(5)),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2360,13 +2390,13 @@ mod matmul_relu {
};
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2465,7 +2495,7 @@ mod relu {
Ok(config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2563,7 +2593,7 @@ mod lookup_ultra_overflow {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.map_err(|_| Error::Synthesis)

View File

@@ -1,6 +1,6 @@
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{Generator, Shell, generate};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{Commitments, RunArgs, pfsys::ProofType};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
use crate::graph::TestDataSource;
@@ -382,6 +382,42 @@ pub struct Cli {
pub command: Option<Commands>,
}
/// Custom parser for data field that handles both direct JSON strings and file paths with '@' prefix
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub struct DataField(pub String);
impl FromStr for DataField {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Check if the input starts with '@'
if let Some(file_path) = s.strip_prefix('@') {
// Extract the file path (remove the '@' prefix)
// Read the file content
let content = std::fs::read_to_string(file_path)
.map_err(|e| format!("Failed to read data file '{}': {}", file_path, e))?;
// Return the file content as the data field value
Ok(DataField(content))
} else {
// Use the input string directly
Ok(DataField(s.to_string()))
}
}
}
impl ToFlags for DataField {
fn to_flags(&self) -> Vec<String> {
vec![self.0.clone()]
}
}
impl std::fmt::Display for DataField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
@@ -400,10 +436,9 @@ pub enum Commands {
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
/// The path to the .json data file (with @ prefix) or a raw data string of the form '{"input_data": [[1, 2, 3]]}'
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_parser = DataField::from_str)]
data: Option<DataField>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
@@ -435,7 +470,7 @@ pub enum Commands {
/// The path to the .onnx model file
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
model: Option<PathBuf>,
/// The path to the .json data file to output
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
/// Hand-written parser for graph variables, eg. batch_size=1
@@ -444,11 +479,16 @@ pub enum Commands {
/// random seed for reproducibility (optional)
#[arg(long, value_hint = clap::ValueHint::Other, default_value = DEFAULT_SEED)]
seed: u64,
/// min value for random data
#[arg(long, value_hint = clap::ValueHint::Other)]
min: Option<f32>,
/// max value for random data
#[arg(long, value_hint = clap::ValueHint::Other)]
max: Option<f32>,
},
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
CalibrateSettings {
/// The path to the .json calibration data file.
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
/// The path to the .onnx model file
@@ -633,7 +673,6 @@ pub enum Commands {
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
/// The path to the compiled model file (generated using the compile-circuit command)
@@ -644,9 +683,9 @@ pub enum Commands {
/// Should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'T', long, value_hint = clap::ValueHint::FilePath)]
test_data: PathBuf,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
/// where the input data come from
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
input_source: TestDataSource,
@@ -654,20 +693,6 @@ pub enum Commands {
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
output_source: TestDataSource,
},
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
#[command(arg_required_else_help = true)]
TestUpdateAccountCalls {
/// The path to the verifier contract's address
#[arg(long, value_hint = clap::ValueHint::Other)]
addr: H160Flag,
/// The path to the .json data file.
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
},
/// Swaps the positions in the transcript that correspond to commitments
SwapProofCommitments {
/// The path to the proof file
@@ -745,7 +770,7 @@ pub enum Commands {
},
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
#[command(name = "create-evm-vka")]
CreateEvmVKArtifact {
CreateEvmVka {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
#[arg(long, value_hint = clap::ValueHint::FilePath)]
srs_path: Option<PathBuf>,
@@ -764,7 +789,7 @@ pub enum Commands {
},
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEvmDataAttestation {
CreateEvmDa {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS, value_hint = clap::ValueHint::FilePath)]
settings_path: Option<PathBuf>,
@@ -855,9 +880,9 @@ pub enum Commands {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Url)]
rpc_url: String,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
@@ -873,9 +898,8 @@ pub enum Commands {
},
/// Deploys an evm verifier that allows for data attestation
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
DeployEvmDa {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -884,9 +908,9 @@ pub enum Commands {
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA, value_hint = clap::ValueHint::FilePath)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
@@ -906,9 +930,9 @@ pub enum Commands {
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
/// does the verifier use data attestation ?
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_da: Option<H160Flag>,

File diff suppressed because one or more lines are too long

View File

@@ -1,33 +1,29 @@
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::CheckMode;
use crate::circuit::region::RegionSettings;
use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::eth::{
deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_multi_sol,
fix_da_single_sol,
};
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_sol};
#[allow(unused_imports)]
use crate::eth::{get_contract_artifacts, verify_proof_via_solidity};
use crate::graph::input::{Calls, GraphData};
use crate::graph::input::GraphData;
use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::{
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
};
use crate::pfsys::{
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::EZKL_BUF_CAPACITY;
use crate::{commands::*, EZKLError};
use crate::{Commitments, RunArgs};
use crate::{EZKLError, commands::*};
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::plonk::{self, Circuit};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
@@ -40,6 +36,7 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
use halo2_proofs::poly::kzg::{
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
use halo2_solidity_verifier;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
@@ -48,14 +45,15 @@ use halo2curves::serde::SerdeObject;
use indicatif::{ProgressBar, ProgressStyle};
use instant::Instant;
use itertools::Itertools;
use lazy_static::lazy_static;
use log::debug;
use log::{info, trace, warn};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde::Serialize;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::Config;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
use std::fs::File;
use std::io::BufWriter;
use std::io::{Cursor, Write};
@@ -68,8 +66,6 @@ use thiserror::Error;
use tract_onnx::prelude::IntoTensor;
use tract_onnx::prelude::Tensor as TractTensor;
use lazy_static::lazy_static;
lazy_static! {
#[derive(Debug)]
/// The path to the ezkl related data.
@@ -141,11 +137,15 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
data,
variables,
seed,
min,
max,
} => gen_random_data(
model.unwrap_or(DEFAULT_MODEL.into()),
data.unwrap_or(DEFAULT_DATA.into()),
variables,
seed,
min,
max,
),
Commands::CalibrateSettings {
model,
@@ -176,7 +176,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
srs_path,
} => gen_witness(
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
data.unwrap_or(DEFAULT_DATA.into()),
data.unwrap_or(DataField(DEFAULT_DATA.into())).to_string(),
Some(output.unwrap_or(DEFAULT_WITNESS.into())),
vk_path,
srs_path,
@@ -215,8 +215,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
addr_vk,
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::CreateEvmVKArtifact {
Commands::CreateEvmVka {
vk_path,
srs_path,
settings_path,
@@ -232,7 +231,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
)
.await
}
Commands::CreateEvmDataAttestation {
Commands::CreateEvmDa {
settings_path,
sol_code_path,
abi_path,
@@ -301,7 +300,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
input_source,
output_source,
} => {
setup_test_evm_witness(
setup_test_evm_data(
data.unwrap_or(DEFAULT_DATA.into()),
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
test_data,
@@ -311,11 +310,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
)
.await
}
Commands::TestUpdateAccountCalls {
addr,
data,
rpc_url,
} => test_update_account_calls(addr, data.unwrap_or(DEFAULT_DATA.into()), rpc_url).await,
Commands::SwapProofCommitments {
proof_path,
witness_path,
@@ -442,7 +436,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
)
.await
}
Commands::DeployEvmDataAttestation {
Commands::DeployEvmDa {
data,
settings_path,
sol_code_path,
@@ -849,6 +843,8 @@ pub(crate) fn gen_random_data(
data_path: PathBuf,
variables: Vec<(String, usize)>,
seed: u64,
min: Option<f32>,
max: Option<f32>,
) -> Result<String, EZKLError> {
let mut file = std::fs::File::open(&model_path).map_err(|e| {
crate::graph::errors::GraphError::ReadWriteFileError(
@@ -867,22 +863,32 @@ pub(crate) fn gen_random_data(
.collect::<tract_onnx::prelude::TractResult<Vec<_>>>()
.map_err(|e| EZKLError::from(e.to_string()))?;
let min = min.unwrap_or(0.0);
let max = max.unwrap_or(1.0);
/// Generates a random tensor of a given size and type.
fn random(
sizes: &[usize],
datum_type: tract_onnx::prelude::DatumType,
seed: u64,
min: f32,
max: f32,
) -> TractTensor {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut tensor = TractTensor::zero::<f32>(sizes).unwrap();
let slice = tensor.as_slice_mut::<f32>().unwrap();
slice.iter_mut().for_each(|x| *x = rng.r#gen());
slice.iter_mut().for_each(|x| *x = rng.gen_range(min..max));
tensor.cast_to_dt(datum_type).unwrap().into_owned()
}
fn tensor_for_fact(fact: &tract_onnx::prelude::TypedFact, seed: u64) -> TractTensor {
fn tensor_for_fact(
fact: &tract_onnx::prelude::TypedFact,
seed: u64,
min: f32,
max: f32,
) -> TractTensor {
if let Some(value) = &fact.konst {
return value.clone().into_tensor();
}
@@ -893,12 +899,14 @@ pub(crate) fn gen_random_data(
.expect("Expected concrete shape, found: {fact:?}"),
fact.datum_type,
seed,
min,
max,
)
}
let generated = input_facts
.iter()
.map(|v| tensor_for_fact(v, seed))
.map(|v| tensor_for_fact(v, seed, min, max))
.collect_vec();
let data = GraphData::from_tract_data(&generated)?;
@@ -1131,6 +1139,7 @@ pub(crate) async fn calibrate(
let mut num_failed = 0;
let mut num_passed = 0;
let mut failure_reasons = vec![];
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
pb.set_message(format!(
@@ -1155,20 +1164,21 @@ pub(crate) async fn calibrate(
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
let _r = Gag::stdout().ok();
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let _g = Gag::stderr().ok();
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
error!("circuit creation from run args failed: {:?}", e);
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
e
));
pb.inc(1);
num_failed += 1;
continue;
@@ -1213,6 +1223,13 @@ pub(crate) async fn calibrate(
error!("forward pass failed: {:?}", e);
pb.inc(1);
num_failed += 1;
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
e
));
continue;
}
}
@@ -1278,7 +1295,14 @@ pub(crate) async fn calibrate(
found_settings.as_json()?.to_colored_json_auto()?
);
num_passed += 1;
} else {
} else if let Err(res) = res {
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
res.to_string().red()
));
num_failed += 1;
}
@@ -1288,6 +1312,13 @@ pub(crate) async fn calibrate(
pb.finish_with_message("Calibration Done.");
if found_params.is_empty() {
if !failure_reasons.is_empty() {
error!("Calibration failed for the following reasons:");
for reason in failure_reasons {
error!("{}", reason);
}
}
return Err("calibration failed, could not find any suitable parameters given the calibration dataset".into());
}
@@ -1321,7 +1352,7 @@ pub(crate) async fn calibrate(
.clone()
}
CalibrationTarget::Accuracy => {
let param_iterator = found_params.iter().sorted_by_key(|p| {
let mut param_iterator = found_params.iter().sorted_by_key(|p| {
(
p.run_args.input_scale,
p.run_args.param_scale,
@@ -1330,7 +1361,7 @@ pub(crate) async fn calibrate(
)
});
let last = param_iterator.last().ok_or("no params found")?;
let last = param_iterator.next_back().ok_or("no params found")?;
let max_scale = (
last.run_args.input_scale,
last.run_args.param_scale,
@@ -1540,50 +1571,28 @@ pub(crate) async fn create_evm_data_attestation(
let data =
GraphData::from_str(&input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
debug!("data attestation data: {:?}", data);
// The number of input and output instances we attest to for the single call data attestation
let mut input_len = None;
let mut output_len = None;
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
if let Some(DataSource::OnChain(source)) = data.output_data {
if visibility.output.is_private() {
return Err("private output data on chain is not supported on chain".into());
}
let mut on_chain_output_data = vec![];
match source.calls {
Calls::Multiple(calls) => {
for call in calls {
on_chain_output_data.push(call);
}
}
Calls::Single(call) => {
output_len = Some(call.len);
}
}
Some(on_chain_output_data)
} else {
None
output_len = Some(source.call.decimals.len());
};
let input_data = if let DataSource::OnChain(source) = data.input_data {
if let DataSource::OnChain(source) = data.input_data {
if visibility.input.is_private() {
return Err("private input data on chain is not supported on chain".into());
}
let mut on_chain_input_data = vec![];
match source.calls {
Calls::Multiple(calls) => {
for call in calls {
on_chain_input_data.push(call);
}
}
Calls::Single(call) => {
input_len = Some(call.len);
}
}
Some(on_chain_input_data)
} else {
None
input_len = Some(source.call.decimals.len());
};
// If both model inputs and outputs are attested to then we
// Read the settings file. Look if either the run_ars.input_visibility, run_args.output_visibility or run_args.param_visibility is KZGCommit
// if so, then we need to load the witness
@@ -1604,24 +1613,16 @@ pub(crate) async fn create_evm_data_attestation(
None
};
// if either input_len or output_len is Some then we are in the single call data attestation mode
if input_len.is_some() || output_len.is_some() {
let output = fix_da_single_sol(input_len, output_len)?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationSingle", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
} else {
let output = fix_da_multi_sol(input_data, output_data, commitment_bytes)?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationMulti", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
}
let output: String = fix_da_sol(
commitment_bytes,
input_len.is_none() && output_len.is_none(),
)?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
Ok(String::new())
}
@@ -1630,7 +1631,7 @@ pub(crate) async fn deploy_da_evm(
data: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
@@ -1639,7 +1640,7 @@ pub(crate) async fn deploy_da_evm(
settings_path,
data,
sol_code_path,
rpc_url.as_deref(),
&rpc_url,
runs,
private_key.as_deref(),
)
@@ -1654,7 +1655,7 @@ pub(crate) async fn deploy_da_evm(
pub(crate) async fn deploy_evm(
sol_code_path: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
@@ -1667,7 +1668,7 @@ pub(crate) async fn deploy_evm(
};
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
&rpc_url,
runs,
private_key.as_deref(),
contract_name,
@@ -1710,7 +1711,7 @@ pub(crate) fn encode_evm_calldata(
pub(crate) async fn verify_evm(
proof_path: PathBuf,
addr_verifier: H160Flag,
rpc_url: Option<String>,
rpc_url: String,
addr_da: Option<H160Flag>,
addr_vk: Option<H160Flag>,
) -> Result<String, EZKLError> {
@@ -1724,7 +1725,7 @@ pub(crate) async fn verify_evm(
addr_verifier.into(),
addr_da.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
&rpc_url,
)
.await?
} else {
@@ -1732,7 +1733,7 @@ pub(crate) async fn verify_evm(
proof.clone(),
addr_verifier.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
&rpc_url,
)
.await?
};
@@ -1869,11 +1870,11 @@ pub(crate) fn setup(
Ok(String::new())
}
pub(crate) async fn setup_test_evm_witness(
pub(crate) async fn setup_test_evm_data(
data_path: String,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
input_source: TestDataSource,
output_source: TestDataSource,
) -> Result<String, EZKLError> {
@@ -1905,17 +1906,6 @@ pub(crate) async fn setup_test_evm_witness(
}
use crate::pfsys::ProofType;
pub(crate) async fn test_update_account_calls(
addr: H160Flag,
data: String,
rpc_url: Option<String>,
) -> Result<String, EZKLError> {
use crate::eth::update_account_calls;
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
Ok(String::new())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn prove(

View File

@@ -67,8 +67,11 @@ pub enum GraphError {
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
#[error("missing result for node {0}")]
MissingResults(usize),
/// Missing input
#[error("missing input {0}")]
MissingInputForNode(usize),
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] crate::tensor::TensorError),
@@ -98,8 +101,6 @@ pub enum GraphError {
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[tokio postgres] {0}")]
TokioPostgresError(#[from] tokio_postgres::Error),
/// Eth error
#[cfg(all(
feature = "ezkl",
@@ -141,7 +142,9 @@ pub enum GraphError {
#[error("range check {0} is too large")]
RangeCheckTooLarge(usize),
///Cannot use on-chain data source as private data
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
#[error(
"cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm."
)]
OnChainDataSource,
/// Missing data source
#[error("missing data source")]

View File

@@ -2,10 +2,6 @@ use super::errors::GraphError;
use super::quantize_float;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::postgres::Client;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::tensor::Tensor;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
@@ -25,9 +21,6 @@ use tract_onnx::tract_core::{
value::TValue,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
type Decimals = u8;
type Call = String;
type RPCUrl = String;
@@ -168,85 +161,26 @@ impl<'de> Deserialize<'de> for FileSourceInner {
/// Organized as a vector of vectors where each inner vector represents a row/entry
pub type FileSource = Vec<Vec<FileSourceInner>>;
/// Represents different types of calls for fetching on-chain data
#[derive(Clone, Debug, PartialOrd, PartialEq)]
pub enum Calls {
/// Multiple calls to different accounts, each returning individual values
Multiple(Vec<CallsToAccount>),
/// Single call returning an array of values
Single(CallToAccount),
}
/// Represents which parts of the model (input/output) are attested to on-chain
pub type InputOutput = (bool, bool);
impl Default for Calls {
fn default() -> Self {
Calls::Multiple(Vec::new())
}
}
impl Serialize for Calls {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Calls::Single(data) => data.serialize(serializer),
Calls::Multiple(data) => data.serialize(serializer),
}
}
}
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
impl<'de> Deserialize<'de> for Calls {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
if let Ok(t) = multiple_try {
return Ok(Calls::Multiple(t));
}
let single_try: Result<CallToAccount, _> = serde_json::from_str(this_json.get());
if let Ok(t) = single_try {
return Ok(Calls::Single(t));
}
Err(serde::de::Error::custom("failed to deserialize Calls"))
}
}
/// Configuration for accessing on-chain data sources
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct OnChainSource {
/// Call specifications for fetching data
pub calls: Calls,
pub call: CallToAccount,
/// RPC endpoint URL for accessing the chain
pub rpc: RPCUrl,
}
impl OnChainSource {
/// Creates a new OnChainSource with multiple calls
///
/// # Arguments
/// * `calls` - Vector of call specifications
/// * `rpc` - RPC endpoint URL
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
OnChainSource {
calls: Calls::Multiple(calls),
rpc,
}
}
/// Creates a new OnChainSource with a single call
/// Creates a new OnChainSource
///
/// # Arguments
/// * `call` - Call specification
/// * `rpc` - RPC endpoint URL
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
OnChainSource {
calls: Calls::Single(call),
rpc,
}
pub fn new(call: CallToAccount, rpc: RPCUrl) -> Self {
OnChainSource { call, rpc }
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
@@ -262,12 +196,9 @@ impl OnChainSource {
data: &FileSource,
scales: Vec<crate::Scale>,
mut shapes: Vec<Vec<usize>>,
rpc: Option<&str>,
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
use crate::eth::{
evm_quantize_multi, read_on_chain_inputs_multi, test_on_chain_data,
DEFAULT_ANVIL_ENDPOINT,
};
rpc: &str,
) -> Result<Self, GraphError> {
use crate::eth::{read_on_chain_inputs, test_on_chain_data};
use log::debug;
// Set up local anvil instance for reading on-chain data
@@ -281,46 +212,15 @@ impl OnChainSource {
shapes[idx] = vec![i.len()];
}
}
let used_rpc = rpc.to_string();
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
debug!("Calls to accounts: {:?}", calls_to_accounts);
let inputs =
read_on_chain_inputs_multi(client.clone(), client_address, &calls_to_accounts).await?;
let call_to_account = test_on_chain_data(client.clone(), data).await?;
debug!("Call to account: {:?}", call_to_account);
let inputs = read_on_chain_inputs(client.clone(), client_address, &call_to_account).await?;
debug!("Inputs: {:?}", inputs);
let mut quantized_evm_inputs = vec![];
let mut prev = 0;
for (idx, i) in data.iter().enumerate() {
quantized_evm_inputs.extend(
evm_quantize_multi(
client.clone(),
vec![scales[idx]; i.len()],
&(
inputs.0[prev..i.len()].to_vec(),
inputs.1[prev..i.len()].to_vec(),
),
)
.await?,
);
prev += i.len();
}
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
let mut inputs: Vec<Tensor<Fp>> = vec![];
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
let mut t: Tensor<Fp> = input.iter().cloned().collect();
t.reshape(&shape)?;
inputs.push(t);
}
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
// Fill the input_data field of the GraphData struct
Ok((
inputs,
OnChainSource::new_multiple(calls_to_accounts.clone(), used_rpc),
))
Ok(OnChainSource::new(call_to_account, used_rpc))
}
}
@@ -342,11 +242,9 @@ pub struct CallToAccount {
/// ABI-encoded function call data
pub call_data: Call,
/// Number of decimal places for float conversion
pub decimals: Decimals,
pub decimals: Vec<Decimals>,
/// Contract address to call
pub address: String,
/// Expected length of returned array
pub len: usize,
}
/// Represents different sources of input/output data for the EZKL model
@@ -357,9 +255,6 @@ pub enum DataSource {
File(FileSource),
/// Data fetched from blockchain contracts
OnChain(OnChainSource),
/// Data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DB(PostgresSource),
}
impl Default for DataSource {
@@ -420,15 +315,6 @@ impl<'de> Deserialize<'de> for DataSource {
return Ok(DataSource::OnChain(t));
}
// Try deserializing as PostgresSource if feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = third_try {
return Ok(DataSource::DB(t));
}
}
Err(serde::de::Error::custom("failed to deserialize DataSource"))
}
}
@@ -478,7 +364,7 @@ impl GraphData {
return Err(GraphError::InvalidDims(
0,
"non file data cannot be split into batches".to_string(),
))
));
}
}
Ok(inputs)
@@ -531,17 +417,15 @@ impl GraphData {
/// Loads graph input data from a string, first seeing if it is a file path or JSON data
/// If it is a file path, it will load the data from the file
/// Otherwise, it will attempt to parse the string as JSON data
///
///
/// # Arguments
/// * `data` - String containing the input data
/// # Returns
/// A new GraphData instance containing the loaded data
pub fn from_str(data: &str) -> Result<Self, GraphError> {
let graph_input = serde_json::from_str(data);
let graph_input = serde_json::from_str(data);
match graph_input {
Ok(graph_input) => {
return Ok(graph_input);
}
Ok(graph_input) => Ok(graph_input),
Err(_) => {
let path = std::path::PathBuf::from(data);
GraphData::from_path(path)
@@ -612,13 +496,8 @@ impl GraphData {
return Err(GraphError::InvalidDims(
0,
"on-chain data cannot be split into batches".to_string(),
))
));
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
GraphData {
input_data: DataSource::DB(data),
output_data: _,
} => data.fetch_and_format_as_file().await?,
};
// Process each input tensor according to its shape
@@ -635,7 +514,6 @@ impl GraphData {
input.len(),
input_size
),
));
}
@@ -683,45 +561,12 @@ impl GraphData {
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallsToAccount {
/// Converts CallsToAccount to Python object
fn to_object(&self, py: Python) -> PyObject {
let dict = PyDict::new(py);
dict.set_item("account", &self.address).unwrap();
dict.set_item("call_data", &self.call_data).unwrap();
dict.to_object(py)
}
}
// Additional Python bindings for various types...
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postgres_source_new() {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let source = PostgresSource::new(
"localhost".to_string(),
"5432".to_string(),
"user".to_string(),
"SELECT * FROM table".to_string(),
"database".to_string(),
"password".to_string(),
);
assert_eq!(source.host, "localhost");
assert_eq!(source.port, "5432");
assert_eq!(source.user, "user");
assert_eq!(source.query, "SELECT * FROM table");
assert_eq!(source.dbname, "database");
assert_eq!(source.password, "password");
}
}
#[test]
fn test_data_source_serialization_round_trip() {
// Test backwards compatibility with old format
@@ -764,95 +609,6 @@ mod tests {
}
}
/// Source data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
/// Database host address
pub host: RPCUrl,
/// Database user name
pub user: String,
/// Database password
pub password: String,
/// SQL query to execute
pub query: String,
/// Database name
pub dbname: String,
/// Database port
pub port: String,
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Creates a new PostgreSQL data source
pub fn new(
host: RPCUrl,
port: String,
user: String,
query: String,
dbname: String,
password: String,
) -> Self {
PostgresSource {
host,
user,
password,
query,
dbname,
port,
}
}
/// Fetches data from the PostgreSQL database
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// Configuration string
let config = if self.password.is_empty() {
format!(
"host={} user={} dbname={} port={}",
self.host, self.user, self.dbname, self.port
)
} else {
format!(
"host={} user={} dbname={} port={} password={}",
self.host, self.user, self.dbname, self.port, self.password
)
};
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// Extract rows from query
for row in client.query(&self.query, &[]).await? {
for i in 0..row.len() {
res.push(row.get(i));
}
}
Ok(vec![res])
}
/// Fetches and formats data as FileSource
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
.map(|d| {
FileSourceInner::Float(
d.n.as_ref()
.unwrap()
.to_f64()
.ok_or("could not convert decimal to f64")
.unwrap(),
)
})
.collect()
})
.collect())
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallToAccount {
fn to_object(&self, py: Python) -> PyObject {
@@ -860,21 +616,10 @@ impl ToPyObject for CallToAccount {
dict.set_item("account", &self.address).unwrap();
dict.set_item("call_data", &self.call_data).unwrap();
dict.set_item("decimals", &self.decimals).unwrap();
dict.set_item("len", &self.len).unwrap();
dict.to_object(py)
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for Calls {
fn to_object(&self, py: Python) -> PyObject {
match self {
Calls::Multiple(calls) => calls.to_object(py),
Calls::Single(call) => call.to_object(py),
}
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for DataSource {
fn to_object(&self, py: Python) -> PyObject {
@@ -883,18 +628,10 @@ impl ToPyObject for DataSource {
DataSource::OnChain(source) => {
let dict = PyDict::new(py);
dict.set_item("rpc_url", &source.rpc).unwrap();
dict.set_item("calls_to_accounts", &source.calls.to_object(py))
dict.set_item("calls_to_accounts", &source.call.to_object(py))
.unwrap();
dict.to_object(py)
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DataSource::DB(source) => {
let dict = PyDict::new(py);
dict.set_item("host", &source.host).unwrap();
dict.set_item("user", &source.user).unwrap();
dict.set_item("query", &source.query).unwrap();
dict.to_object(py)
}
}
}
}

View File

@@ -6,9 +6,6 @@ pub mod model;
pub mod modules;
/// Inner elements of a computational graph that represent a single operation / constraints.
pub mod node;
/// postgres helper functions
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod postgres;
/// Helper functions
pub mod utilities;
/// Representations of a computational graph's variables.
@@ -574,7 +571,7 @@ impl GraphSettings {
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| {
error!("failed to save settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
std::io::Error::other(e)
})
}
/// load params from file
@@ -584,7 +581,7 @@ impl GraphSettings {
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
std::io::Error::other(e)
})?;
crate::check_version_string_matches(&settings.version);
@@ -764,7 +761,7 @@ pub struct TestOnChainData {
/// The path to the test witness
pub data: std::path::PathBuf,
/// rpc endpoint
pub rpc: Option<String>,
pub rpc: String,
/// data sources for the on chain data
pub data_sources: TestSources,
}
@@ -1011,10 +1008,6 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::DB(pg) => {
let data = pg.fetch_and_format_as_file().await?;
self.load_file_data(&data, &shapes, scales, input_types)
}
}
}
@@ -1026,24 +1019,11 @@ impl GraphCircuit {
shapes: &Vec<Vec<usize>>,
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
use crate::eth::{
evm_quantize_multi, evm_quantize_single, read_on_chain_inputs_multi,
read_on_chain_inputs_single, setup_eth_backend,
};
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let quantized_evm_inputs = match source.calls {
input::Calls::Single(call) => {
let (inputs, decimals) =
read_on_chain_inputs_single(client.clone(), client_address, call).await?;
evm_quantize_single(client, scales, &inputs, decimals).await?
}
input::Calls::Multiple(calls) => {
let inputs =
read_on_chain_inputs_multi(client.clone(), client_address, &calls).await?;
evm_quantize_multi(client, scales, &inputs).await?
}
};
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (client, client_address) = setup_eth_backend(&source.rpc, None).await?;
let input = read_on_chain_inputs(client.clone(), client_address, &source.call).await?;
let quantized_evm_inputs =
evm_quantize(client, scales, &input, &source.call.decimals).await?;
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
let mut inputs: Vec<Tensor<Fp>> = vec![];
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
@@ -1252,15 +1232,9 @@ impl GraphCircuit {
let mut cs = ConstraintSystem::default();
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
let _r = Gag::stdout().ok();
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let _g = Gag::stderr().ok();
Self::configure_with_params(&mut cs, settings);
@@ -1444,6 +1418,8 @@ impl GraphCircuit {
let output_scales = self.model().graph.get_output_scales()?;
let input_shapes = self.model().graph.input_shapes()?;
let output_shapes = self.model().graph.output_shapes()?;
let mut input_data = None;
let mut output_data = None;
if matches!(
test_on_chain_data.data_sources.input,
@@ -1454,23 +1430,12 @@ impl GraphCircuit {
return Err(GraphError::OnChainDataSource);
}
let input_data = match &data.input_data {
DataSource::File(input_data) => input_data,
input_data = match &data.input_data {
DataSource::File(input_data) => Some(input_data),
_ => {
return Err(GraphError::OnChainDataSource);
return Err(GraphError::MissingDataSource);
}
};
// Get the flatten length of input_data
// if the input source is a field then set scale to 0
let datam: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
input_data,
input_scales,
input_shapes,
test_on_chain_data.rpc.as_deref(),
)
.await?;
data.input_data = datam.1.into();
}
if matches!(
test_on_chain_data.data_sources.output,
@@ -1481,20 +1446,39 @@ impl GraphCircuit {
return Err(GraphError::OnChainDataSource);
}
let output_data = match &data.output_data {
Some(DataSource::File(output_data)) => output_data,
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
output_data = match &data.output_data {
Some(DataSource::File(output_data)) => Some(output_data),
_ => return Err(GraphError::MissingDataSource),
};
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
output_data,
output_scales,
output_shapes,
test_on_chain_data.rpc.as_deref(),
)
.await?;
data.output_data = Some(datum.1.into());
}
// Merge the input and output data
let mut file_data: Vec<Vec<input::FileSourceInner>> = vec![];
let mut scales: Vec<crate::Scale> = vec![];
let mut shapes: Vec<Vec<usize>> = vec![];
if let Some(input_data) = input_data {
file_data.extend(input_data.clone());
scales.extend(input_scales.clone());
shapes.extend(input_shapes.clone());
}
if let Some(output_data) = output_data {
file_data.extend(output_data.clone());
scales.extend(output_scales.clone());
shapes.extend(output_shapes.clone());
};
// print file data
debug!("file data: {:?}", file_data);
let on_chain_data: OnChainSource =
OnChainSource::test_from_file_data(&file_data, scales, shapes, &test_on_chain_data.rpc)
.await?;
// Here we update the GraphData struct with the on-chain data
if input_data.is_some() {
data.input_data = on_chain_data.clone().into();
}
if output_data.is_some() {
data.output_data = Some(on_chain_data.into());
}
debug!("test on-chain data: {:?}", data);
// Save the updated GraphData struct to the data_path
data.save(test_on_chain_data.data)?;
Ok(())
@@ -1586,13 +1570,13 @@ impl Circuit<Fp> for GraphCircuit {
let mut module_configs = ModuleConfigs::from_visibility(
cs,
params.module_sizes.clone(),
&params.module_sizes,
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
module_configs.configure_complex_modules(cs, &visibility, &params.module_sizes);
vars.instantiate_instance(
cs,

View File

@@ -200,7 +200,7 @@ fn number_of_iterations(mappings: &[InputMapping], dims: Vec<&[usize]>) -> usize
InputMapping::Stacked { axis, chunk } => Some(
// number of iterations given the dim size along the axis
// and the chunk size
(dims[*axis] + chunk - 1) / chunk,
dims[*axis].div_ceil(*chunk), // (dims[*axis] + chunk - 1) / chunk,
),
_ => None,
});
@@ -589,10 +589,7 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
input_types: self.get_input_types().ok(),
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
@@ -650,10 +647,13 @@ impl Model {
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
for (i, id) in model.clone().inputs.iter().enumerate() {
let inputs = model.inputs.clone();
let outputs = model.outputs.clone();
for (i, id) in inputs.iter().enumerate() {
let input = model.node_mut(id.node);
if input.outputs.len() == 0 {
if input.outputs.is_empty() {
return Err(GraphError::MissingOutput(id.node));
}
let mut fact: InferenceFact = input.outputs[0].fact.clone();
@@ -672,7 +672,7 @@ impl Model {
model.set_input_fact(i, fact)?;
}
for (i, _) in model.clone().outputs.iter().enumerate() {
for (i, _) in outputs.iter().enumerate() {
model.set_output_fact(i, InferenceFact::default())?;
}
@@ -1196,7 +1196,7 @@ impl Model {
.base
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
&[output, &comparators],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
@@ -1257,12 +1257,27 @@ impl Model {
node.inputs()
.iter()
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
// check node is not an output
let is_output = self.graph.outputs.iter().any(|(o_idx, _)| *idx == *o_idx);
let res = if self.graph.nodes[idx].num_uses() == 1 && !is_output {
let res = results.remove(idx);
res.ok_or(GraphError::MissingResults(*idx))?[*outlet].clone()
} else {
results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet]
.clone()
};
Ok(res)
})
.collect::<Result<Vec<_>, GraphError>>()?
} else {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
if self.graph.nodes[idx].num_uses() == 1 {
let res = results.remove(idx);
vec![res.ok_or(GraphError::MissingInput(*idx))?[0].clone()]
} else {
vec![results.get(idx).ok_or(GraphError::MissingInput(*idx))?[0].clone()]
}
};
trace!("output dims: {:?}", node.out_dims());
trace!(
@@ -1273,7 +1288,7 @@ impl Model {
let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let res = if node.is_constant() && node.num_uses() == 1 {
let mut res = if node.is_constant() && node.num_uses() == 1 {
log::debug!("node {} is a constant with 1 use", n.idx);
let mut node = n.clone();
let c = node
@@ -1284,19 +1299,19 @@ impl Model {
} else {
config
.base
.layout(region, &values, n.opkind.clone_dyn())
.layout(region, &values.iter().collect_vec(), n.opkind.clone_dyn())
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?
};
if let Some(mut vt) = res {
if let Some(vt) = &mut res {
vt.reshape(&node.out_dims()[0])?;
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
//only use with mock prover
debug!("------------ output node {:?}: {:?}", idx, vt.show());
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
}
}
NodeType::SubGraph {
@@ -1340,7 +1355,7 @@ impl Model {
.inputs
.clone()
.into_iter()
.zip(values.clone().into_iter().map(|v| vec![v])),
.zip(values.iter().map(|v| vec![v.clone()])),
);
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
@@ -1421,7 +1436,7 @@ impl Model {
);
let outputs = output_nodes
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
Ok(results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?;
@@ -1476,7 +1491,7 @@ impl Model {
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
&[output, &comparator],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),

View File

@@ -37,15 +37,15 @@ impl ModuleConfigs {
/// Create new module configs from visibility of each variable
pub fn from_visibility(
cs: &mut ConstraintSystem<Fp>,
module_size: ModuleSizes,
module_size: &ModuleSizes,
logrows: usize,
) -> Self {
let mut config = Self::default();
for size in module_size.polycommit {
for size in &module_size.polycommit {
config
.polycommit
.push(PolyCommitChip::configure(cs, (logrows, size)));
.push(PolyCommitChip::configure(cs, (logrows, *size)));
}
config
@@ -55,8 +55,8 @@ impl ModuleConfigs {
pub fn configure_complex_modules(
&mut self,
cs: &mut ConstraintSystem<Fp>,
visibility: VarVisibility,
module_size: ModuleSizes,
visibility: &VarVisibility,
module_size: &ModuleSizes,
) {
if (visibility.input.is_hashed()
|| visibility.output.is_hashed()

View File

@@ -37,6 +37,7 @@ use crate::tensor::TensorError;
// Import curve-specific field type
use halo2curves::bn256::Fr as Fp;
use itertools::Itertools;
// Import logging for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
@@ -118,16 +119,15 @@ impl Op<Fp> for Rescaled {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
if self.scale.len() != values.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
}
let res =
&crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?
[..];
self.inner.layout(config, region, res)
crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?;
self.inner.layout(config, region, &res.iter().collect_vec())
}
/// Create a cloned boxed copy of this operation
@@ -274,13 +274,13 @@ impl Op<Fp> for RebaseScale {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
self.rebase_op.layout(config, region, &[original_res])
self.rebase_op.layout(config, region, &[&original_res])
}
/// Create a cloned boxed copy of this operation
@@ -472,7 +472,7 @@ impl Op<Fp> for SupportedOp {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
self.as_op().layout(config, region, values)
}

View File

@@ -1,493 +0,0 @@
use log::{debug, error, info};
use std::fmt::Debug;
use std::net::IpAddr;
#[cfg(all(not(not(feature = "ezkl")), unix))]
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, pin::Pin};
use tokio::task::JoinHandle;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
};
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::NoTls;
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
/// Connection configuration.
///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
///
///
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("config", &self.config)
.finish()
}
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
tokio_postgres::Config::new().into()
}
/// Sets the user to authenticate with.
///
/// If the user is not set, then this defaults to the user executing this process.
pub fn user(&mut self, user: &str) -> &mut Config {
self.config.user(user);
self
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn get_user(&self) -> Option<&str> {
self.config.get_user()
}
/// Sets the password to authenticate with.
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.config.password(password);
self
}
/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.config.get_password()
}
/// Sets the name of the database to connect to.
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.config.dbname(dbname);
self
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn get_dbname(&self) -> Option<&str> {
self.config.get_dbname()
}
/// Sets command line options used to configure the server.
pub fn options(&mut self, options: &str) -> &mut Config {
self.config.options(options);
self
}
/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.config.get_options()
}
/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}
/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.config.get_application_name()
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.config.ssl_mode(ssl_mode);
self
}
/// Gets the SSL configuration.
pub fn get_ssl_mode(&self) -> SslMode {
self.config.get_ssl_mode()
}
/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
/// There must be either no hosts, or the same number of hosts as hostaddrs.
pub fn host(&mut self, host: &str) -> &mut Config {
self.config.host(host);
self
}
/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &[Host] {
self.config.get_hosts()
}
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
pub fn get_hostaddrs(&self) -> &[IpAddr] {
self.config.get_hostaddrs()
}
/// Adds a Unix socket host to the configuration.
///
/// Unlike `host`, this method allows non-UTF8 paths.
#[cfg(all(not(not(feature = "ezkl")), unix))]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.config.host_path(host);
self
}
/// Adds a hostaddr to the configuration.
///
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
self.config.hostaddr(hostaddr);
self
}
/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.config.port(port);
self
}
/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &[u16] {
self.config.get_ports()
}
/// Sets the timeout applied to socket-level connection attempts.
///
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
/// host separately. Defaults to no limit.
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.config.connect_timeout(connect_timeout);
self
}
/// Gets the connection timeout, if one has been set with the
/// `connect_timeout` method.
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.config.get_connect_timeout()
}
/// Sets the TCP user timeout.
///
/// This is ignored for Unix domain socket connections. It is only supported on systems where
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
/// on other systems, it has no effect.
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
self.config.tcp_user_timeout(tcp_user_timeout);
self
}
/// Gets the TCP user timeout, if one has been set with the
/// `user_timeout` method.
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
self.config.get_tcp_user_timeout()
}
/// Controls the use of TCP keepalive.
///
/// This is ignored for Unix domain socket connections. Defaults to `true`.
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.config.keepalives(keepalives);
self
}
/// Reports whether TCP keepalives will be used.
pub fn get_keepalives(&self) -> bool {
self.config.get_keepalives()
}
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.config.keepalives_idle(keepalives_idle);
self
}
/// Gets the configured amount of idle time before a keepalive packet will
/// be sent on the connection.
pub fn get_keepalives_idle(&self) -> Duration {
self.config.get_keepalives_idle()
}
/// Sets the time interval between TCP keepalive probes.
/// On Windows, this sets the value of the tcp_keepalive structs keepaliveinterval field.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
self.config.keepalives_interval(keepalives_interval);
self
}
/// Gets the time interval between TCP keepalive probes.
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.config.get_keepalives_interval()
}
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
self.config.keepalives_retries(keepalives_retries);
self
}
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.config.get_keepalives_retries()
}
/// Sets the requirements of the session.
///
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
/// secondary servers. Defaults to `Any`.
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.config.target_session_attrs(target_session_attrs);
self
}
/// Gets the requirements of the session.
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.config.get_target_session_attrs()
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
/// Gets the channel binding behavior.
pub fn get_channel_binding(&self) -> ChannelBinding {
self.config.get_channel_binding()
}
/// Sets the host load balancing behavior.
///
/// Defaults to `disable`.
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.config.load_balance_hosts(load_balance_hosts);
self
}
/// Gets the host load balancing behavior.
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.config.get_load_balance_hosts()
}
/// Sets the notice callback.
///
/// This callback will be invoked with the contents of every
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
/// the same structure as errors, but they are not "errors" per-se.
///
/// Notices are distinct from notifications, which are instead accessible
/// via the [`Notifications`] API.
///
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
/// [`Notifications`]: crate::Notifications
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}
/// Opens a connection to a PostgreSQL database.
pub async fn connect(&self) -> Result<Client, Error> {
let (client, connection) = self.config.connect(NoTls).await?;
let connection = Connection::new(connection);
Ok(Client::new(client, connection))
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
s.parse::<tokio_postgres::Config>().map(Config::from)
}
}
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
/// get dropped.
pub struct Connection {
/// The underlying connection stream.
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
}
impl Connection {
/// Creates a new connection.
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
Connection {
connection: Box::pin(connection),
}
}
/// start the connection
pub async fn start(self) {
if let Err(e) = self.connection.await {
error!("connection error: {}", e);
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL client.
pub struct Client {
connection: JoinHandle<()>,
client: tokio_postgres::Client,
}
impl Drop for Client {
fn drop(&mut self) {
let _ = self.close_inner();
}
}
impl Client {
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
let thread = tokio::spawn(async move {
connection.start().await;
});
Client {
client,
connection: thread,
}
}
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
///
/// See the documentation for [`Config`] for information about the connection syntax.
///
/// [`Config`]: config/struct.Config.html
pub async fn connect(params: &str) -> Result<Client, Error> {
debug!("Connecting to database with params: {}", params);
params.parse::<Config>()?.connect().await
}
/// Returns a new `Config` object which can be used to configure and connect to a database.
pub fn configure() -> Config {
Config::new()
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
pub async fn execute<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.execute(query, params).await
}
/// Executes a statement, returning the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Examples
///
pub async fn query<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.query(query, params).await
}
/// Determines if the client's connection has already closed.
///
/// If this returns `true`, the client is no longer usable.
pub fn is_closed(&self) -> bool {
self.client.is_closed()
}
/// Closes the client's connection to the server.
///
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
/// caller.
pub fn close(mut self) -> Result<(), Error> {
self.close_inner()
}
fn close_inner(&mut self) -> Result<(), Error> {
self.client.__private_api_close();
Ok(())
}
}

View File

@@ -858,6 +858,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
eps: run_args.get_epsilon(),
})
}
@@ -903,6 +904,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Rsqrt {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
eps: run_args.get_epsilon(),
})
}
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
@@ -913,6 +915,7 @@ pub fn new_op_from_onnx(
if run_args.bounded_log_lookup {
SupportedOp::Hybrid(HybridOp::Ln {
scale: scale_to_multiplier(input_scales[0]).into(),
eps: run_args.get_epsilon(),
})
} else {
SupportedOp::Nonlinear(LookupOp::Ln {
@@ -1131,6 +1134,7 @@ pub fn new_op_from_onnx(
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
eps: run_args.get_epsilon(),
})
}
"MaxPool" => {
@@ -1398,7 +1402,7 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Downsample {
axis: downsample_node.axis,
stride: downsample_node.stride as usize,
stride: downsample_node.stride,
modulo: downsample_node.modulo,
})
}

View File

@@ -29,8 +29,14 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use log::warn;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc as _;
#[global_allocator]
#[cfg(all(feature = "jemalloc", not(target_arch = "wasm32")))]
static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[global_allocator]
#[cfg(all(feature = "mimalloc", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
@@ -350,6 +356,16 @@ pub struct RunArgs {
arg(long, default_value = "false")
)]
pub ignore_range_check_inputs_outputs: bool,
/// Optional override for epsilon value
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
pub epsilon: Option<f64>,
}
impl RunArgs {
/// Returns the epsilon value
pub fn get_epsilon(&self) -> f64 {
self.epsilon.unwrap_or(f64::EPSILON)
}
}
impl Default for RunArgs {
@@ -376,6 +392,7 @@ impl Default for RunArgs {
decomp_base: 16384,
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
epsilon: None,
}
}
}

View File

@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::CurveAffine;
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use halo2curves::serde::SerdeObject;
use halo2curves::CurveAffine;
use instant::Instant;
use log::{debug, info, trace};
#[cfg(not(feature = "det-prove"))]
@@ -324,7 +324,7 @@ where
}
#[cfg(feature = "python-bindings")]
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
#[cfg(feature = "python-bindings")]
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
where
@@ -348,9 +348,9 @@ where
}
impl<
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,

View File

@@ -27,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{IntegerRep, integer_rep_to_felt},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::Visibility,
};
@@ -42,11 +42,11 @@ use std::error::Error;
use std::fmt::Debug;
use std::io::Read;
use std::iter::Iterator;
use std::ops::Rem;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
/// The (inner) type of tensor elements.
pub trait TensorType: Clone + Debug + 'static {
pub trait TensorType: Clone + Debug {
/// Returns the zero value.
fn zero() -> Option<Self> {
None
@@ -55,10 +55,6 @@ pub trait TensorType: Clone + Debug + 'static {
fn one() -> Option<Self> {
None
}
/// Max operator for ordering values.
fn tmax(&self, _: &Self) -> Option<Self> {
None
}
}
macro_rules! tensor_type {
@@ -70,10 +66,6 @@ macro_rules! tensor_type {
fn one() -> Option<Self> {
Some($one)
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(max(*self, *other))
}
}
};
}
@@ -82,46 +74,12 @@ impl TensorType for f32 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f32::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
impl TensorType for f64 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f64::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
tensor_type!(bool, Bool, false, true);
@@ -147,14 +105,6 @@ impl<T: TensorType> TensorType for Value<T> {
fn one() -> Option<Self> {
Some(Value::known(T::one().unwrap()))
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(
(self.clone())
.zip(other.clone())
.map(|(a, b)| a.tmax(&b).unwrap()),
)
}
}
impl<F: PrimeField + PartialOrd> TensorType for Assigned<F>
@@ -168,14 +118,6 @@ where
fn one() -> Option<Self> {
Some(F::ONE.into())
}
fn tmax(&self, other: &Self) -> Option<Self> {
if self.evaluate() >= other.evaluate() {
Some(*self)
} else {
Some(*other)
}
}
}
impl<F: PrimeField> TensorType for Expression<F>
@@ -189,42 +131,14 @@ where
fn one() -> Option<Self> {
Some(Expression::Constant(F::ONE))
}
fn tmax(&self, _: &Self) -> Option<Self> {
todo!()
}
}
impl TensorType for Column<Advice> {}
impl TensorType for Column<Fixed> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value_field().zip(other.value_field()).map(|(a, b)| {
if a.evaluate() >= b.evaluate() {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value().zip(other.value()).map(|(a, b)| {
if a >= b {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {}
// specific types
impl TensorType for halo2curves::pasta::Fp {
@@ -235,10 +149,6 @@ impl TensorType for halo2curves::pasta::Fp {
fn one() -> Option<Self> {
Some(halo2curves::pasta::Fp::one())
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
}
}
impl TensorType for halo2curves::bn256::Fr {
@@ -249,9 +159,15 @@ impl TensorType for halo2curves::bn256::Fr {
fn one() -> Option<Self> {
Some(halo2curves::bn256::Fr::one())
}
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
impl<F: TensorType> TensorType for &F {
fn zero() -> Option<Self> {
None
}
fn one() -> Option<Self> {
None
}
}
@@ -374,7 +290,7 @@ impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
}
}
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync + 'data>
maybe_rayon::iter::IntoParallelRefMutIterator<'data> for Tensor<T>
{
type Iter = maybe_rayon::slice::IterMut<'data, T>;
@@ -423,6 +339,14 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
}
}
impl<T: Clone + TensorType> Tensor<&T> {
/// Clones the tensor values into a new tensor.
pub fn cloned(&self) -> Tensor<T> {
let inner = self.inner.clone().into_iter().cloned().collect::<Vec<T>>();
Tensor::new(Some(&inner), &self.dims).unwrap()
}
}
impl<T: Clone + TensorType> Tensor<T> {
/// Sets (copies) the tensor values to the provided ones.
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
@@ -554,7 +478,6 @@ impl<T: Clone + TensorType> Tensor<T> {
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// ```
@@ -631,23 +554,23 @@ impl<T: Clone + TensorType> Tensor<T> {
// Fill remaining dimensions
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
// Pre-calculate total size and allocate result vector
let total_size: usize = full_indices
.iter()
.map(|range| range.end - range.start)
.product();
let mut res = Vec::with_capacity(total_size);
// Calculate new dimensions once
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
// Use iterator directly without collecting into intermediate Vec
for coord in full_indices.iter().cloned().multi_cartesian_product() {
let index = self.get_index(&coord);
res.push(self[index].clone());
}
let mut output = Tensor::new(None, &dims)?;
Tensor::new(Some(&res), &dims)
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
output.par_iter_mut().enumerate().for_each(|(i, e)| {
let coord = &cartesian_coord[i];
*e = self.get(coord);
});
Ok(output)
}
/// Set a slice of the Tensor.
@@ -753,7 +676,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn get_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if i % n == 0 {
inner.push(elem.clone());
}
@@ -776,7 +699,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn exclude_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if i % n != 0 {
inner.push(elem.clone());
}
@@ -812,9 +735,9 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if (i + offset + 1) % n == 0 {
inner.extend(vec![elem; 1 + num_repeats]);
inner.extend(vec![elem.clone(); 1 + num_repeats]);
offset += num_repeats;
} else {
inner.push(elem.clone());
@@ -871,16 +794,16 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
/// let mut indices = vec![3, 4];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), expected);
///
///
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
/// let mut indices = vec![63, 67];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), b);
/// ```
pub fn remove_indices(
&self,
@@ -927,7 +850,7 @@ impl<T: Clone + TensorType> Tensor<T> {
}
self.dims = vec![];
}
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
if self.dims() == [0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
@@ -1292,33 +1215,6 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&[res]), &[1])
}
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn par_enum_map_mut_filtered<
F: Fn(usize) -> Result<T, E> + std::marker::Send + std::marker::Sync,
E: Error + std::marker::Send + std::marker::Sync,
>(
&mut self,
filter_indices: &std::collections::HashSet<usize>,
f: F,
) -> Result<(), E>
where
T: std::marker::Send + std::marker::Sync,
{
self.inner
.par_iter_mut()
.enumerate()
.filter(|(i, _)| filter_indices.contains(i))
.for_each(move |(i, e)| *e = f(i).unwrap());
Ok(())
}
}
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
@@ -1335,9 +1231,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
pub fn combine(&self) -> Result<Tensor<T>, TensorError> {
let mut dims = 0;
let mut inner = Vec::new();
for t in self.inner.clone().into_iter() {
for t in self.inner.iter() {
dims += t.len();
inner.extend(t.inner);
inner.extend(t.inner.clone());
}
Tensor::new(Some(&inner), &[dims])
}
@@ -1808,7 +1704,7 @@ impl DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => self.clone(),
_ => *self,
}
}
@@ -1908,7 +1804,7 @@ impl KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => self.clone(),
_ => *self,
}
}
@@ -2030,6 +1926,9 @@ mod tests {
fn tensor_slice() {
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
assert_eq!(
a.get_slice(&[0..2, 0..1]).unwrap(),
b.get_slice(&[0..2, 0..1]).unwrap()
);
}
}

View File

@@ -160,7 +160,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap();
@@ -168,7 +168,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let a = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
@@ -188,7 +188,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap();
@@ -196,7 +196,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let a = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9]),
@@ -216,7 +216,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap();
@@ -224,7 +224,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
/// ```
pub fn trilu<T: TensorType + std::marker::Send + std::marker::Sync>(
a: &Tensor<T>,
@@ -535,30 +535,101 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
/// let result = downsample(&x, 1, 2, 2).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 6]), &[2, 1]).unwrap();
/// assert_eq!(result, expected);
/// let x = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
/// &[2, 3],
/// ).unwrap();
///
/// // Test case 1: Negative stride along dimension 0
/// // This should flip the order along dimension 0
/// let result = downsample(&x, 0, -1, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[4, 5, 6, 1, 2, 3]), // Flipped order of rows
/// &[2, 3]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 2: Negative stride along dimension 1
/// // This should flip the order along dimension 1
/// let result = downsample(&x, 1, -1, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 2, 1, 6, 5, 4]), // Flipped order of columns
/// &[2, 3]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 3: Negative stride with stride magnitude > 1
/// // This should both skip and flip
/// let result = downsample(&x, 1, -2, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 1, 6, 4]), // Take every 2nd element in reverse
/// &[2, 2]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Test case 4: Negative stride with non-zero modulo
/// // This should start at (size - 1 - modulo) and reverse
/// let result = downsample(&x, 1, -2, 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[2, 5]), // Start at second element from end, take every 2nd in reverse
/// &[2, 1]
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// // Create a larger test case for more complex downsampling
/// let y = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
/// &[3, 4],
/// ).unwrap();
///
/// // Test case 5: Negative stride with modulo on larger tensor
/// let result = downsample(&y, 1, -2, 1).unwrap();
/// let expected = Tensor::<IntegerRep>::new(
/// Some(&[3, 1, 7, 5, 11, 9]), // Start at one after reverse, take every 2nd
/// &[3, 2]
/// ).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn downsample<T: TensorType + Send + Sync>(
input: &Tensor<T>,
dim: usize,
stride: usize,
stride: isize, // Changed from usize to isize to support negative strides
modulo: usize,
) -> Result<Tensor<T>, TensorError> {
let mut output_shape = input.dims().to_vec();
// now downsample along axis dim offset by modulo, rounding up (+1 if remaidner is non-zero)
let remainder = (input.dims()[dim] - modulo) % stride;
let div = (input.dims()[dim] - modulo) / stride;
output_shape[dim] = div + (remainder > 0) as usize;
let mut output = Tensor::<T>::new(None, &output_shape)?;
// Handle negative stride case
if stride == 0 {
return Err(TensorError::DimMismatch(
"downsample stride cannot be zero".to_string(),
));
}
if modulo > input.dims()[dim] {
let stride_abs = stride.unsigned_abs();
let mut output_shape = input.dims().to_vec();
if modulo >= input.dims()[dim] {
return Err(TensorError::DimMismatch("downsample".to_string()));
}
// now downsample along axis dim offset by modulo
// Calculate output shape based on the absolute value of stride
let remainder = (input.dims()[dim] - modulo) % stride_abs;
let div = (input.dims()[dim] - modulo) / stride_abs;
output_shape[dim] = div + (remainder > 0) as usize;
let mut output = Tensor::<T>::new(None, &output_shape)?;
// Calculate indices based on stride direction
let indices = (0..output_shape.len())
.map(|i| {
if i == dim {
let mut index = vec![0; output_shape[i]];
for (i, idx) in index.iter_mut().enumerate() {
*idx = i * stride + modulo;
for (j, idx) in index.iter_mut().enumerate() {
if stride > 0 {
// Positive stride: move forward from modulo
*idx = j * stride_abs + modulo;
} else {
// Negative stride: move backward from (size - 1 - modulo)
*idx = (input.dims()[dim] - 1 - modulo) - j * stride_abs;
}
}
index
} else {
@@ -845,11 +916,14 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
pub fn gather_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &'a Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1037,11 +1111,14 @@ pub fn gather_nd<T: TensorType + Send + Sync>(
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
pub fn scatter_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
src: &'a Tensor<T>,
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1112,12 +1189,12 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// use ezkl::tensor::ops::intercalate_values;
///
/// let tensor = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
/// let result = intercalate_values(&tensor, 0, 2, 1).unwrap();
/// let result = intercalate_values(&tensor, &0, 2, 1).unwrap();
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = intercalate_values(&expected, 0, 2, 0).unwrap();
/// let result = intercalate_values(&expected, &0, 2, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap();
///
/// assert_eq!(result, expected);
@@ -1125,7 +1202,7 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// ```
pub fn intercalate_values<T: TensorType>(
tensor: &Tensor<T>,
value: T,
value: &T,
stride: usize,
axis: usize,
) -> Result<Tensor<T>, TensorError> {
@@ -1423,7 +1500,7 @@ pub fn slice<T: TensorType + Send + Sync>(
}
}
t.get_slice(&slice)
Ok(t.get_slice(&slice)?)
}
// ---------------------------------------------------------------------------------------------------------
@@ -1788,14 +1865,14 @@ pub mod nonlinearities {
/// Some(&[4, 25, 8, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = rsqrt(&x, 1.0);
/// let result = rsqrt(&x, 1.0, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64, eps: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;
let fout = scale_input / (kix.sqrt() + f64::EPSILON);
let fout = scale_input / (kix.sqrt() + eps);
let rounded = fout.round();
Ok::<_, TensorError>(rounded as IntegerRep)
})
@@ -2268,14 +2345,23 @@ pub mod nonlinearities {
/// &[2, 3],
/// ).unwrap();
/// let k = 2_f64;
/// let result = recip(&x, 1.0, k);
/// let result = recip(&x, 1.0, k, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn recip(a: &Tensor<IntegerRep>, input_scale: f64, out_scale: f64) -> Tensor<IntegerRep> {
pub fn recip(
a: &Tensor<IntegerRep>,
input_scale: f64,
out_scale: f64,
eps: f64,
) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let rescaled = (a_i as f64) / input_scale;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let denom = if rescaled == 0_f64 {
(1_f64) / (rescaled + eps)
} else {
(1_f64) / (rescaled)
};
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
})
@@ -2291,16 +2377,16 @@ pub mod nonlinearities {
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
/// let k = 2_f64;
/// let result = zero_recip(1.0);
/// let result = zero_recip(1.0, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[4503599627370496]), &[1]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn zero_recip(out_scale: f64) -> Tensor<IntegerRep> {
pub fn zero_recip(out_scale: f64, eps: f64) -> Tensor<IntegerRep> {
let a = Tensor::<IntegerRep>::new(Some(&[0]), &[1]).unwrap();
a.par_enum_map(|_, a_i| {
let rescaled = a_i as f64;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let denom = (1_f64) / (rescaled + eps);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
})
@@ -2334,20 +2420,20 @@ pub mod accumulated {
/// Some(&[25, 35]),
/// &[2],
/// ).unwrap();
/// assert_eq!(dot(&[x, y], 1).unwrap(), expected);
/// assert_eq!(dot(&x, &y, 1).unwrap(), expected);
/// ```
pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &[Tensor<T>; 2],
a: &Tensor<T>,
b: &Tensor<T>,
chunk_size: usize,
) -> Result<Tensor<T>, TensorError> {
if inputs[0].clone().len() != inputs[1].clone().len() {
if a.len() != b.len() {
return Err(TensorError::DimMismatch("dot".to_string()));
}
let (a, b): (Tensor<T>, Tensor<T>) = (inputs[0].clone(), inputs[1].clone());
let transcript: Tensor<T> = a
.iter()
.zip(b)
.zip(b.iter())
.chunks(chunk_size)
.into_iter()
.scan(T::zero().unwrap(), |acc, chunk| {

View File

@@ -280,9 +280,17 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
fn from(t: Vec<ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().into(),
dims: vec![t.len()],
scale: 1,
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> From<Vec<&ValType<F>>> for ValTensor<F> {
fn from(t: Vec<&ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().cloned().into(),
dims: vec![t.len()],
scale: 1,
}
}
@@ -640,6 +648,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// # Returns
/// A tensor containing the base-n decomposition of each value
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let res = self
.get_inner()?
.par_iter()
@@ -665,9 +676,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
})
.collect::<Result<Vec<_>, _>>();
let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let mut tensor = res?.into_iter().flatten().collect::<Tensor<_>>();
tensor.reshape(&dims)?;
@@ -1043,119 +1052,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Removes constant zero values from the tensor
/// Uses parallel processing for tensors larger than a threshold
pub fn remove_const_zero_values(&mut self) {
let size_threshold = 1_000_000; // Tuned using benchmarks
if self.len() < size_threshold {
// Single-threaded for small tensors
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
Some(e)
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.par_chunks_mut(chunk_size)
.flat_map(|chunk| {
chunk.par_iter_mut().filter_map(|e| {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return None;
}
}
Some(e.clone())
})
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
}
}
/// Gets the indices of all constant zero values
/// Uses parallel processing for large tensors
///
/// # Returns
/// A vector of indices where constant zero values are located
pub fn get_const_zero_indices(&self) -> Vec<usize> {
let size_threshold = 1_000_000;
if self.len() < size_threshold {
// Single-threaded for small tensors
match &self {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match &self {
ValTensor::Value { inner: v, .. } => v
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter()
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>(),
ValTensor::Instance { .. } => vec![],
}
}
}
/// Gets the indices of all constant values
/// Uses parallel processing for large tensors
///
@@ -1276,7 +1172,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Returns an error if called on an Instance tensor
pub fn intercalate_values(
&mut self,
value: ValType<F>,
value: &ValType<F>,
stride: usize,
axis: usize,
) -> Result<(), TensorError> {
@@ -1368,10 +1264,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Concatenates two tensors along the first dimension
pub fn concat(&self, other: Self) -> Result<Self, TensorError> {
pub fn concat(&self, other: &Self) -> Result<Self, TensorError> {
let res = match (self, other) {
(ValTensor::Value { inner: v1, .. }, ValTensor::Value { inner: v2, .. }) => {
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2]), &[2])?.combine()?)
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2.clone()]), &[2])?.combine()?)
}
_ => {
return Err(TensorError::WrongMethod);

View File

@@ -1,8 +1,6 @@
use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::{CheckMode, region::ConstantsMap};
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
@@ -381,50 +379,6 @@ impl VarTensor {
}
}
/// Assigns values from a ValTensor to this tensor, excluding specified positions
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `omissions` - Set of positions to skip during assignment
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
error!(
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
}?;
res.set_scale(values.scale());
Ok(res)
}
/// Assigns values from a ValTensor to this tensor
///
/// # Arguments

Binary file not shown.

14
tests/foundry/.gitignore vendored Normal file
View File

@@ -0,0 +1,14 @@
# Compiler files
cache/
out/
# Ignores development broadcast logs
!/broadcast
/broadcast/*/31337/
/broadcast/**/dry-run/
# Docs
docs/
# Dotenv file
.env

66
tests/foundry/README.md Normal file
View File

@@ -0,0 +1,66 @@
## Foundry
**Foundry is a blazing fast, portable and modular toolkit for Ethereum application development written in Rust.**
Foundry consists of:
- **Forge**: Ethereum testing framework (like Truffle, Hardhat and DappTools).
- **Cast**: Swiss army knife for interacting with EVM smart contracts, sending transactions and getting chain data.
- **Anvil**: Local Ethereum node, akin to Ganache, Hardhat Network.
- **Chisel**: Fast, utilitarian, and verbose solidity REPL.
## Documentation
https://book.getfoundry.sh/
## Usage
### Build
```shell
$ forge build
```
### Test
```shell
$ forge test
```
### Format
```shell
$ forge fmt
```
### Gas Snapshots
```shell
$ forge snapshot
```
### Anvil
```shell
$ anvil
```
### Deploy
```shell
$ forge script script/Counter.s.sol:CounterScript --rpc-url <your_rpc_url> --private-key <your_private_key>
```
### Cast
```shell
$ cast <subcommand>
```
### Help
```shell
$ forge --help
$ anvil --help
$ cast --help
```

View File

@@ -0,0 +1,6 @@
[profile.default]
src = "../../contracts"
out = "out"
libs = ["lib"]
# See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options

View File

@@ -0,0 +1 @@
contracts/=../../contracts/

View File

@@ -0,0 +1,429 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;
import "forge-std/Test.sol";
import {console} from "forge-std/console.sol";
import "contracts/AttestData.sol" as AttestData;
contract MockVKA {
constructor() {}
}
contract MockVerifier {
bool public shouldVerify;
constructor(bool _shouldVerify) {
shouldVerify = _shouldVerify;
}
function verifyProof(
bytes calldata,
uint256[] calldata
) external view returns (bool) {
require(shouldVerify, "Verification failed");
return shouldVerify;
}
}
contract MockVerifierSeperate {
bool public shouldVerify;
constructor(bool _shouldVerify) {
shouldVerify = _shouldVerify;
}
function verifyProof(
address,
bytes calldata,
uint256[] calldata
) external view returns (bool) {
require(shouldVerify, "Verification failed");
return shouldVerify;
}
}
contract MockTargetContract {
int256[] public data;
constructor(int256[] memory _data) {
data = _data;
}
function setData(int256[] memory _data) external {
data = _data;
}
function getData() external view returns (int256[] memory) {
return data;
}
}
contract DataAttestationTest is Test {
AttestData.DataAttestation das;
MockVerifier verifier;
MockVerifierSeperate verifierSeperate;
MockVKA vka;
MockTargetContract target;
int256[] mockData = [int256(1e18), -int256(5e17)];
uint256[] decimals = [18, 18];
uint256[] bits = [13, 13];
uint8 instanceOffset = 0;
bytes callData;
function setUp() public {
target = new MockTargetContract(mockData);
verifier = new MockVerifier(true);
verifierSeperate = new MockVerifierSeperate(true);
vka = new MockVKA();
callData = abi.encodeWithSignature("getData()");
das = new AttestData.DataAttestation(
address(target),
callData,
decimals,
bits,
instanceOffset
);
}
// Fork of mulDivRound which doesn't revert on overflow and returns a boolean instead to indicate overflow
function mulDivRound(
uint256 x,
uint256 y,
uint256 denominator
) public pure returns (uint256 result, bool overflow) {
unchecked {
uint256 prod0;
uint256 prod1;
assembly {
let mm := mulmod(x, y, not(0))
prod0 := mul(x, y)
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
uint256 remainder = mulmod(x, y, denominator);
bool addOne;
if (remainder * 2 >= denominator) {
addOne = true;
}
if (prod1 == 0) {
if (addOne) {
return ((prod0 / denominator) + 1, false);
}
return (prod0 / denominator, false);
}
if (denominator > prod1) {
return (0, true);
}
assembly {
prod1 := sub(prod1, gt(remainder, prod0))
prod0 := sub(prod0, remainder)
}
uint256 twos = denominator & (~denominator + 1);
assembly {
denominator := div(denominator, twos)
prod0 := div(prod0, twos)
twos := add(div(sub(0, twos), twos), 1)
}
prod0 |= prod1 * twos;
uint256 inverse = (3 * denominator) ^ 2;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
inverse *= 2 - denominator * inverse;
result = prod0 * inverse;
if (addOne) {
result += 1;
}
return (result, false);
}
}
struct SampleAttestation {
int256 mockData;
uint8 decimals;
uint8 bits;
}
function test_fuzzAttestedData(
SampleAttestation[] memory _attestations
) public {
vm.assume(_attestations.length == 1);
int256[] memory _mockData = new int256[](1);
uint256[] memory _decimals = new uint256[](1);
uint256[] memory _bits = new uint256[](1);
uint256[] memory _instances = new uint256[](1);
for (uint256 i = 0; i < 1; i++) {
SampleAttestation memory attestation = _attestations[i];
_mockData[i] = attestation.mockData;
vm.assume(attestation.mockData != type(int256).min); /// Will overflow int256 during negation op
vm.assume(attestation.decimals < 77); /// Else will exceed uint256 bounds
vm.assume(attestation.bits < 128); /// Else will exceed EZKL fixed point bounds for int128 type
bool neg = attestation.mockData < 0;
if (neg) {
attestation.mockData = -attestation.mockData;
}
(uint256 _result, bool overflow) = mulDivRound(
uint256(attestation.mockData),
uint256(1 << attestation.bits),
uint256(10 ** attestation.decimals)
);
vm.assume(!overflow);
vm.assume(_result < das.HALF_ORDER());
if (neg) {
// No possibility of overflow here since output is less than or equal to HALF_ORDER
// and therefore falls within the max range of int256 without overflow
vm.assume(-int256(_result) > type(int128).min);
_instances[i] =
uint256(int(das.ORDER()) - int256(_result)) %
das.ORDER();
} else {
vm.assume(_result < uint128(type(int128).max));
_instances[i] = _result;
}
_decimals[i] = attestation.decimals;
_bits[i] = attestation.bits;
}
// Update the attested data
target.setData(_mockData);
// Deploy the new data attestation contract
AttestData.DataAttestation dasNew = new AttestData.DataAttestation(
address(target),
callData,
_decimals,
_bits,
instanceOffset
);
bytes memory proof = hex"1234"; // Would normally contain commitments
bytes memory encoded = abi.encodeWithSignature(
"verifyProof(bytes,uint256[])",
proof,
_instances
);
AttestData.DataAttestation.Scalars memory _scalars = AttestData
.DataAttestation
.Scalars(10 ** _decimals[0], 1 << _bits[0]);
int256 output = dasNew.quantizeData(_mockData[0], _scalars);
console.log("output: ", output);
uint256 fieldElement = dasNew.toFieldElement(output);
// output should equal to _instances[0]
assertEq(fieldElement, _instances[0]);
bool verificationResult = dasNew.verifyWithDataAttestation(
address(verifier),
encoded
);
assertTrue(verificationResult);
}
// Test deployment parameters
function testDeployment() public view {
assertEq(das.contractAddress(), address(target));
assertEq(das.callData(), abi.encodeWithSignature("getData()"));
assertEq(das.instanceOffset(), instanceOffset);
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
assertEq(scalar.decimals, 1e18);
assertEq(scalar.bits, 1 << 13);
}
// Test quantizeData function
function testQuantizeData() public view {
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
int256 positive = das.quantizeData(1e18, scalar);
assertEq(positive, int256(scalar.bits));
int256 negative = das.quantizeData(-1e18, scalar);
assertEq(negative, -int256(scalar.bits));
// Test rounding
int half = int(0.5e18 / scalar.bits);
int256 rounded = das.quantizeData(half, scalar);
assertEq(rounded, 1);
}
// Test staticCall functionality
function testStaticCall() public view {
bytes memory result = das.staticCall(
address(target),
abi.encodeWithSignature("getData()")
);
int256[] memory decoded = abi.decode(result, (int256[]));
assertEq(decoded[0], mockData[0]);
assertEq(decoded[1], mockData[1]);
}
// Test attestData validation
function testAttestDataSuccess() public view {
uint256[] memory instances = new uint256[](2);
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
instances[0] = das.toFieldElement(int(scalar.bits));
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
das.attestData(instances); // Should not revert
}
function testAttestDataFailure() public {
uint256[] memory instances = new uint256[](2);
instances[0] = das.toFieldElement(1e18); // Incorrect value
instances[1] = das.toFieldElement(5e17);
vm.expectRevert("Public input does not match");
das.attestData(instances);
}
// Test full verification flow
function testSuccessfulVerification() public view {
// Prepare valid instances
uint256[] memory instances = new uint256[](2);
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
instances[0] = das.toFieldElement(int(scalar.bits));
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
// Create valid calldata (mock)
bytes memory proof = hex"1234"; // Would normally contain commitments
bytes memory encoded = abi.encodeWithSignature(
"verifyProof(bytes,uint256[])",
proof,
instances
);
bytes memory encoded_vka = abi.encodeWithSignature(
"verifyProof(address,bytes,uint256[])",
address(vka),
proof,
instances
);
bool result = das.verifyWithDataAttestation(address(verifier), encoded);
assertTrue(result);
result = das.verifyWithDataAttestation(
address(verifierSeperate),
encoded_vka
);
assertTrue(result);
}
function testLoadInstances() public view {
uint256[] memory instances = new uint256[](2);
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
instances[0] = das.toFieldElement(int(scalar.bits));
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
// Create valid calldata (mock)
bytes memory proof = hex"1234"; // Would normally contain commitments
bytes memory encoded = abi.encodeWithSignature(
"verifyProof(bytes,uint256[])",
proof,
instances
);
bytes memory encoded_vka = abi.encodeWithSignature(
"verifyProof(address,bytes,uint256[])",
address(vka),
proof,
instances
);
// Load encoded instances from calldata
uint256[] memory extracted_instances_calldata = das
.getInstancesCalldata(encoded);
assertEq(extracted_instances_calldata[0], instances[0]);
assertEq(extracted_instances_calldata[1], instances[1]);
// Load encoded instances from memory
uint256[] memory extracted_instances_memory = das.getInstancesMemory(
encoded
);
assertEq(extracted_instances_memory[0], instances[0]);
assertEq(extracted_instances_memory[1], instances[1]);
// Load encoded with vk instances from calldata
uint256[] memory extracted_instances_calldata_vk = das
.getInstancesCalldata(encoded_vka);
assertEq(extracted_instances_calldata_vk[0], instances[0]);
assertEq(extracted_instances_calldata_vk[1], instances[1]);
// Load encoded with vk instances from memory
uint256[] memory extracted_instances_memory_vk = das.getInstancesMemory(
encoded_vka
);
assertEq(extracted_instances_memory_vk[0], instances[0]);
assertEq(extracted_instances_memory_vk[1], instances[1]);
}
function testInvalidCommitments() public {
// Create calldata with invalid commitments
bytes memory invalidProof = hex"5678";
uint256[] memory instances = new uint256[](2);
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
instances[0] = das.toFieldElement(int(scalar.bits));
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
bytes memory encoded = abi.encodeWithSignature(
"verifyProof(bytes,uint256[])",
invalidProof,
instances
);
vm.expectRevert("Invalid KZG commitments");
das.verifyWithDataAttestation(address(verifier), encoded);
}
function testInvalidVerifier() public {
MockVerifier invalidVerifier = new MockVerifier(false);
uint256[] memory instances = new uint256[](2);
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
instances[0] = das.toFieldElement(int(scalar.bits));
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
bytes memory encoded = abi.encodeWithSignature(
"verifyProof(bytes,uint256[])",
hex"1234",
instances
);
vm.expectRevert("low-level call to verifier failed");
das.verifyWithDataAttestation(address(invalidVerifier), encoded);
}
// Test edge cases
function testZeroValueQuantization() public view {
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
int256 zero = das.quantizeData(0, scalar);
assertEq(zero, 0);
}
function testOverflowProtection() public {
int256 order = int(
uint256(
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
)
);
// int256 half_order = int(order >> 1);
AttestData.DataAttestation.Scalars memory scalar = AttestData
.DataAttestation
.Scalars(1, 1 << 2);
vm.expectRevert("Overflow field modulus");
das.quantizeData(order, scalar); // Value that would overflow
}
function testInvalidFunctionSignature() public {
uint256[] memory instances = new uint256[](2);
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
instances[0] = das.toFieldElement(int(scalar.bits));
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
bytes memory encoded_invalid_sig = abi.encodeWithSignature(
"verifyProofff(bytes,uint256[])",
hex"1234",
instances
);
vm.expectRevert("Invalid function signature");
das.verifyWithDataAttestation(address(verifier), encoded_invalid_sig);
}
}

View File

@@ -3,10 +3,10 @@
mod native_tests {
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::Commitments;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
use ezkl::pfsys::Snark;
use ezkl::Commitments;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::Bn256;
use lazy_static::lazy_static;
@@ -2440,23 +2440,43 @@ mod native_tests {
));
}
input.save(data_path.clone().into()).unwrap();
let args = vec![
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
];
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.args(args)
.status()
.expect("failed to execute process");
assert!(status.success());
// generate the witness, passing the vk path to generate the necessary kzg commits only
// if input visibility is NOT hashed
if input_visibility != "hashed" {
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"gen-witness",
"-D",
&test_on_chain_data_path,
"-M",
&model_path,
"-O",
&witness_path,
"--vk-path",
&format!("{}/{}/key.vk", test_dir, example_name),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
}
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
@@ -2599,56 +2619,6 @@ mod native_tests {
.status()
.expect("failed to execute process");
assert!(status.success());
// Create a new set of test on chain data only for the on-chain input source
if input_source != "file" || output_source != "file" {
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let deployed_addr_arg = format!("--addr={}", addr_da);
let args: Vec<&str> = vec![
"test-update-account-calls",
deployed_addr_arg.as_str(),
"-D",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(&args)
.status()
.expect("failed to execute process");
assert!(status.success());
}
// As sanity check, add example that should fail.
let args = vec![
"verify-evm",
"--proof-path",
PF_FAILURE,
deployed_addr_verifier_arg.as_str(),
deployed_addr_da_arg.as_str(),
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
.args(args)
.status()
.expect("failed to execute process");
assert!(!status.success());
}
fn build_ezkl() {

View File

@@ -272,16 +272,6 @@ mod py_tests {
anvil_child.kill().unwrap();
}
#[test]
fn postgres_notebook_() {
crate::py_tests::init_binary();
let test_dir: TempDir = TempDir::new("mean_postgres").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "mean_postgres.ipynb");
run_notebook(path, "mean_postgres.ipynb");
test_dir.close().unwrap();
}
#[test]
fn tictactoe_autoencoder_notebook_() {
crate::py_tests::init_binary();

View File

@@ -479,15 +479,15 @@ async def test_deploy_evm_reusable_and_vka():
res = await ezkl.deploy_evm(
addr_path_verifier,
sol_code_path,
anvil_url,
sol_code_path,
"verifier/reusable",
)
res = await ezkl.deploy_evm(
addr_path_vk,
vk_code_path,
anvil_url,
vk_code_path,
"vka",
)
@@ -506,8 +506,8 @@ async def test_deploy_evm():
res = await ezkl.deploy_evm(
addr_path,
sol_code_path,
anvil_url,
sol_code_path,
)
assert res == True
@@ -528,8 +528,8 @@ async def test_deploy_evm_with_private_key():
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
private_key=anvil_default_private_key
)
@@ -540,8 +540,8 @@ async def test_deploy_evm_with_private_key():
with pytest.raises(RuntimeError, match="Failed to run deploy_evm"):
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
private_key=custom_zero_balance_private_key
)
@@ -564,8 +564,8 @@ async def test_verify_evm():
res = await ezkl.verify_evm(
addr,
anvil_url,
proof_path,
rpc_url=anvil_url,
# sol_code_path
# optimizer_runs
)
@@ -604,8 +604,8 @@ async def test_verify_evm_separate_vk():
res = await ezkl.verify_evm(
addr_verifier,
anvil_url,
proof_path,
rpc_url=anvil_url,
addr_vk=addr_vk,
# sol_code_path
# optimizer_runs
@@ -831,8 +831,8 @@ async def test_evm_aggregate_and_verify_aggr():
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
)
# as a sanity check