mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5125aaa090 | ||
|
|
f1950e6cd0 | ||
|
|
998ca22c2a | ||
|
|
5c574adc31 | ||
|
|
749e0ba652 | ||
|
|
d464ddf6b6 | ||
|
|
8f6c0aced5 | ||
|
|
860e9700a8 | ||
|
|
32dd4a854f | ||
|
|
924f7c0420 |
4
.github/workflows/pypi.yml
vendored
4
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-06-27
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
|
||||
18
.github/workflows/release.yml
vendored
18
.github/workflows/release.yml
vendored
@@ -102,28 +102,32 @@ jobs:
|
||||
PCRE2_SYS_STATIC: 1
|
||||
strategy:
|
||||
matrix:
|
||||
build: [windows-msvc, macos, macos-aarch64, linux-musl, linux-gnu]
|
||||
build: [windows-msvc, macos, macos-aarch64, linux-musl, linux-gnu, linux-aarch64]
|
||||
include:
|
||||
- build: windows-msvc
|
||||
os: windows-latest
|
||||
rust: nightly-2023-06-27
|
||||
rust: nightly-2024-02-06
|
||||
target: x86_64-pc-windows-msvc
|
||||
- build: macos
|
||||
os: macos-13
|
||||
rust: nightly-2023-06-27
|
||||
rust: nightly-2024-02-06
|
||||
target: x86_64-apple-darwin
|
||||
- build: macos-aarch64
|
||||
os: macos-13
|
||||
rust: nightly-2023-06-27
|
||||
rust: nightly-2024-02-06
|
||||
target: aarch64-apple-darwin
|
||||
- build: linux-musl
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2023-06-27
|
||||
rust: nightly-2024-02-06
|
||||
target: x86_64-unknown-linux-musl
|
||||
- build: linux-gnu
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2023-06-27
|
||||
rust: nightly-2024-02-06
|
||||
target: x86_64-unknown-linux-gnu
|
||||
- build: linux-aarch64
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2024-02-06
|
||||
target: aarch64-unknown-linux-gnu
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
@@ -181,7 +185,7 @@ jobs:
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
|
||||
- name: Strip release binary
|
||||
if: matrix.build != 'windows-msvc'
|
||||
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
|
||||
run: strip "target/${{ matrix.target }}/release/ezkl"
|
||||
|
||||
- name: Strip release binary (Windows)
|
||||
|
||||
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -547,8 +547,6 @@ jobs:
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Download MNIST
|
||||
run: sh data.sh
|
||||
- name: Examples
|
||||
run: cargo nextest run --release tests_examples
|
||||
|
||||
@@ -576,7 +574,7 @@ jobs:
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
run: source .env/bin/activate; pytest -vv
|
||||
run: source .env/bin/activate; pip install pytest-asyncio; pytest -vv
|
||||
|
||||
accuracy-measurement-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
|
||||
13
.github/workflows/verify.yml
vendored
13
.github/workflows/verify.yml
vendored
@@ -31,10 +31,21 @@ jobs:
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name#v }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Fetch integrity
|
||||
run: |
|
||||
ENGINE_INTEGRITY=$(npm view @ezkljs/engine@${{ github.ref_name#v }} dist.integrity)
|
||||
echo "ENGINE_INTEGRITY=$ENGINE_INTEGRITY" >> $GITHUB_ENV
|
||||
- name: Update pnpm-lock.yaml versions and integrity
|
||||
run: |
|
||||
awk -v integrity="$ENGINE_INTEGRITY" -v tag="${{ github.ref_name#v }}" '
|
||||
NR==30{$0=" specifier: \"" tag "\""}
|
||||
NR==31{$0=" version: \"" tag "\""}
|
||||
NR==400{$0=" /@ezkljs/engine@" tag ":"}
|
||||
NR==401{$0=" resolution: {integrity: \"" integrity "\"}"} 1' in-browser-evm-verifier/pnpm-lock.yaml > temp.yaml && mv temp.yaml in-browser-evm-verifier/pnpm-lock.yaml
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,5 @@
|
||||
target
|
||||
pkg
|
||||
data
|
||||
*.csv
|
||||
!examples/notebooks/eth_price.csv
|
||||
*.ipynb_checkpoints
|
||||
|
||||
1917
Cargo.lock
generated
1917
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
30
Cargo.toml
30
Cargo.toml
@@ -43,46 +43,49 @@ unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
ethers = { version = "2.0.11", default_features = false, features = [
|
||||
"ethers-solc",
|
||||
] }
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", features = ["provider-http", "signers", "contract", "rpc-types-eth", "signer-wallet", "node-bindings"] }
|
||||
foundry-compilers = {version = "0.4.1", features = ["svm-solc"]}
|
||||
ethabi = "18"
|
||||
indicatif = { version = "0.17.5", features = ["rayon"] }
|
||||
gag = { version = "1.0.0", default_features = false }
|
||||
instant = { version = "0.1" }
|
||||
reqwest = { version = "0.11.14", default-features = false, features = [
|
||||
reqwest = { version = "0.12.4", default-features = false, features = [
|
||||
"default-tls",
|
||||
"multipart",
|
||||
"stream",
|
||||
] }
|
||||
openssl = { version = "0.10.55", features = ["vendored"] }
|
||||
postgres = "0.19.5"
|
||||
tokio-postgres = "0.7.10"
|
||||
pg_bigdecimal = "0.1.5"
|
||||
futures-util = "0.3.30"
|
||||
lazy_static = "1.4.0"
|
||||
colored_json = { version = "3.0.1", default_features = false, optional = true }
|
||||
plotters = { version = "0.3.0", default_features = false, optional = true }
|
||||
regex = { version = "1", default_features = false }
|
||||
tokio = { version = "1.26.0", default_features = false, features = [
|
||||
tokio = { version = "1.35", default_features = false, features = [
|
||||
"macros",
|
||||
"rt",
|
||||
"rt-multi-thread"
|
||||
] }
|
||||
tokio-util = { version = "0.7.9", features = ["codec"] }
|
||||
pyo3 = { version = "0.20.2", features = [
|
||||
pyo3 = { version = "0.21.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-asyncio = { version = "0.20.0", features = [
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
|
||||
colored = { version = "2.0.0", default_features = false, optional = true }
|
||||
@@ -175,7 +178,7 @@ required-features = ["ezkl"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = ["ezkl", "mv-lookup"]
|
||||
default = ["ezkl", "mv-lookup", "no-banner"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ezkl = [
|
||||
@@ -198,6 +201,7 @@ det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
metal = ["dep:metal", "dep:objc"]
|
||||
|
||||
# icicle patch to 0.1.0 if feature icicle is enabled
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
|
||||
@@ -91,9 +91,9 @@ You can install the library from source
|
||||
cargo install --locked --path .
|
||||
```
|
||||
|
||||
You will need a functioning installation of `solc` in order to run `ezkl` properly.
|
||||
[solc-select](https://github.com/crytic/solc-select) is recommended.
|
||||
Follow the instructions on [solc-select](https://github.com/crytic/solc-select) to activate `solc` in your environment.
|
||||
`ezkl` now auto-manages solc installation for you.
|
||||
|
||||
|
||||
|
||||
|
||||
#### building python bindings
|
||||
|
||||
@@ -20,9 +20,9 @@
|
||||
"name": "quantize_data",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "int128[]",
|
||||
"internalType": "int64[]",
|
||||
"name": "quantized_data",
|
||||
"type": "int128[]"
|
||||
"type": "int64[]"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
@@ -31,9 +31,9 @@
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "int128[]",
|
||||
"internalType": "int64[]",
|
||||
"name": "quantized_data",
|
||||
"type": "int128[]"
|
||||
"type": "int64[]"
|
||||
}
|
||||
],
|
||||
"name": "to_field_element",
|
||||
|
||||
@@ -1,6 +1,97 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
pragma solidity ^0.8.20;
|
||||
import './LoadInstances.sol';
|
||||
contract LoadInstances {
|
||||
/**
|
||||
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
|
||||
* @notice must pass encoded bytes from memory
|
||||
* @param encoded - verifier calldata
|
||||
*/
|
||||
function getInstancesMemory(
|
||||
bytes memory encoded
|
||||
) internal pure returns (uint256[] memory instances) {
|
||||
bytes4 funcSig;
|
||||
uint256 instances_offset;
|
||||
uint256 instances_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := mload(add(encoded, 0x20))
|
||||
|
||||
// Fetch instances offset which is 4 + 32 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
instances_offset := mload(
|
||||
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
|
||||
)
|
||||
|
||||
instances_length := mload(add(add(encoded, 0x24), instances_offset))
|
||||
}
|
||||
instances = new uint256[](instances_length); // Allocate memory for the instances array.
|
||||
assembly {
|
||||
// Now instances points to the start of the array data
|
||||
// (right after the length field).
|
||||
for {
|
||||
let i := 0x20
|
||||
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
|
||||
i := add(i, 0x20)
|
||||
} {
|
||||
mstore(
|
||||
add(instances, i),
|
||||
mload(add(add(encoded, add(i, 0x24)), instances_offset))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
|
||||
* @notice must pass encoded bytes from calldata
|
||||
* @param encoded - verifier calldata
|
||||
*/
|
||||
function getInstancesCalldata(
|
||||
bytes calldata encoded
|
||||
) internal pure returns (uint256[] memory instances) {
|
||||
bytes4 funcSig;
|
||||
uint256 instances_offset;
|
||||
uint256 instances_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := calldataload(encoded.offset)
|
||||
|
||||
// Fetch instances offset which is 4 + 32 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
instances_offset := calldataload(
|
||||
add(
|
||||
encoded.offset,
|
||||
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
|
||||
)
|
||||
)
|
||||
|
||||
instances_length := calldataload(
|
||||
add(add(encoded.offset, 0x04), instances_offset)
|
||||
)
|
||||
}
|
||||
instances = new uint256[](instances_length); // Allocate memory for the instances array.
|
||||
assembly {
|
||||
// Now instances points to the start of the array data
|
||||
// (right after the length field).
|
||||
|
||||
for {
|
||||
let i := 0x20
|
||||
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
|
||||
i := add(i, 0x20)
|
||||
} {
|
||||
mstore(
|
||||
add(instances, i),
|
||||
calldataload(
|
||||
add(add(encoded.offset, add(i, 0x04)), instances_offset)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This contract serves as a Data Attestation Verifier for the EZKL model.
|
||||
// It is designed to read and attest to instances of proofs generated from a specified circuit.
|
||||
@@ -34,11 +125,14 @@ contract DataAttestation is LoadInstances {
|
||||
address public admin;
|
||||
|
||||
/**
|
||||
* @notice EZKL P value
|
||||
* @notice EZKL P value
|
||||
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
|
||||
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
|
||||
*/
|
||||
uint256 constant ORDER = uint256(0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001);
|
||||
uint256 constant ORDER =
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
);
|
||||
|
||||
uint256 constant INPUT_CALLS = 0;
|
||||
|
||||
@@ -69,7 +163,7 @@ contract DataAttestation is LoadInstances {
|
||||
|
||||
function updateAdmin(address _admin) external {
|
||||
require(msg.sender == admin, "Only admin can update admin");
|
||||
if(_admin == address(0)) {
|
||||
if (_admin == address(0)) {
|
||||
revert();
|
||||
}
|
||||
admin = _admin;
|
||||
@@ -80,7 +174,7 @@ contract DataAttestation is LoadInstances {
|
||||
bytes[][] memory _callData,
|
||||
uint256[][] memory _decimals
|
||||
) external {
|
||||
require(msg.sender == admin, "Only admin can update instanceOffset");
|
||||
require(msg.sender == admin, "Only admin can update account calls");
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
}
|
||||
|
||||
@@ -111,7 +205,10 @@ contract DataAttestation is LoadInstances {
|
||||
// count the total number of storage reads across all of the accounts
|
||||
counter += _callData[i].length;
|
||||
}
|
||||
require(counter == INPUT_CALLS + OUTPUT_CALLS, "Invalid number of calls");
|
||||
require(
|
||||
counter == INPUT_CALLS + OUTPUT_CALLS,
|
||||
"Invalid number of calls"
|
||||
);
|
||||
}
|
||||
|
||||
function mulDiv(
|
||||
@@ -167,7 +264,7 @@ contract DataAttestation is LoadInstances {
|
||||
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
|
||||
* @param data - The data returned from the account calls.
|
||||
* @param decimals - The number of decimals the data returned from the account calls has (for floating point representation).
|
||||
* @param scale - The scale used to convert the floating point value into a fixed point value.
|
||||
* @param scale - The scale used to convert the floating point value into a fixed point value.
|
||||
*/
|
||||
function quantizeData(
|
||||
bytes memory data,
|
||||
@@ -181,7 +278,7 @@ contract DataAttestation is LoadInstances {
|
||||
if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) {
|
||||
output += 1;
|
||||
}
|
||||
quantized_data = neg ? -int256(output): int256(output);
|
||||
quantized_data = neg ? -int256(output) : int256(output);
|
||||
}
|
||||
/**
|
||||
* @dev Make a static call to the account to fetch the data that EZKL reads from.
|
||||
@@ -211,7 +308,9 @@ contract DataAttestation is LoadInstances {
|
||||
* @param x - The quantized data.
|
||||
* @return field_element - The field element.
|
||||
*/
|
||||
function toFieldElement(int256 x) internal pure returns (uint256 field_element) {
|
||||
function toFieldElement(
|
||||
int256 x
|
||||
) internal pure returns (uint256 field_element) {
|
||||
// The casting down to uint256 is safe because the order is about 2^254, and the value
|
||||
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
|
||||
return uint256(x + int(ORDER)) % ORDER;
|
||||
@@ -251,12 +350,11 @@ contract DataAttestation is LoadInstances {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function verifyWithDataAttestation(
|
||||
address verifier,
|
||||
bytes calldata encoded
|
||||
) public view returns (bool) {
|
||||
require(verifier.code.length > 0,"Address: call to non-contract");
|
||||
require(verifier.code.length > 0, "Address: call to non-contract");
|
||||
attestData(getInstancesCalldata(encoded));
|
||||
// static call the verifier contract to verify the proof
|
||||
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
pragma solidity ^0.8.20;
|
||||
contract LoadInstances {
|
||||
/**
|
||||
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
|
||||
* @notice must pass encoded bytes from memory
|
||||
* @param encoded - verifier calldata
|
||||
*/
|
||||
function getInstancesMemory(
|
||||
bytes memory encoded
|
||||
) internal pure returns (uint256[] memory instances) {
|
||||
bytes4 funcSig;
|
||||
uint256 instances_offset;
|
||||
uint256 instances_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := mload(add(encoded, 0x20))
|
||||
|
||||
// Fetch instances offset which is 4 + 32 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
instances_offset := mload(
|
||||
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
|
||||
)
|
||||
|
||||
instances_length := mload(add(add(encoded, 0x24), instances_offset))
|
||||
}
|
||||
instances = new uint256[](instances_length); // Allocate memory for the instances array.
|
||||
assembly {
|
||||
// Now instances points to the start of the array data
|
||||
// (right after the length field).
|
||||
for {
|
||||
let i := 0x20
|
||||
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
|
||||
i := add(i, 0x20)
|
||||
} {
|
||||
mstore(
|
||||
add(instances, i),
|
||||
mload(add(add(encoded, add(i, 0x24)), instances_offset))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
|
||||
* @notice must pass encoded bytes from calldata
|
||||
* @param encoded - verifier calldata
|
||||
*/
|
||||
function getInstancesCalldata(
|
||||
bytes calldata encoded
|
||||
) internal pure returns (uint256[] memory instances) {
|
||||
bytes4 funcSig;
|
||||
uint256 instances_offset;
|
||||
uint256 instances_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := calldataload(encoded.offset)
|
||||
|
||||
// Fetch instances offset which is 4 + 32 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
instances_offset := calldataload(
|
||||
add(
|
||||
encoded.offset,
|
||||
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
|
||||
)
|
||||
)
|
||||
|
||||
instances_length := calldataload(add(add(encoded.offset, 0x04), instances_offset))
|
||||
}
|
||||
instances = new uint256[](instances_length); // Allocate memory for the instances array.
|
||||
assembly{
|
||||
// Now instances points to the start of the array data
|
||||
// (right after the length field).
|
||||
|
||||
for {
|
||||
let i := 0x20
|
||||
} lt(i, add(mul(instances_length, 0x20), 0x20)) {
|
||||
i := add(i, 0x20)
|
||||
} {
|
||||
mstore(
|
||||
add(instances, i),
|
||||
calldataload(
|
||||
add(add(encoded.offset, add(i, 0x04)), instances_offset)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
// SPDX-License-Identifier: GPL-3.0
|
||||
|
||||
pragma solidity ^0.8.17;
|
||||
|
||||
contract QuantizeData {
|
||||
/**
|
||||
* @notice EZKL P value
|
||||
* @dev In order to prevent the verifier from accepting two version of the same instance, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
|
||||
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
|
||||
*/
|
||||
uint256 constant ORDER =
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
);
|
||||
|
||||
/**
|
||||
* @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or denominator == 0
|
||||
* @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv)
|
||||
* with further edits by Uniswap Labs also under MIT license.
|
||||
*/
|
||||
function mulDiv(
|
||||
uint256 x,
|
||||
uint256 y,
|
||||
uint256 denominator
|
||||
) internal pure returns (uint256 result) {
|
||||
unchecked {
|
||||
// 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2^256 and mod 2^256 - 1, then use
|
||||
// use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
|
||||
// variables such that product = prod1 * 2^256 + prod0.
|
||||
uint256 prod0; // Least significant 256 bits of the product
|
||||
uint256 prod1; // Most significant 256 bits of the product
|
||||
assembly {
|
||||
let mm := mulmod(x, y, not(0))
|
||||
prod0 := mul(x, y)
|
||||
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
|
||||
}
|
||||
|
||||
// Handle non-overflow cases, 256 by 256 division.
|
||||
if (prod1 == 0) {
|
||||
// Solidity will revert if denominator == 0, unlike the div opcode on its own.
|
||||
// The surrounding unchecked block does not change this fact.
|
||||
// See https://docs.soliditylang.org/en/latest/control-structures.html#checked-or-unchecked-arithmetic.
|
||||
return prod0 / denominator;
|
||||
}
|
||||
|
||||
// Make sure the result is less than 2^256. Also prevents denominator == 0.
|
||||
require(denominator > prod1, "Math: mulDiv overflow");
|
||||
|
||||
///////////////////////////////////////////////
|
||||
// 512 by 256 division.
|
||||
///////////////////////////////////////////////
|
||||
|
||||
// Make division exact by subtracting the remainder from [prod1 prod0].
|
||||
uint256 remainder;
|
||||
assembly {
|
||||
// Compute remainder using mulmod.
|
||||
remainder := mulmod(x, y, denominator)
|
||||
|
||||
// Subtract 256 bit number from 512 bit number.
|
||||
prod1 := sub(prod1, gt(remainder, prod0))
|
||||
prod0 := sub(prod0, remainder)
|
||||
}
|
||||
|
||||
// Factor powers of two out of denominator and compute largest power of two divisor of denominator. Always >= 1.
|
||||
// See https://cs.stackexchange.com/q/138556/92363.
|
||||
|
||||
// Does not overflow because the denominator cannot be zero at this stage in the function.
|
||||
uint256 twos = denominator & (~denominator + 1);
|
||||
assembly {
|
||||
// Divide denominator by twos.
|
||||
denominator := div(denominator, twos)
|
||||
|
||||
// Divide [prod1 prod0] by twos.
|
||||
prod0 := div(prod0, twos)
|
||||
|
||||
// Flip twos such that it is 2^256 / twos. If twos is zero, then it becomes one.
|
||||
twos := add(div(sub(0, twos), twos), 1)
|
||||
}
|
||||
|
||||
// Shift in bits from prod1 into prod0.
|
||||
prod0 |= prod1 * twos;
|
||||
|
||||
// Invert denominator mod 2^256. Now that denominator is an odd number, it has an inverse modulo 2^256 such
|
||||
// that denominator * inv = 1 mod 2^256. Compute the inverse by starting with a seed that is correct for
|
||||
// four bits. That is, denominator * inv = 1 mod 2^4.
|
||||
uint256 inverse = (3 * denominator) ^ 2;
|
||||
|
||||
// Use the Newton-Raphson iteration to improve the precision. Thanks to Hensel's lifting lemma, this also works
|
||||
// in modular arithmetic, doubling the correct bits in each step.
|
||||
inverse *= 2 - denominator * inverse; // inverse mod 2^8
|
||||
inverse *= 2 - denominator * inverse; // inverse mod 2^16
|
||||
inverse *= 2 - denominator * inverse; // inverse mod 2^32
|
||||
inverse *= 2 - denominator * inverse; // inverse mod 2^64
|
||||
inverse *= 2 - denominator * inverse; // inverse mod 2^128
|
||||
inverse *= 2 - denominator * inverse; // inverse mod 2^256
|
||||
|
||||
// Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
|
||||
// This will give us the correct result modulo 2^256. Since the preconditions guarantee that the outcome is
|
||||
// less than 2^256, this is the final result. We don't need to compute the high bits of the result and prod1
|
||||
// is no longer required.
|
||||
result = prod0 * inverse;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
function quantize_data(
|
||||
bytes[] memory data,
|
||||
uint256[] memory decimals,
|
||||
uint256[] memory scales
|
||||
) external pure returns (int256[] memory quantized_data) {
|
||||
quantized_data = new int256[](data.length);
|
||||
for (uint i; i < data.length; i++) {
|
||||
int x = abi.decode(data[i], (int256));
|
||||
bool neg = x < 0;
|
||||
if (neg) x = -x;
|
||||
uint denom = 10 ** decimals[i];
|
||||
uint scale = 1 << scales[i];
|
||||
uint output = mulDiv(uint256(x), scale, denom);
|
||||
if (mulmod(uint256(x), scale, denom) * 2 >= denom) {
|
||||
output += 1;
|
||||
}
|
||||
|
||||
quantized_data[i] = neg ? -int256(output) : int256(output);
|
||||
}
|
||||
}
|
||||
|
||||
function to_field_element(
|
||||
int128[] memory quantized_data
|
||||
) public pure returns (uint256[] memory output) {
|
||||
output = new uint256[](quantized_data.length);
|
||||
for (uint i; i < quantized_data.length; i++) {
|
||||
output[i] = uint256(quantized_data[i] + int(ORDER)) % ORDER;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
// SPDX-License-Identifier: UNLICENSED
|
||||
pragma solidity ^0.8.17;
|
||||
|
||||
contract TestReads {
|
||||
int[] public arr;
|
||||
|
||||
constructor(int256[] memory _numbers) {
|
||||
for (uint256 i = 0; i < _numbers.length; i++) {
|
||||
arr.push(_numbers[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
11
data.sh
11
data.sh
@@ -1,11 +0,0 @@
|
||||
#! /bin/bash
|
||||
|
||||
mkdir data
|
||||
cd data
|
||||
|
||||
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
|
||||
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
|
||||
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
|
||||
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
|
||||
|
||||
gzip -d *.gz
|
||||
@@ -42,8 +42,8 @@ const NUM_INNER_COLS: usize = 1;
|
||||
struct Config<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i128,
|
||||
const LOOKUP_MAX: i128,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -66,8 +66,8 @@ struct Config<
|
||||
struct MyCircuit<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i128,
|
||||
const LOOKUP_MAX: i128,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -90,8 +90,8 @@ struct MyCircuit<
|
||||
impl<
|
||||
const LEN: usize,
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i128,
|
||||
const LOOKUP_MAX: i128,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -308,6 +308,7 @@ pub fn runconv() {
|
||||
tst_lbl: _,
|
||||
..
|
||||
} = MnistBuilder::new()
|
||||
.base_path("examples/data")
|
||||
.label_format_digit()
|
||||
.training_set_length(50_000)
|
||||
.validation_set_length(10_000)
|
||||
|
||||
BIN
examples/data/t10k-images-idx3-ubyte
Normal file
BIN
examples/data/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
examples/data/t10k-labels-idx1-ubyte
Normal file
BIN
examples/data/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
examples/data/train-images-idx3-ubyte
Normal file
BIN
examples/data/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
examples/data/train-labels-idx1-ubyte
Normal file
BIN
examples/data/train-labels-idx1-ubyte
Normal file
Binary file not shown.
@@ -23,8 +23,8 @@ struct MyConfig {
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened
|
||||
const LOOKUP_MIN: i128,
|
||||
const LOOKUP_MAX: i128,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
> {
|
||||
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
|
||||
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
|
||||
@@ -34,7 +34,7 @@ struct MyCircuit<
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<const LEN: usize, const LOOKUP_MIN: i128, const LOOKUP_MAX: i128> Circuit<F>
|
||||
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
|
||||
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
|
||||
{
|
||||
type Config = MyConfig;
|
||||
|
||||
@@ -251,7 +251,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -307,7 +307,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ezkl.setup_test_evm_witness(\n",
|
||||
"await ezkl.setup_test_evm_witness(\n",
|
||||
" data_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" # we write the call data to the same file as the input data\n",
|
||||
@@ -333,7 +333,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.get_srs( settings_path)\n"
|
||||
"res = await ezkl.get_srs( settings_path)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -354,7 +354,7 @@
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -462,7 +462,7 @@
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" settings_path,\n",
|
||||
@@ -482,7 +482,7 @@
|
||||
"\n",
|
||||
"addr_path_verifier = \"addr_verifier.txt\"\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
@@ -510,7 +510,7 @@
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"input_path = 'input.json'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_data_attestation(\n",
|
||||
"res = await ezkl.create_evm_data_attestation(\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
@@ -535,7 +535,7 @@
|
||||
"source": [
|
||||
"addr_path_da = \"addr_da.txt\"\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_da_evm(\n",
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
@@ -567,7 +567,7 @@
|
||||
"with open(addr_path_da, 'r') as f:\n",
|
||||
" addr_da = f.read()\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
@@ -592,7 +592,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -249,7 +249,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -278,7 +278,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.get_srs( settings_path)\n"
|
||||
"res = await ezkl.get_srs( settings_path)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -299,7 +299,7 @@
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -518,7 +518,7 @@
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" settings_path,\n",
|
||||
@@ -538,7 +538,7 @@
|
||||
"\n",
|
||||
"addr_path_verifier = \"addr_verifier.txt\"\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
@@ -566,7 +566,7 @@
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"input_path = 'input.json'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_data_attestation(\n",
|
||||
"res = await ezkl.create_evm_data_attestation(\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
@@ -591,7 +591,7 @@
|
||||
"source": [
|
||||
"addr_path_da = \"addr_da.txt\"\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_da_evm(\n",
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
@@ -623,7 +623,7 @@
|
||||
"with open(addr_path_da, 'r') as f:\n",
|
||||
" addr_da = f.read()\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" RPC_URL,\n",
|
||||
@@ -654,4 +654,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -150,7 +150,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
@@ -170,7 +170,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -192,7 +192,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -204,7 +204,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -303,4 +303,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -352,14 +352,8 @@
|
||||
"# Specify all the files we need\n",
|
||||
"\n",
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.ezkl')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')\n",
|
||||
"cal_data_path = os.path.join('cal_data.json')"
|
||||
"cal_data_path = os.path.join('calibration.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -424,7 +418,7 @@
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"res = ezkl.gen_settings()\n",
|
||||
"assert res == True\n",
|
||||
"\n"
|
||||
]
|
||||
@@ -443,7 +437,7 @@
|
||||
"\n",
|
||||
"# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy\n",
|
||||
"# You may want to increase the max logrows if accuracy is a concern\n",
|
||||
"res = ezkl.calibrate_settings(cal_data_path, model_path, settings_path, \"resources\", max_logrows = 12, scales = [2])"
|
||||
"res = await ezkl.calibrate_settings(target = \"resources\", max_logrows = 12, scales = [2])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -463,7 +457,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"res = ezkl.compile_circuit()\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
@@ -484,7 +478,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -504,17 +498,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" )\n",
|
||||
"res = ezkl.setup()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -539,7 +526,7 @@
|
||||
"# now generate the witness file\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness()\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -559,13 +546,7 @@
|
||||
"\n",
|
||||
"proof_path = os.path.join('proof.json')\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"proof = ezkl.prove(proof_type=\"single\", proof_path=proof_path)\n",
|
||||
"\n",
|
||||
"print(proof)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
@@ -585,11 +566,7 @@
|
||||
"source": [
|
||||
"# verify our proof\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" )\n",
|
||||
"res = ezkl.verify()\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
@@ -664,12 +641,9 @@
|
||||
"sol_code_path = os.path.join('Verifier.sol')\n",
|
||||
"abi_path = os.path.join('Verifier.abi')\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" sol_code_path=sol_code_path,\n",
|
||||
" abi_path=abi_path, \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -757,7 +731,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -467,7 +467,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
@@ -494,7 +494,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -508,7 +508,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -625,4 +625,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -195,7 +195,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -222,7 +222,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -236,7 +236,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -179,7 +179,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -202,7 +202,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -214,7 +214,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -313,4 +313,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -241,7 +241,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -270,7 +270,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.get_srs( settings_path)\n"
|
||||
"res = await ezkl.get_srs( settings_path)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -291,7 +291,7 @@
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -420,7 +420,7 @@
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" settings_path,\n",
|
||||
@@ -451,7 +451,7 @@
|
||||
"\n",
|
||||
"address_path = os.path.join(\"address.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
@@ -472,7 +472,7 @@
|
||||
"# make sure anvil is running locally\n",
|
||||
"# $ anvil -p 3030\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
|
||||
@@ -152,7 +152,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -175,7 +175,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs(settings_path = settings_path)"
|
||||
"res = await ezkl.get_srs(settings_path = settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -188,7 +188,7 @@
|
||||
"# now generate the witness file \n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -284,4 +284,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -155,7 +155,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -178,7 +178,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -190,7 +190,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -289,4 +289,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -233,7 +233,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -262,7 +262,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.get_srs( settings_path)\n"
|
||||
"res = await ezkl.get_srs( settings_path)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -315,7 +315,7 @@
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n"
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -429,7 +429,7 @@
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" settings_path,\n",
|
||||
@@ -460,7 +460,7 @@
|
||||
"\n",
|
||||
"address_path = os.path.join(\"address.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
@@ -481,7 +481,7 @@
|
||||
"# make sure anvil is running locally\n",
|
||||
"# $ anvil -p 3030\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
|
||||
@@ -193,7 +193,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -216,7 +216,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -228,7 +228,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -347,4 +347,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -142,7 +142,7 @@
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -165,7 +165,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -177,7 +177,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -276,4 +276,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -347,7 +347,7 @@
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -370,7 +370,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -383,7 +383,7 @@
|
||||
"# now generate the witness file \n",
|
||||
"witness_path = \"gan_witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -490,4 +490,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
279
examples/notebooks/logistic_regression.ipynb
Normal file
279
examples/notebooks/logistic_regression.ipynb
Normal file
@@ -0,0 +1,279 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Logistic Regression\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Sklearn based models are slightly finicky to get into a suitable onnx format. \n",
|
||||
"This notebook showcases how to do so using the `hummingbird-ml` python package for a Logistic Regression model. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95613ee9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"hummingbird-ml\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import json\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# here we create and (potentially train a model)\n",
|
||||
"\n",
|
||||
"# make sure you have the dependencies required here already installed\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])\n",
|
||||
"# y = 1 * x_0 + 2 * x_1 + 3\n",
|
||||
"y = np.dot(X, np.array([1, 2])) + 3\n",
|
||||
"reg = LogisticRegression().fit(X, y)\n",
|
||||
"reg.score(X, y)\n",
|
||||
"\n",
|
||||
"circuit = convert(reg, \"torch\", X[:1]).model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37637c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = os.path.join('network.onnx')\n",
|
||||
"compiled_model_path = os.path.join('network.compiled')\n",
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"witness_path = os.path.join('witness.json')\n",
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82db373a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = X.shape[1:]\n",
|
||||
"x = torch.rand(1, *shape, requires_grad=True)\n",
|
||||
"torch_out = circuit(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
" x,\n",
|
||||
" # where to save the model (can be a file or file-like object)\n",
|
||||
" \"network.onnx\",\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names=['input'], # the model's input names\n",
|
||||
" output_names=['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
|
||||
" 'output': {0: 'batch_size'}})\n",
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d5e374a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!RUST_LOG=trace\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cal_path = os.path.join(\"calibration.json\")\n",
|
||||
"\n",
|
||||
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aa4f090",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b74dcee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18c8b7c7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c384cbc8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"assert os.path.isfile(proof_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "76f00d41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"\n",
|
||||
"res = ezkl.verify(\n",
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -139,7 +139,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
@@ -180,7 +180,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -193,7 +193,7 @@
|
||||
"# now generate the witness file \n",
|
||||
"witness_path = \"lstmwitness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -300,4 +300,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -91,11 +91,14 @@
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"subprocess.Popen(command)\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(5)\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -310,7 +313,7 @@
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
@@ -387,7 +390,7 @@
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = ezkl.gen_witness(\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
@@ -425,16 +428,6 @@
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# kill all shovel process \n",
|
||||
"os.system(\"pkill -f shovel\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -323,7 +323,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[2,7])\n",
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[2,7])\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
@@ -348,7 +348,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs(settings_path)"
|
||||
"res = await ezkl.get_srs(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -362,7 +362,7 @@
|
||||
"# now generate the witness file\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -469,7 +469,7 @@
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test_1.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
@@ -502,7 +502,7 @@
|
||||
"\n",
|
||||
"address_path = os.path.join(\"address.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
@@ -525,7 +525,7 @@
|
||||
"# make sure anvil is running locally\n",
|
||||
"# $ anvil -p 3030\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
@@ -558,4 +558,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
@@ -289,7 +289,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -309,7 +309,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -321,7 +321,7 @@
|
||||
"# now generate the witness file \n",
|
||||
"witness_path = \"gan_witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -215,7 +215,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -235,7 +235,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -247,7 +247,7 @@
|
||||
"# now generate the witness file\n",
|
||||
"witness_path = \"ae_witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -451,7 +451,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")"
|
||||
]
|
||||
@@ -473,7 +473,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -485,7 +485,7 @@
|
||||
"# now generate the witness file \n",
|
||||
"witness_path = \"vae_witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -845,7 +845,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
@@ -870,7 +870,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -881,7 +881,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -993,4 +993,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
@@ -261,7 +261,7 @@
|
||||
"source": [
|
||||
"# iterate over each submodel gen-settings, compile circuit and setup zkSNARK\n",
|
||||
"\n",
|
||||
"def setup(i):\n",
|
||||
"async def setup(i):\n",
|
||||
" # file names\n",
|
||||
" model_path = os.path.join('network_split_'+str(i)+'.onnx')\n",
|
||||
" settings_path = os.path.join('settings_split_'+str(i)+'.json')\n",
|
||||
@@ -282,7 +282,7 @@
|
||||
"\n",
|
||||
" # generate settings for the current model\n",
|
||||
" res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
||||
" res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
|
||||
" res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales=[run_args.input_scale], max_logrows=run_args.logrows)\n",
|
||||
" assert res == True\n",
|
||||
"\n",
|
||||
" # load settings and print them to the console\n",
|
||||
@@ -303,11 +303,11 @@
|
||||
" assert os.path.isfile(vk_path)\n",
|
||||
" assert os.path.isfile(pk_path)\n",
|
||||
"\n",
|
||||
" res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
|
||||
" res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
|
||||
" run_args.input_scale = settings[\"model_output_scales\"][0]\n",
|
||||
"\n",
|
||||
"for i in range(2):\n",
|
||||
" setup(i)\n"
|
||||
" await setup(i)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -414,7 +414,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(2):\n",
|
||||
" setup(i)"
|
||||
" await setup(i)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -466,7 +466,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -174,7 +174,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -196,7 +196,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -208,7 +208,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -215,7 +215,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -229,7 +229,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -265,7 +265,7 @@
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( data, open(data_path_faulty, 'w' ))\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
|
||||
"res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)\n",
|
||||
"assert os.path.isfile(witness_path_faulty)"
|
||||
]
|
||||
},
|
||||
@@ -310,7 +310,7 @@
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump( data, open(data_path_truthy, 'w' ))\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
|
||||
"res = await ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)\n",
|
||||
"assert os.path.isfile(witness_path_truthy)"
|
||||
]
|
||||
},
|
||||
@@ -519,4 +519,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -193,7 +193,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -205,7 +205,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -290,7 +290,7 @@
|
||||
"source": [
|
||||
"# Generate a larger SRS. This is needed for the aggregated proof\n",
|
||||
"\n",
|
||||
"res = ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
|
||||
"res = await ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -374,7 +374,7 @@
|
||||
"sol_code_path = os.path.join(\"Verifier.sol\")\n",
|
||||
"abi_path = os.path.join(\"Verifier_ABI.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier_aggr(\n",
|
||||
"res = await ezkl.create_evm_verifier_aggr(\n",
|
||||
" [settings_path],\n",
|
||||
" aggregate_vk_path,\n",
|
||||
" sol_code_path,\n",
|
||||
@@ -404,4 +404,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -170,7 +170,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -192,7 +192,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -204,7 +204,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -191,7 +191,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -203,7 +203,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -302,4 +302,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -192,7 +192,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -204,7 +204,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -149,7 +149,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -171,7 +171,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -183,7 +183,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -282,4 +282,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -250,7 +250,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -297,7 +297,7 @@
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)\n",
|
||||
"assert os.path.isfile(witness_path)\n",
|
||||
"\n",
|
||||
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
|
||||
@@ -411,7 +411,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path, vk_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path, vk_path)\n",
|
||||
"assert os.path.isfile(witness_path)\n",
|
||||
"\n",
|
||||
"# we force the output to be 1 this corresponds to the solvency test being true -- and we set this to a fixed vis output\n",
|
||||
|
||||
@@ -167,7 +167,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
@@ -187,7 +187,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -209,7 +209,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -221,7 +221,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -320,4 +320,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -180,7 +180,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -202,7 +202,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -214,7 +214,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -420,7 +420,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"res = await ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True"
|
||||
]
|
||||
}
|
||||
@@ -446,4 +446,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -57,7 +57,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -119,7 +119,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -163,7 +163,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -217,7 +217,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -637,7 +637,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -646,7 +646,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ezkl.get_srs( settings_path)"
|
||||
"await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -683,7 +683,7 @@
|
||||
" data = json.load(f)\n",
|
||||
" print(len(data['input_data'][0]))\n",
|
||||
"\n",
|
||||
"ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
"await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -525,7 +525,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -572,7 +572,7 @@
|
||||
" data = json.load(f)\n",
|
||||
" print(len(data['input_data'][0]))\n",
|
||||
"\n",
|
||||
"ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
"await ezkl.gen_witness(data_path, compiled_model_path, witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
@@ -49,7 +49,11 @@
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os"
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"\n",
|
||||
"logging.basicConfig(level=logging.INFO)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -63,7 +67,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@@ -71,7 +75,15 @@
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"cpu\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
@@ -133,7 +145,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@@ -141,7 +153,18 @@
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1715422870\n",
|
||||
"1714818070\n",
|
||||
"https://api.coingecko.com/api/v3/coins/ethereum/market_chart/range?vs_currency=usd&from=1714818070&to=1715422870\n",
|
||||
"<Response [200]>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"def get_url(coin, currency, start, end):\n",
|
||||
@@ -174,7 +197,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@@ -183,7 +206,115 @@
|
||||
"id": "WSj1Uxln65vf",
|
||||
"outputId": "51422d71-9680-4b51-c4df-e400d20f988b"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>time</th>\n",
|
||||
" <th>prices</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>1714820485367</td>\n",
|
||||
" <td>3146.785806</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>1714824033868</td>\n",
|
||||
" <td>3127.968728</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>1714828058243</td>\n",
|
||||
" <td>3156.141681</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>1714831650751</td>\n",
|
||||
" <td>3124.834064</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>1714834972229</td>\n",
|
||||
" <td>3133.115333</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>163</th>\n",
|
||||
" <td>1715407579346</td>\n",
|
||||
" <td>2918.049749</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>164</th>\n",
|
||||
" <td>1715411090715</td>\n",
|
||||
" <td>2920.330834</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>165</th>\n",
|
||||
" <td>1715414554830</td>\n",
|
||||
" <td>2923.986611</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>166</th>\n",
|
||||
" <td>1715418419843</td>\n",
|
||||
" <td>2910.537671</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>167</th>\n",
|
||||
" <td>1715421675338</td>\n",
|
||||
" <td>2907.702307</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>168 rows × 2 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" time prices\n",
|
||||
"0 1714820485367 3146.785806\n",
|
||||
"1 1714824033868 3127.968728\n",
|
||||
"2 1714828058243 3156.141681\n",
|
||||
"3 1714831650751 3124.834064\n",
|
||||
"4 1714834972229 3133.115333\n",
|
||||
".. ... ...\n",
|
||||
"163 1715407579346 2918.049749\n",
|
||||
"164 1715411090715 2920.330834\n",
|
||||
"165 1715414554830 2923.986611\n",
|
||||
"166 1715418419843 2910.537671\n",
|
||||
"167 1715421675338 2907.702307\n",
|
||||
"\n",
|
||||
"[168 rows x 2 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df = pd.DataFrame(new_data)\n",
|
||||
"df\n"
|
||||
@@ -200,7 +331,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -217,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@@ -225,7 +356,98 @@
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:ezkl.execute:SRS already exists at that path\n",
|
||||
"INFO:ezkl.execute:num calibration batches: 1\n",
|
||||
"INFO:ezkl.execute:read 16777476 bytes from file (vector of len = 16777476)\n",
|
||||
"WARNING:ezkl.execute:\n",
|
||||
"\n",
|
||||
" <------------- Numerical Fidelity Report (input_scale: 4, param_scale: 4, scale_input_multiplier: 10) ------------->\n",
|
||||
"\n",
|
||||
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
|
||||
"| mean_error | median_error | max_error | min_error | mean_abs_error | median_abs_error | max_abs_error | min_abs_error | mean_squared_error | mean_percent_error | mean_abs_percent_error |\n",
|
||||
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
|
||||
"| -727.9929 | -727.9929 | -727.9929 | -727.9929 | 727.9929 | 727.9929 | 727.9929 | 727.9929 | 529973.7 | -0.24999964 | 0.24999964 |\n",
|
||||
"+------------+--------------+-----------+-----------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"INFO:ezkl.execute:file hash: 41509f380362a8d14401c5ae92073154922fe23e45459ce6f696f58607655db7\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\n",
|
||||
" \"run_args\": {\n",
|
||||
" \"tolerance\": {\n",
|
||||
" \"val\": 0.0,\n",
|
||||
" \"scale\": 1.0\n",
|
||||
" },\n",
|
||||
" \"input_scale\": 4,\n",
|
||||
" \"param_scale\": 4,\n",
|
||||
" \"scale_rebase_multiplier\": 10,\n",
|
||||
" \"lookup_range\": [\n",
|
||||
" 0,\n",
|
||||
" 0\n",
|
||||
" ],\n",
|
||||
" \"logrows\": 6,\n",
|
||||
" \"num_inner_cols\": 2,\n",
|
||||
" \"variables\": [\n",
|
||||
" [\n",
|
||||
" \"batch_size\",\n",
|
||||
" 1\n",
|
||||
" ]\n",
|
||||
" ],\n",
|
||||
" \"input_visibility\": \"Private\",\n",
|
||||
" \"output_visibility\": \"Public\",\n",
|
||||
" \"param_visibility\": \"Private\",\n",
|
||||
" \"div_rebasing\": false,\n",
|
||||
" \"rebase_frac_zero_constants\": false,\n",
|
||||
" \"check_mode\": \"UNSAFE\",\n",
|
||||
" \"commitment\": \"KZG\"\n",
|
||||
" },\n",
|
||||
" \"num_rows\": 21,\n",
|
||||
" \"total_assignments\": 42,\n",
|
||||
" \"total_const_size\": 0,\n",
|
||||
" \"total_dynamic_col_size\": 0,\n",
|
||||
" \"num_dynamic_lookups\": 0,\n",
|
||||
" \"num_shuffles\": 0,\n",
|
||||
" \"total_shuffle_col_size\": 0,\n",
|
||||
" \"model_instance_shapes\": [\n",
|
||||
" [\n",
|
||||
" 1\n",
|
||||
" ]\n",
|
||||
" ],\n",
|
||||
" \"model_output_scales\": [\n",
|
||||
" 8\n",
|
||||
" ],\n",
|
||||
" \"model_input_scales\": [\n",
|
||||
" 4\n",
|
||||
" ],\n",
|
||||
" \"module_sizes\": {\n",
|
||||
" \"polycommit\": [],\n",
|
||||
" \"poseidon\": [\n",
|
||||
" 0,\n",
|
||||
" [\n",
|
||||
" 0\n",
|
||||
" ]\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"required_lookups\": [],\n",
|
||||
" \"required_range_checks\": [],\n",
|
||||
" \"check_mode\": \"UNSAFE\",\n",
|
||||
" \"version\": \"0.0.0\",\n",
|
||||
" \"num_blinding_factors\": null,\n",
|
||||
" \"timestamp\": 1715422871248\n",
|
||||
"}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
@@ -236,9 +458,9 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
|
||||
"ezkl.calibrate_settings(\n",
|
||||
"await ezkl.calibrate_settings(\n",
|
||||
" input_filename, onnx_filename, settings_filename, \"resources\", scales = [4])\n",
|
||||
"res = ezkl.get_srs(settings_filename)\n",
|
||||
"res = await ezkl.get_srs(settings_filename)\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
@@ -259,7 +481,24 @@
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
|
||||
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
|
||||
"INFO:ezkl.graph.model:model layout...\n",
|
||||
"INFO:ezkl.pfsys:VK took 0.8\n",
|
||||
"INFO:ezkl.graph.model:model layout...\n",
|
||||
"INFO:ezkl.pfsys:PK took 0.2\n",
|
||||
"INFO:ezkl.pfsys:saving verification key 💾\n",
|
||||
"INFO:ezkl.pfsys:done saving verification key ✅\n",
|
||||
"INFO:ezkl.pfsys:saving proving key 💾\n",
|
||||
"INFO:ezkl.pfsys:done saving proving key ✅\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
@@ -281,20 +520,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(input_filename, compiled_filename, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
@@ -302,7 +541,34 @@
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:ezkl.pfsys:loading proving key from \"test.pk\"\n",
|
||||
"INFO:ezkl.pfsys:done loading proving key ✅\n",
|
||||
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
|
||||
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
|
||||
"INFO:ezkl.pfsys:proof started...\n",
|
||||
"INFO:ezkl.graph.model:model layout...\n",
|
||||
"INFO:ezkl.pfsys:proof took 0.15\n",
|
||||
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
|
||||
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
|
||||
"INFO:ezkl.pfsys:loading verification key from \"test.vk\"\n",
|
||||
"INFO:ezkl.pfsys:done loading verification key ✅\n",
|
||||
"INFO:ezkl.execute:verify took 0.2\n",
|
||||
"INFO:ezkl.execute:verified: true\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"verified\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
@@ -351,7 +617,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@@ -360,7 +626,26 @@
|
||||
"id": "fodkNgwS70FM",
|
||||
"outputId": "827b5efd-f74f-44de-c114-861b3a86daf2"
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:ezkl.pfsys.srs:loading srs from \"/Users/dante/.ezkl/srs/kzg6.srs\"\n",
|
||||
"INFO:ezkl.execute:downsizing params to 6 logrows\n",
|
||||
"INFO:ezkl.pfsys:loading verification key from \"test.vk\"\n",
|
||||
"INFO:ezkl.pfsys:done loading verification key ✅\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"test.vk\n",
|
||||
"settings.json\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# first we need to create evm verifier\n",
|
||||
"print(vk_path)\n",
|
||||
@@ -370,7 +655,7 @@
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" settings_filename,\n",
|
||||
" sol_code_path,\n",
|
||||
@@ -381,9 +666,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:ezkl.eth:using chain 31337\n",
|
||||
"INFO:ezkl.execute:Contract deployed at: 0x998abeb3e57409262ae5b751f60747921b33613e\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Make sure anvil is running locally first\n",
|
||||
"# run with $ anvil -p 3030\n",
|
||||
@@ -391,8 +685,9 @@
|
||||
"import json\n",
|
||||
"\n",
|
||||
"address_path = os.path.join(\"address.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"# await\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
@@ -406,16 +701,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:ezkl.eth:using chain 31337\n",
|
||||
"INFO:ezkl.eth:estimated verify gas cost: 399775\n",
|
||||
"INFO:ezkl.execute:Solidity verification result: true\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# read the address from addr_path\n",
|
||||
"addr = None\n",
|
||||
"with open(address_path, 'r') as f:\n",
|
||||
" addr = f.read()\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
@@ -451,7 +756,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
1
examples/notebooks/verifier_abi.json
Normal file
1
examples/notebooks/verifier_abi.json
Normal file
@@ -0,0 +1 @@
|
||||
[{"type":"function","name":"verifyProof","inputs":[{"name":"proof","type":"bytes","internalType":"bytes"},{"name":"instances","type":"uint256[]","internalType":"uint256[]"}],"outputs":[{"name":"","type":"bool","internalType":"bool"}],"stateMutability":"nonpayable"}]
|
||||
@@ -629,7 +629,7 @@
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
|
||||
"res = await ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", scales = [4])\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")\n"
|
||||
]
|
||||
@@ -660,7 +660,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.get_srs(settings_path)"
|
||||
"res = await ezkl.get_srs(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -680,7 +680,7 @@
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -807,7 +807,7 @@
|
||||
"settings_path = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" settings_path,\n",
|
||||
@@ -847,7 +847,7 @@
|
||||
"\n",
|
||||
"address_path = os.path.join(\"address.json\")\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" address_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
@@ -868,7 +868,7 @@
|
||||
"# make sure anvil is running locally\n",
|
||||
"# $ anvil -p 3030\n",
|
||||
"\n",
|
||||
"res = ezkl.verify_evm(\n",
|
||||
"res = await ezkl.verify_evm(\n",
|
||||
" addr,\n",
|
||||
" proof_path,\n",
|
||||
" \"http://127.0.0.1:3030\"\n",
|
||||
@@ -905,4 +905,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -242,6 +242,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2007dc77",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -257,6 +258,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ab993958",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). \n",
|
||||
@@ -272,7 +274,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -284,12 +286,13 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"witness = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"witness = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ad58432e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
|
||||
@@ -317,6 +320,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1746c8d1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now create an EVM verifier contract from our circuit. This contract will be deployed to the chain we are using. In this case we are using a local anvil instance."
|
||||
@@ -325,15 +329,15 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d1920c0f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"abi_path = 'test.abi'\n",
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_verifier(\n",
|
||||
"res = await ezkl.create_evm_verifier(\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
" abi_path,\n",
|
||||
@@ -344,6 +348,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0fd7f22b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -351,7 +356,7 @@
|
||||
"\n",
|
||||
"addr_path_verifier = \"addr_verifier.txt\"\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_evm(\n",
|
||||
"res = await ezkl.deploy_evm(\n",
|
||||
" addr_path_verifier,\n",
|
||||
" sol_code_path,\n",
|
||||
" 'http://127.0.0.1:3030'\n",
|
||||
@@ -362,6 +367,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9c0dffab",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"With the vanilla verifier deployed, we can now create the data attestation contract, which will read in the instances from the calldata to the verifier, attest to them, call the verifier and then return the result. \n",
|
||||
@@ -371,6 +377,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc888848",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@@ -378,6 +385,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c2db14d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -385,7 +393,7 @@
|
||||
"sol_code_path = 'test.sol'\n",
|
||||
"input_path = 'input.json'\n",
|
||||
"\n",
|
||||
"res = ezkl.create_evm_data_attestation(\n",
|
||||
"res = await ezkl.create_evm_data_attestation(\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
" sol_code_path,\n",
|
||||
@@ -396,12 +404,13 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5a018ba6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"addr_path_da = \"addr_da.txt\"\n",
|
||||
"\n",
|
||||
"res = ezkl.deploy_da_evm(\n",
|
||||
"res = await ezkl.deploy_da_evm(\n",
|
||||
" addr_path_da,\n",
|
||||
" input_path,\n",
|
||||
" settings_path,\n",
|
||||
@@ -412,6 +421,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2adad845",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we can pull in the data from the contract and calculate a new set of coordinates. We then rotate the world by 1 transform and submit the proof to the contract. The contract could then update the world rotation (logic not inserted here). For demo purposes we do this repeatedly, rotating the world by 1 transform. "
|
||||
@@ -444,6 +454,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "90eda56e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Call the view only verify method on the contract to verify the proof. Since it is a view function this is safe to use in production since you don't have to pass your private key."
|
||||
@@ -528,7 +539,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -193,7 +193,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -215,7 +215,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = ezkl.get_srs( settings_path)"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -227,7 +227,7 @@
|
||||
"source": [
|
||||
"# now generate the witness file \n",
|
||||
"\n",
|
||||
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"assert os.path.isfile(witness_path)"
|
||||
]
|
||||
},
|
||||
@@ -346,4 +346,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
9
examples/onnx/lstm_medium/input.json
Normal file
9
examples/onnx/lstm_medium/input.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
1.514470100402832, 1.519423007965088, 1.5182757377624512,
|
||||
1.5262789726257324, 1.5298409461975098
|
||||
]
|
||||
],
|
||||
"output_data": [[-0.1862019]]
|
||||
}
|
||||
BIN
examples/onnx/lstm_medium/network.onnx
Normal file
BIN
examples/onnx/lstm_medium/network.onnx
Normal file
Binary file not shown.
@@ -104,5 +104,5 @@ json.dump(data, open("input.json", 'w'))
|
||||
# ezkl.gen_settings("network.onnx", "settings.json")
|
||||
|
||||
# !RUST_LOG = full
|
||||
# res = ezkl.calibrate_settings(
|
||||
# res = await ezkl.calibrate_settings(
|
||||
# "input.json", "network.onnx", "settings.json", "resources")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@ezkljs/verify",
|
||||
"version": "0.0.0",
|
||||
"version": "v10.4.2",
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
},
|
||||
@@ -27,7 +27,7 @@
|
||||
"@ethereumjs/util": "9.0.0",
|
||||
"@ethereumjs/vm": "7.0.0",
|
||||
"@ethersproject/abi": "5.7.0",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"@ezkljs/engine": "10.4.2",
|
||||
"ethers": "6.7.1",
|
||||
"json-bigint": "1.0.0"
|
||||
},
|
||||
|
||||
8
in-browser-evm-verifier/pnpm-lock.yaml
generated
8
in-browser-evm-verifier/pnpm-lock.yaml
generated
@@ -27,8 +27,8 @@ dependencies:
|
||||
specifier: 5.7.0
|
||||
version: 5.7.0
|
||||
'@ezkljs/engine':
|
||||
specifier: ^9.4.4
|
||||
version: 9.4.4
|
||||
specifier: "10.4.2"
|
||||
version: "10.4.2"
|
||||
ethers:
|
||||
specifier: 6.7.1
|
||||
version: 6.7.1
|
||||
@@ -397,8 +397,8 @@ packages:
|
||||
'@ethersproject/strings': 5.7.0
|
||||
dev: false
|
||||
|
||||
/@ezkljs/engine@9.4.4:
|
||||
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
|
||||
/@ezkljs/engine@10.4.2:
|
||||
resolution: {integrity: "sha512-1GNB4vChbaQ1ALcYbEbM/AFoh4QWtswpzGCO/g9wL8Ep6NegM2gQP/uWICU7Utl0Lj1DncXomD7PUhFSXhtx8A=="}
|
||||
dependencies:
|
||||
'@types/json-bigint': 1.0.2
|
||||
json-bigint: 1.0.0
|
||||
|
||||
@@ -143,7 +143,7 @@ elif [ "$PLATFORM" == "macos" ]; then
|
||||
fi
|
||||
|
||||
elif [ "$PLATFORM" == "linux" ]; then
|
||||
if [ "${ARCHITECTURE}" = "amd64" ]; then
|
||||
if [ "$ARCHITECTURE" == "amd64" ]; then
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-gnu.tar.gz")
|
||||
|
||||
@@ -155,9 +155,20 @@ elif [ "$PLATFORM" == "linux" ]; then
|
||||
|
||||
echo "Cleaning up"
|
||||
rm "$EZKL_DIR/build-artifacts.ezkl-linux-gnu.tar.gz"
|
||||
else if [ "$ARCHITECTURE" == "aarch64" ]; then
|
||||
JSON_RESPONSE=$(curl -s "$RELEASE_URL")
|
||||
FILE_URL=$(echo "$JSON_RESPONSE" | grep -o 'https://github.com[^"]*' | grep "build-artifacts.ezkl-linux-aarch64.tar.gz")
|
||||
|
||||
echo "Downloading package"
|
||||
curl -L "$FILE_URL" -o "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz"
|
||||
|
||||
echo "Unpacking package"
|
||||
tar -xzf "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz" -C "$EZKL_DIR"
|
||||
|
||||
echo "Cleaning up"
|
||||
rm "$EZKL_DIR/build-artifacts.ezkl-linux-aarch64.tar.gz"
|
||||
else
|
||||
echo "ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
|
||||
echo "Non aarch ARM architectures are not supported for Linux at the moment. If you would need support for the ARM architectures on linux please submit an issue https://github.com/zkonduit/ezkl/issues/new/choose"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
|
||||
@@ -8,6 +8,7 @@ addopts = "-rfEX -p pytester --strict-markers"
|
||||
testpaths = [
|
||||
"tests/python/*_tests.py",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[project]
|
||||
name = "ezkl"
|
||||
|
||||
@@ -11,7 +11,7 @@ use ezkl::execute::run;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ezkl::logger::init_logger;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use log::{debug, error, info};
|
||||
use log::{error, info};
|
||||
#[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))]
|
||||
use rand::prelude::SliceRandom;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -33,7 +33,7 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
} else {
|
||||
info!("Running with CPU");
|
||||
}
|
||||
debug!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
info!("command: \n {}", &args.as_json()?.to_colored_json_auto()?);
|
||||
let res = run(args.command).await;
|
||||
match &res {
|
||||
Ok(_) => info!("succeeded"),
|
||||
|
||||
@@ -44,12 +44,11 @@ impl PolyCommitChip {
|
||||
/// Commit to the message using the KZG commitment scheme
|
||||
pub fn commit<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
|
||||
message: Vec<Scheme::Scalar>,
|
||||
degree: u32,
|
||||
num_unusable_rows: u32,
|
||||
params: &Scheme::ParamsProver,
|
||||
) -> Vec<G1Affine> {
|
||||
let k = params.k();
|
||||
let domain = halo2_proofs::poly::EvaluationDomain::new(degree, k);
|
||||
let domain = halo2_proofs::poly::EvaluationDomain::new(2, k);
|
||||
let n = 2_u64.pow(k) - num_unusable_rows as u64;
|
||||
let num_poly = (message.len() / n as usize) + 1;
|
||||
let mut poly = vec![domain.empty_lagrange(); num_poly];
|
||||
|
||||
@@ -24,7 +24,7 @@ use crate::{
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, error::Error, marker::PhantomData};
|
||||
|
||||
@@ -80,14 +80,18 @@ impl ToFlags for CheckMode {
|
||||
|
||||
impl From<String> for CheckMode {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"safe" => CheckMode::SAFE,
|
||||
"unsafe" => CheckMode::UNSAFE,
|
||||
_ => {
|
||||
log::error!("Invalid value for CheckMode");
|
||||
log::warn!("defaulting to SAFE");
|
||||
CheckMode::SAFE
|
||||
}
|
||||
Self::from_str(value.as_str()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for CheckMode {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"safe" => Ok(CheckMode::SAFE),
|
||||
"unsafe" => Ok(CheckMode::UNSAFE),
|
||||
_ => Err("Invalid value for CheckMode".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -345,7 +349,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::i128_to_felt,
|
||||
fieldutils::i64_to_felt,
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -71,7 +71,7 @@ pub enum HybridOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
@@ -184,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i128_to_felt(input_scale.0 as i128),
|
||||
i128_to_felt(output_scale.0 as i128),
|
||||
i64_to_felt(input_scale.0 as i64),
|
||||
i64_to_felt(output_scale.0 as i64),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i128_to_felt(denom.0 as i128),
|
||||
i64_to_felt(denom.0 as i64),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,9 +4,9 @@ use std::error::Error;
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
fieldutils::{felt_to_i64, i64_to_felt},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
use super::Op;
|
||||
@@ -132,16 +132,16 @@ impl LookupOp {
|
||||
/// Returns the range of values that can be represented by the table
|
||||
pub fn bit_range(max_len: usize) -> Range {
|
||||
let range = (max_len - 1) as f64 / 2_f64;
|
||||
let range = range as i128;
|
||||
let range = range as i64;
|
||||
(-range, range)
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
&self,
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_i128(x));
|
||||
let x = x[0].clone().map(|x| felt_to_i64(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
|
||||
@@ -228,13 +228,13 @@ impl LookupOp {
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
let output = res.map(|x| i64_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -27,12 +27,12 @@ pub mod region;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Returns a string representation of the operation.
|
||||
@@ -71,7 +71,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
|
||||
fn clone(&self) -> Self {
|
||||
self.clone_dyn()
|
||||
}
|
||||
@@ -122,8 +122,8 @@ impl InputType {
|
||||
*input = T::from_f64(f64_input).unwrap();
|
||||
}
|
||||
InputType::Int | InputType::TDim => {
|
||||
let int_input = input.clone().to_i128().unwrap();
|
||||
*input = T::from_i128(int_input).unwrap();
|
||||
let int_input = input.clone().to_i64().unwrap();
|
||||
*input = T::from_i64(int_input).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,7 +138,7 @@ pub struct Input {
|
||||
pub datum_type: InputType,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
Ok(self.scale)
|
||||
}
|
||||
@@ -193,7 +193,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Unknown;
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
Ok(0)
|
||||
}
|
||||
@@ -220,7 +220,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
///
|
||||
pub quantized_values: Tensor<F>,
|
||||
///
|
||||
@@ -230,7 +230,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
Self {
|
||||
@@ -263,7 +263,8 @@ impl<
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
+ for<'de> Deserialize<'de>
|
||||
+ IntoI64,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
|
||||
@@ -97,7 +97,8 @@ impl<
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
+ for<'de> Deserialize<'de>
|
||||
+ IntoI64,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
|
||||
@@ -9,7 +9,7 @@ use halo2_proofs::{
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use portable_atomic::AtomicI128 as AtomicInt;
|
||||
use portable_atomic::AtomicI64 as AtomicInt;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::{HashMap, HashSet},
|
||||
@@ -133,10 +133,11 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
shuffle_index: ShuffleIndex,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i128,
|
||||
min_lookup_inputs: i128,
|
||||
max_range_size: i128,
|
||||
max_lookup_inputs: i64,
|
||||
min_lookup_inputs: i64,
|
||||
max_range_size: i64,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
}
|
||||
|
||||
@@ -191,6 +192,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.witness_gen
|
||||
}
|
||||
|
||||
///
|
||||
pub fn check_lookup_range(&self) -> bool {
|
||||
self.check_lookup_range
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
|
||||
let region = Some(RefCell::new(region));
|
||||
@@ -209,6 +215,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen: true,
|
||||
check_lookup_range: true,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -246,12 +253,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen: false,
|
||||
check_lookup_range: false,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
let linear_coord = row * num_inner_cols;
|
||||
|
||||
@@ -268,6 +281,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen,
|
||||
check_lookup_range,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -278,6 +292,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -293,6 +308,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen,
|
||||
check_lookup_range,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -364,6 +380,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.witness_gen,
|
||||
self.check_lookup_range,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -546,17 +563,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
}
|
||||
|
||||
/// max lookup inputs
|
||||
pub fn max_lookup_inputs(&self) -> i128 {
|
||||
pub fn max_lookup_inputs(&self) -> i64 {
|
||||
self.max_lookup_inputs
|
||||
}
|
||||
|
||||
/// min lookup inputs
|
||||
pub fn min_lookup_inputs(&self) -> i128 {
|
||||
pub fn min_lookup_inputs(&self) -> i64 {
|
||||
self.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// max range check
|
||||
pub fn max_range_size(&self) -> i128 {
|
||||
pub fn max_range_size(&self) -> i64 {
|
||||
self.max_range_size
|
||||
}
|
||||
|
||||
|
||||
@@ -11,17 +11,17 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
fieldutils::i128_to_felt,
|
||||
tensor::{Tensor, TensorType},
|
||||
fieldutils::i64_to_felt,
|
||||
tensor::{IntoI64, Tensor, TensorType},
|
||||
};
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i128, i128);
|
||||
pub type Range = (i64, i64);
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
pub const RANGE_MULTIPLIER: i64 = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
|
||||
|
||||
@@ -96,21 +96,21 @@ pub struct Table<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
let chunk =
|
||||
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
|
||||
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
|
||||
|
||||
i128_to_felt(chunk)
|
||||
i64_to_felt(chunk)
|
||||
}
|
||||
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
|
||||
let chunk = chunk as i128;
|
||||
let chunk = chunk as i64;
|
||||
// we index from 1 to prevent soundness issues
|
||||
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
|
||||
let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0);
|
||||
let op_f = self
|
||||
.nonlinearity
|
||||
.f(&[Tensor::from(vec![first_element].into_iter())])
|
||||
@@ -130,12 +130,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
|
||||
pub fn num_cols_required(range_len: i64, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as i128)) as usize + 1
|
||||
(range_len / (col_size as i64)) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -202,7 +202,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
@@ -272,12 +272,12 @@ pub struct RangeCheck<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> F {
|
||||
let chunk = chunk as i128;
|
||||
let chunk = chunk as i64;
|
||||
// we index from 1 to prevent soundness issues
|
||||
i128_to_felt(chunk * (self.col_size as i128) + self.range.0)
|
||||
i64_to_felt(chunk * (self.col_size as i64) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
@@ -294,13 +294,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
let chunk =
|
||||
(crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128);
|
||||
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
|
||||
|
||||
i128_to_felt(chunk)
|
||||
i64_to_felt(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
@@ -350,7 +350,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
180
src/commands.rs
180
src/commands.rs
@@ -1,6 +1,6 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use ethers::types::H160;
|
||||
use alloy::primitives::Address as H160;
|
||||
use clap::{Parser, Subcommand};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
conversion::{FromPyObject, PyTryFrom},
|
||||
@@ -290,7 +290,7 @@ pub enum Commands {
|
||||
Table {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
|
||||
model: PathBuf,
|
||||
model: Option<PathBuf>,
|
||||
/// proving arguments
|
||||
#[clap(flatten)]
|
||||
args: RunArgs,
|
||||
@@ -300,13 +300,13 @@ pub enum Commands {
|
||||
GenWitness {
|
||||
/// The path to the .json data file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
|
||||
data: PathBuf,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
compiled_circuit: PathBuf,
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// Path to output the witness .json file
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_WITNESS)]
|
||||
output: PathBuf,
|
||||
output: Option<PathBuf>,
|
||||
/// Path to the verification key file (optional - solely used to generate kzg commits)
|
||||
#[arg(short = 'V', long)]
|
||||
vk_path: Option<PathBuf>,
|
||||
@@ -319,10 +319,10 @@ pub enum Commands {
|
||||
GenSettings {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
|
||||
model: PathBuf,
|
||||
model: Option<PathBuf>,
|
||||
/// The path to generate the circuit settings .json file to
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
settings_path: Option<PathBuf>,
|
||||
/// proving arguments
|
||||
#[clap(flatten)]
|
||||
args: RunArgs,
|
||||
@@ -333,19 +333,19 @@ pub enum Commands {
|
||||
CalibrateSettings {
|
||||
/// The path to the .json calibration data file.
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE)]
|
||||
data: PathBuf,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
|
||||
model: PathBuf,
|
||||
model: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file AND overwrite (generated using the gen-settings command).
|
||||
#[arg(short = 'O', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
settings_path: Option<PathBuf>,
|
||||
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET)]
|
||||
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
|
||||
target: CalibrationTarget,
|
||||
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN)]
|
||||
lookup_safety_margin: i128,
|
||||
lookup_safety_margin: i64,
|
||||
/// Optional scales to specifically try for calibration. Example, --scales 0,4
|
||||
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
@@ -361,8 +361,8 @@ pub enum Commands {
|
||||
#[arg(long)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE)]
|
||||
only_range_check_rebase: bool,
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
|
||||
only_range_check_rebase: Option<bool>,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
@@ -376,7 +376,7 @@ pub enum Commands {
|
||||
logrows: usize,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Commitments,
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -400,10 +400,10 @@ pub enum Commands {
|
||||
Mock {
|
||||
/// The path to the .json witness file (generated using the gen-witness command)
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
|
||||
witness: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
model: PathBuf,
|
||||
model: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Mock aggregate proofs
|
||||
@@ -413,10 +413,10 @@ pub enum Commands {
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
logrows: u32,
|
||||
logrows: Option<u32>,
|
||||
/// whether the accumulated are segments of a larger proof
|
||||
#[arg(long, default_value = DEFAULT_SPLIT)]
|
||||
split_proofs: bool,
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
split_proofs: Option<bool>,
|
||||
},
|
||||
|
||||
/// setup aggregation circuit :)
|
||||
@@ -426,22 +426,22 @@ pub enum Commands {
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
/// The path to save the desired verification key file to
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
|
||||
vk_path: PathBuf,
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to save the proving key to
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED)]
|
||||
pk_path: PathBuf,
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
logrows: u32,
|
||||
logrows: Option<u32>,
|
||||
/// whether the accumulated are segments of a larger proof
|
||||
#[arg(long, default_value = DEFAULT_SPLIT)]
|
||||
split_proofs: bool,
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
split_proofs: Option<bool>,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
|
||||
disable_selector_compression: bool,
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
|
||||
disable_selector_compression: Option<bool>,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Option<Commitments>,
|
||||
@@ -453,10 +453,10 @@ pub enum Commands {
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
/// The path to load the desired proving key file (generated using the setup-aggregate command)
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED)]
|
||||
pk_path: PathBuf,
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)]
|
||||
proof_path: PathBuf,
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -470,13 +470,13 @@ pub enum Commands {
|
||||
transcript: TranscriptType,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
logrows: u32,
|
||||
logrows: Option<u32>,
|
||||
/// run sanity checks during calculations (safe or unsafe)
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
check_mode: CheckMode,
|
||||
check_mode: Option<CheckMode>,
|
||||
/// whether the accumulated proofs are segments of a larger circuit
|
||||
#[arg(long, default_value = DEFAULT_SPLIT)]
|
||||
split_proofs: bool,
|
||||
#[arg(long, default_value = DEFAULT_SPLIT, action = clap::ArgAction::SetTrue)]
|
||||
split_proofs: Option<bool>,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Option<Commitments>,
|
||||
@@ -485,34 +485,34 @@ pub enum Commands {
|
||||
CompileCircuit {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL)]
|
||||
model: PathBuf,
|
||||
model: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
compiled_circuit: PathBuf,
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
settings_path: Option<PathBuf>,
|
||||
},
|
||||
/// Creates pk and vk
|
||||
Setup {
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
compiled_circuit: PathBuf,
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to output the verification key file to
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
vk_path: PathBuf,
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to output the proving key file to
|
||||
#[arg(long, default_value = DEFAULT_PK)]
|
||||
pk_path: PathBuf,
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The graph witness (optional - used to override fixed values in the circuit)
|
||||
#[arg(short = 'W', long)]
|
||||
witness: Option<PathBuf>,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
|
||||
disable_selector_compression: bool,
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
|
||||
disable_selector_compression: Option<bool>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
@@ -520,10 +520,10 @@ pub enum Commands {
|
||||
SetupTestEvmData {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long)]
|
||||
compiled_circuit: PathBuf,
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information
|
||||
/// derived from the file information in the data .json file.
|
||||
/// Should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
@@ -548,7 +548,7 @@ pub enum Commands {
|
||||
addr: H160Flag,
|
||||
/// The path to the .json data file.
|
||||
#[arg(short = 'D', long)]
|
||||
data: PathBuf,
|
||||
data: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
@@ -558,10 +558,10 @@ pub enum Commands {
|
||||
SwapProofCommitments {
|
||||
/// The path to the proof file
|
||||
#[arg(short = 'P', long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to the witness file
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
|
||||
witness_path: PathBuf,
|
||||
witness_path: Option<PathBuf>,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -569,16 +569,16 @@ pub enum Commands {
|
||||
Prove {
|
||||
/// The path to the .json witness file (generated using the gen-witness command)
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
|
||||
witness: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
compiled_circuit: PathBuf,
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to load the desired proving key file (generated using the setup command)
|
||||
#[arg(long, default_value = DEFAULT_PK)]
|
||||
pk_path: PathBuf,
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -592,7 +592,7 @@ pub enum Commands {
|
||||
proof_type: ProofType,
|
||||
/// run sanity checks during calculations (safe or unsafe)
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
check_mode: CheckMode,
|
||||
check_mode: Option<CheckMode>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for a single proof
|
||||
@@ -603,21 +603,21 @@ pub enum Commands {
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
vk_path: PathBuf,
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE)]
|
||||
sol_code_path: PathBuf,
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_ABI)]
|
||||
abi_path: PathBuf,
|
||||
abi_path: Option<PathBuf>,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
|
||||
render_vk_seperately: bool,
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
|
||||
render_vk_seperately: Option<bool>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier for a single proof
|
||||
@@ -628,16 +628,16 @@ pub enum Commands {
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
vk_path: PathBuf,
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL)]
|
||||
sol_code_path: PathBuf,
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VK_ABI)]
|
||||
abi_path: PathBuf,
|
||||
abi_path: Option<PathBuf>,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
|
||||
@@ -645,20 +645,20 @@ pub enum Commands {
|
||||
CreateEvmDataAttestation {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_DA)]
|
||||
sol_code_path: PathBuf,
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_DA_ABI)]
|
||||
abi_path: PathBuf,
|
||||
abi_path: Option<PathBuf>,
|
||||
/// The path to the .json data file, which should
|
||||
/// contain the necessary calldata and account addresses
|
||||
/// needed to read from all the on-chain
|
||||
/// view functions that return the data that the network
|
||||
/// ingests as inputs.
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
|
||||
data: PathBuf,
|
||||
data: Option<PathBuf>,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -670,60 +670,60 @@ pub enum Commands {
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
|
||||
vk_path: PathBuf,
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_AGGREGATED)]
|
||||
sol_code_path: PathBuf,
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// The path to output the Solidity verifier ABI
|
||||
#[arg(long, default_value = DEFAULT_VERIFIER_AGGREGATED_ABI)]
|
||||
abi_path: PathBuf,
|
||||
abi_path: Option<PathBuf>,
|
||||
// aggregated circuit settings paths, used to calculate the number of instances in the aggregate proof
|
||||
#[arg(long, default_value = DEFAULT_SETTINGS, value_delimiter = ',', allow_hyphen_values = true)]
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
logrows: u32,
|
||||
logrows: Option<u32>,
|
||||
/// Whether the verifier key should be rendered as a separate contract.
|
||||
/// We recommend disabling selector compression if this is enabled.
|
||||
/// To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command.
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY)]
|
||||
render_vk_seperately: bool,
|
||||
#[arg(long, default_value = DEFAULT_RENDER_VK_SEPERATELY, action = clap::ArgAction::SetTrue)]
|
||||
render_vk_seperately: Option<bool>,
|
||||
},
|
||||
/// Verifies a proof, returning accept or reject
|
||||
Verify {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(short = 'S', long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to the verification key file (generated using the setup command)
|
||||
#[arg(long, default_value = DEFAULT_VK)]
|
||||
vk_path: PathBuf,
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
|
||||
reduced_srs: bool,
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
},
|
||||
/// Verifies an aggregate proof, returning accept or reject
|
||||
VerifyAggr {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED)]
|
||||
proof_path: PathBuf,
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to the verification key file (generated using the setup-aggregate command)
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
|
||||
vk_path: PathBuf,
|
||||
vk_path: Option<PathBuf>,
|
||||
/// reduced srs
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
|
||||
reduced_srs: bool,
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
logrows: u32,
|
||||
logrows: Option<u32>,
|
||||
/// commitment
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Option<Commitments>,
|
||||
@@ -733,13 +733,13 @@ pub enum Commands {
|
||||
DeployEvmVerifier {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE)]
|
||||
sol_code_path: PathBuf,
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
|
||||
/// The path to output the contract address
|
||||
addr_path: PathBuf,
|
||||
addr_path: Option<PathBuf>,
|
||||
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
|
||||
optimizer_runs: usize,
|
||||
@@ -752,13 +752,13 @@ pub enum Commands {
|
||||
DeployEvmVK {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_VK_SOL)]
|
||||
sol_code_path: PathBuf,
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_VK)]
|
||||
/// The path to output the contract address
|
||||
addr_path: PathBuf,
|
||||
addr_path: Option<PathBuf>,
|
||||
/// The optimizer runs to set on the verifier. Lower values optimize for deployment cost, while higher values optimize for gas cost.
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
|
||||
optimizer_runs: usize,
|
||||
@@ -772,19 +772,19 @@ pub enum Commands {
|
||||
DeployEvmDataAttestation {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA)]
|
||||
data: PathBuf,
|
||||
data: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
#[arg(long, default_value = DEFAULT_SETTINGS)]
|
||||
settings_path: PathBuf,
|
||||
settings_path: Option<PathBuf>,
|
||||
/// The path to the Solidity code
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE_DA)]
|
||||
sol_code_path: PathBuf,
|
||||
sol_code_path: Option<PathBuf>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long)]
|
||||
rpc_url: Option<String>,
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA)]
|
||||
/// The path to output the contract address
|
||||
addr_path: PathBuf,
|
||||
addr_path: Option<PathBuf>,
|
||||
/// The optimizer runs to set on the verifier. (Lower values optimize for deployment, while higher values optimize for execution)
|
||||
#[arg(long, default_value = DEFAULT_OPTIMIZER_RUNS)]
|
||||
optimizer_runs: usize,
|
||||
@@ -798,7 +798,7 @@ pub enum Commands {
|
||||
VerifyEvm {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
#[arg(long, default_value = DEFAULT_PROOF)]
|
||||
proof_path: PathBuf,
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to verifier contract's address
|
||||
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS)]
|
||||
addr_verifier: H160Flag,
|
||||
|
||||
686
src/eth.rs
686
src/eth.rs
File diff suppressed because one or more lines are too long
314
src/execute.rs
314
src/execute.rs
@@ -1,9 +1,7 @@
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::commands::Commands;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::H160Flag;
|
||||
use crate::commands::*;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -23,6 +21,8 @@ use crate::pfsys::{
|
||||
};
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::{Commitments, RunArgs};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
@@ -66,48 +66,16 @@ use snark_verifier::system::halo2::Config;
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::BufWriter;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::{Cursor, Write};
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::process::Command;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::sync::OnceLock;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::BufWriter;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use tabled::Tabled;
|
||||
use thiserror::Error;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
static _SOLC_REQUIREMENT: OnceLock<bool> = OnceLock::new();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_solc_requirement() {
|
||||
info!("checking solc installation..");
|
||||
_SOLC_REQUIREMENT.get_or_init(|| match Command::new("solc").arg("--version").output() {
|
||||
Ok(output) => {
|
||||
debug!("solc output: {:#?}", output);
|
||||
debug!("solc output success: {:#?}", output.status.success());
|
||||
if !output.status.success() {
|
||||
log::error!(
|
||||
"`solc` check failed: {}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
return false;
|
||||
}
|
||||
debug!("solc check passed, proceeding");
|
||||
true
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("`solc` check failed: solc not found");
|
||||
false
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
lazy_static! {
|
||||
@@ -152,7 +120,11 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
logrows,
|
||||
commitment,
|
||||
} => gen_srs_cmd(srs_path, logrows as u32, commitment),
|
||||
} => gen_srs_cmd(
|
||||
srs_path,
|
||||
logrows as u32,
|
||||
commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT)?),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::GetSrs {
|
||||
srs_path,
|
||||
@@ -160,12 +132,16 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
logrows,
|
||||
commitment,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows, commitment).await,
|
||||
Commands::Table { model, args } => table(model, args),
|
||||
Commands::Table { model, args } => table(model.unwrap_or(DEFAULT_MODEL.into()), args),
|
||||
Commands::GenSettings {
|
||||
model,
|
||||
settings_path,
|
||||
args,
|
||||
} => gen_circuit_settings(model, settings_path, args),
|
||||
} => gen_circuit_settings(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
args,
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CalibrateSettings {
|
||||
model,
|
||||
@@ -178,16 +154,17 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
max_logrows,
|
||||
only_range_check_rebase,
|
||||
} => calibrate(
|
||||
model,
|
||||
data,
|
||||
settings_path,
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse()?),
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::GenWitness {
|
||||
data,
|
||||
@@ -195,9 +172,19 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
output,
|
||||
vk_path,
|
||||
srs_path,
|
||||
} => gen_witness(compiled_circuit, data, Some(output), vk_path, srs_path)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(model, witness),
|
||||
} => gen_witness(
|
||||
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
Some(output.unwrap_or(DEFAULT_WITNESS.into())),
|
||||
vk_path,
|
||||
srs_path,
|
||||
)
|
||||
.await
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::Mock { model, witness } => mock(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
witness.unwrap_or(DEFAULT_WITNESS.into()),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmVerifier {
|
||||
vk_path,
|
||||
@@ -206,28 +193,48 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
} => create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
),
|
||||
} => {
|
||||
create_evm_verifier(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
srs_path,
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_ABI.into()),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
Commands::CreateEvmVK {
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
} => create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path),
|
||||
} => {
|
||||
create_evm_vk(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
srs_path,
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_VK_SOL.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VK_ABI.into()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmDataAttestation {
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
data,
|
||||
} => create_evm_data_attestation(settings_path, sol_code_path, abi_path, data),
|
||||
} => {
|
||||
create_evm_data_attestation(
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_DA.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_DA_ABI.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::CreateEvmVerifierAggr {
|
||||
vk_path,
|
||||
@@ -237,20 +244,27 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
} => create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
),
|
||||
} => {
|
||||
create_evm_aggregate_verifier(
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
srs_path,
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_AGGREGATED.into()),
|
||||
abi_path.unwrap_or(DEFAULT_VERIFIER_AGGREGATED_ABI.into()),
|
||||
aggregation_settings,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
render_vk_seperately.unwrap_or(DEFAULT_RENDER_VK_SEPERATELY.parse()?),
|
||||
)
|
||||
.await
|
||||
}
|
||||
Commands::CompileCircuit {
|
||||
model,
|
||||
compiled_circuit,
|
||||
settings_path,
|
||||
} => compile_circuit(model, compiled_circuit, settings_path),
|
||||
} => compile_circuit(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
),
|
||||
Commands::Setup {
|
||||
compiled_circuit,
|
||||
srs_path,
|
||||
@@ -259,12 +273,12 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
witness,
|
||||
disable_selector_compression,
|
||||
} => setup(
|
||||
compiled_circuit,
|
||||
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
|
||||
srs_path,
|
||||
vk_path,
|
||||
pk_path,
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
pk_path.unwrap_or(DEFAULT_PK.into()),
|
||||
witness,
|
||||
disable_selector_compression,
|
||||
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SetupTestEvmData {
|
||||
@@ -276,8 +290,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
output_source,
|
||||
} => {
|
||||
setup_test_evm_witness(
|
||||
data,
|
||||
compiled_circuit,
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
|
||||
test_data,
|
||||
rpc_url,
|
||||
input_source,
|
||||
@@ -290,13 +304,17 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
addr,
|
||||
data,
|
||||
rpc_url,
|
||||
} => test_update_account_calls(addr, data, rpc_url).await,
|
||||
} => test_update_account_calls(addr, data.unwrap_or(DEFAULT_DATA.into()), rpc_url).await,
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SwapProofCommitments {
|
||||
proof_path,
|
||||
witness_path,
|
||||
} => swap_proof_commitments_cmd(proof_path, witness_path)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
} => swap_proof_commitments_cmd(
|
||||
proof_path.unwrap_or(DEFAULT_PROOF.into()),
|
||||
witness_path.unwrap_or(DEFAULT_WITNESS.into()),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::Prove {
|
||||
witness,
|
||||
@@ -307,20 +325,24 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
proof_type,
|
||||
check_mode,
|
||||
} => prove(
|
||||
witness,
|
||||
compiled_circuit,
|
||||
pk_path,
|
||||
Some(proof_path),
|
||||
witness.unwrap_or(DEFAULT_WITNESS.into()),
|
||||
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
|
||||
pk_path.unwrap_or(DEFAULT_PK.into()),
|
||||
Some(proof_path.unwrap_or(DEFAULT_PROOF.into())),
|
||||
srs_path,
|
||||
proof_type,
|
||||
check_mode,
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::MockAggregate {
|
||||
aggregation_snarks,
|
||||
logrows,
|
||||
split_proofs,
|
||||
} => mock_aggregate(aggregation_snarks, logrows, split_proofs),
|
||||
} => mock_aggregate(
|
||||
aggregation_snarks,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
),
|
||||
Commands::SetupAggregate {
|
||||
sample_snarks,
|
||||
vk_path,
|
||||
@@ -332,12 +354,12 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
commitment,
|
||||
} => setup_aggregate(
|
||||
sample_snarks,
|
||||
vk_path,
|
||||
pk_path,
|
||||
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
|
||||
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
logrows,
|
||||
split_proofs,
|
||||
disable_selector_compression,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
disable_selector_compression.unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse()?),
|
||||
commitment.into(),
|
||||
),
|
||||
Commands::Aggregate {
|
||||
@@ -351,14 +373,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
split_proofs,
|
||||
commitment,
|
||||
} => aggregate(
|
||||
proof_path,
|
||||
proof_path.unwrap_or(DEFAULT_PROOF_AGGREGATED.into()),
|
||||
aggregation_snarks,
|
||||
pk_path,
|
||||
pk_path.unwrap_or(DEFAULT_PK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
transcript,
|
||||
logrows,
|
||||
check_mode,
|
||||
split_proofs,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
check_mode.unwrap_or(DEFAULT_CHECKMODE.parse()?),
|
||||
split_proofs.unwrap_or(DEFAULT_SPLIT.parse()?),
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -368,8 +390,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path,
|
||||
srs_path,
|
||||
reduced_srs,
|
||||
} => verify(proof_path, settings_path, vk_path, srs_path, reduced_srs)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
} => verify(
|
||||
proof_path.unwrap_or(DEFAULT_PROOF.into()),
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
vk_path.unwrap_or(DEFAULT_VK.into()),
|
||||
srs_path,
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
Commands::VerifyAggr {
|
||||
proof_path,
|
||||
vk_path,
|
||||
@@ -378,11 +406,11 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
logrows,
|
||||
commitment,
|
||||
} => verify_aggr(
|
||||
proof_path,
|
||||
vk_path,
|
||||
proof_path.unwrap_or(DEFAULT_PROOF_AGGREGATED.into()),
|
||||
vk_path.unwrap_or(DEFAULT_VK_AGGREGATED.into()),
|
||||
srs_path,
|
||||
logrows,
|
||||
reduced_srs,
|
||||
logrows.unwrap_or(DEFAULT_AGGREGATED_LOGROWS.parse()?),
|
||||
reduced_srs.unwrap_or(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse()?),
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
@@ -395,9 +423,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
private_key,
|
||||
} => {
|
||||
deploy_evm(
|
||||
sol_code_path,
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
rpc_url,
|
||||
addr_path,
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
@@ -413,9 +441,9 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
private_key,
|
||||
} => {
|
||||
deploy_evm(
|
||||
sol_code_path,
|
||||
sol_code_path.unwrap_or(DEFAULT_VK_SOL.into()),
|
||||
rpc_url,
|
||||
addr_path,
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS_VK.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
@@ -433,11 +461,11 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
private_key,
|
||||
} => {
|
||||
deploy_da_evm(
|
||||
data,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
settings_path.unwrap_or(DEFAULT_SETTINGS.into()),
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE_DA.into()),
|
||||
rpc_url,
|
||||
addr_path,
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS_DA.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
)
|
||||
@@ -450,7 +478,16 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
rpc_url,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
} => verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk).await,
|
||||
} => {
|
||||
verify_evm(
|
||||
proof_path.unwrap_or(DEFAULT_PROOF.into()),
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,7 +673,7 @@ pub(crate) fn table(model: PathBuf, run_args: RunArgs) -> Result<String, Box<dyn
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) fn gen_witness(
|
||||
pub(crate) async fn gen_witness(
|
||||
compiled_circuit_path: PathBuf,
|
||||
data: PathBuf,
|
||||
output: Option<PathBuf>,
|
||||
@@ -659,7 +696,7 @@ pub(crate) fn gen_witness(
|
||||
};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
@@ -682,6 +719,7 @@ pub(crate) fn gen_witness(
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
true,
|
||||
true,
|
||||
)?
|
||||
}
|
||||
Commitments::IPA => {
|
||||
@@ -696,15 +734,22 @@ pub(crate) fn gen_witness(
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
true,
|
||||
true,
|
||||
)?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("SRS for poly commit does not exist (will be ignored)");
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
None,
|
||||
true,
|
||||
true,
|
||||
)?
|
||||
}
|
||||
} else {
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true, true)?
|
||||
};
|
||||
|
||||
// print each variable tuple (symbol, value) as symbol=value
|
||||
@@ -882,12 +927,12 @@ impl AccuracyResults {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(trivial_casts)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn calibrate(
|
||||
pub(crate) async fn calibrate(
|
||||
model_path: PathBuf,
|
||||
data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i128,
|
||||
lookup_safety_margin: i64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
@@ -905,7 +950,9 @@ pub(crate) fn calibrate(
|
||||
|
||||
let model = Model::from_run_args(&settings.run_args, &model_path)?;
|
||||
|
||||
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
|
||||
let input_shapes = model.graph.input_shapes()?;
|
||||
|
||||
let chunks = data.split_into_batches(input_shapes).await?;
|
||||
info!("num calibration batches: {}", chunks.len());
|
||||
|
||||
debug!("running onnx predictions...");
|
||||
@@ -1003,6 +1050,7 @@ pub(crate) fn calibrate(
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
lookup_range: (i64::MIN, i64::MAX),
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
@@ -1038,7 +1086,13 @@ pub(crate) fn calibrate(
|
||||
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
|
||||
|
||||
let forward_res = circuit
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut data.clone(), None, None, true)
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(
|
||||
&mut data.clone(),
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
// push result to the hashmap
|
||||
@@ -1053,7 +1107,7 @@ pub(crate) fn calibrate(
|
||||
|
||||
match forward_res {
|
||||
Ok(_) => (),
|
||||
// typically errors will be due to the circuit overflowing the i128 limit
|
||||
// typically errors will be due to the circuit overflowing the i64 limit
|
||||
Err(e) => {
|
||||
error!("forward pass failed: {:?}", e);
|
||||
pb.inc(1);
|
||||
@@ -1281,7 +1335,7 @@ pub(crate) fn mock(
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn create_evm_verifier(
|
||||
pub(crate) async fn create_evm_verifier(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: PathBuf,
|
||||
@@ -1289,8 +1343,6 @@ pub(crate) fn create_evm_verifier(
|
||||
abi_path: PathBuf,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
|
||||
@@ -1320,7 +1372,7 @@ pub(crate) fn create_evm_verifier(
|
||||
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0)?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
@@ -1328,14 +1380,13 @@ pub(crate) fn create_evm_verifier(
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn create_evm_vk(
|
||||
pub(crate) async fn create_evm_vk(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
|
||||
@@ -1362,7 +1413,7 @@ pub(crate) fn create_evm_vk(
|
||||
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0)?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingKey", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
@@ -1370,7 +1421,7 @@ pub(crate) fn create_evm_vk(
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn create_evm_data_attestation(
|
||||
pub(crate) async fn create_evm_data_attestation(
|
||||
settings_path: PathBuf,
|
||||
_sol_code_path: PathBuf,
|
||||
_abi_path: PathBuf,
|
||||
@@ -1378,7 +1429,6 @@ pub(crate) fn create_evm_data_attestation(
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
#[allow(unused_imports)]
|
||||
use crate::graph::{DataSource, VarVisibility};
|
||||
check_solc_requirement();
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
|
||||
@@ -1418,7 +1468,7 @@ pub(crate) fn create_evm_data_attestation(
|
||||
let mut f = File::create(_sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0)?;
|
||||
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(_abi_path)?, &abi)?;
|
||||
} else {
|
||||
@@ -1439,7 +1489,6 @@ pub(crate) async fn deploy_da_evm(
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let contract_address = deploy_da_verifier_via_solidity(
|
||||
settings_path,
|
||||
data,
|
||||
@@ -1466,7 +1515,6 @@ pub(crate) async fn deploy_evm(
|
||||
private_key: Option<String>,
|
||||
contract_name: &str,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
@@ -1492,7 +1540,6 @@ pub(crate) async fn verify_evm(
|
||||
addr_vk: Option<H160Flag>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::verify_proof_with_data_attestation;
|
||||
check_solc_requirement();
|
||||
|
||||
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
|
||||
|
||||
@@ -1525,7 +1572,7 @@ pub(crate) async fn verify_evm(
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) fn create_evm_aggregate_verifier(
|
||||
pub(crate) async fn create_evm_aggregate_verifier(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
sol_code_path: PathBuf,
|
||||
@@ -1534,7 +1581,6 @@ pub(crate) fn create_evm_aggregate_verifier(
|
||||
logrows: u32,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let srs_path = get_srs_path(logrows, srs_path, Commitments::KZG);
|
||||
let params: ParamsKZG<Bn256> = load_srs_verifier::<KZGCommitmentScheme<Bn256>>(srs_path)?;
|
||||
|
||||
@@ -1580,7 +1626,7 @@ pub(crate) fn create_evm_aggregate_verifier(
|
||||
File::create(sol_code_path.clone())?.write_all(verifier_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0)?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2Verifier", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
@@ -1660,7 +1706,6 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::graph::TestOnChainData;
|
||||
|
||||
info!("run this command in background to keep the instance running for testing");
|
||||
let mut data = GraphData::from_path(data_path)?;
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit_path)?;
|
||||
|
||||
@@ -1696,7 +1741,6 @@ pub(crate) async fn test_update_account_calls(
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
use crate::eth::update_account_calls;
|
||||
|
||||
check_solc_requirement();
|
||||
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
|
||||
|
||||
Ok(String::new())
|
||||
|
||||
@@ -11,8 +11,8 @@ pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts an i128 to a PrimeField element.
|
||||
pub fn i128_to_felt<F: PrimeField>(x: i128) -> F {
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else {
|
||||
@@ -37,7 +37,7 @@ pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
|
||||
|
||||
/// Converts a PrimeField element to an f64.
|
||||
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
if x > F::from_u128(i128::MAX as u128) {
|
||||
if x > F::from_u128(i64::MAX as u128) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -50,18 +50,18 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i128.
|
||||
pub fn felt_to_i128<F: PrimeField + PartialOrd + Field>(x: F) -> i128 {
|
||||
if x > F::from_u128(i128::MAX as u128) {
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_i64<F: PrimeField + PartialOrd + Field>(x: F) -> i64 {
|
||||
if x > F::from_u128(i64::MAX as u128) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
-(lower_128 as i128)
|
||||
-(lower_128 as i64)
|
||||
} else {
|
||||
let rep = (x).to_repr();
|
||||
let tmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
|
||||
lower_128 as i128
|
||||
lower_128 as i64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,10 +79,10 @@ mod test {
|
||||
let res: F = i32_to_felt(2_i32.pow(17));
|
||||
assert_eq!(res, F::from(131072));
|
||||
|
||||
let res: F = i128_to_felt(-15i128);
|
||||
let res: F = i64_to_felt(-15i64);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
let res: F = i128_to_felt(2_i128.pow(17));
|
||||
let res: F = i64_to_felt(2_i64.pow(17));
|
||||
assert_eq!(res, F::from(131072));
|
||||
}
|
||||
|
||||
@@ -96,10 +96,10 @@ mod test {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttoi128() {
|
||||
for x in -(2i128.pow(20))..(2i128.pow(20)) {
|
||||
let fieldx: F = i128_to_felt::<F>(x);
|
||||
let xf: i128 = felt_to_i128::<F>(fieldx);
|
||||
fn felttoi64() {
|
||||
for x in -(2i64.pow(20))..(2i64.pow(20)) {
|
||||
let fieldx: F = i64_to_felt::<F>(x);
|
||||
let xf: i64 = felt_to_i64::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use super::quantize_float;
|
||||
use super::GraphError;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::postgres::Client;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::tensor::Tensor;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use postgres::{Client, NoTls};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -128,7 +128,7 @@ impl FileSourceInner {
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => i128_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
@@ -150,7 +150,7 @@ impl FileSourceInner {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i128(*f) as f64,
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -211,7 +211,9 @@ impl PostgresSource {
|
||||
}
|
||||
|
||||
/// Fetch data from postgres
|
||||
pub fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
|
||||
pub async fn fetch(
|
||||
&self,
|
||||
) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, Box<dyn std::error::Error>> {
|
||||
// clone to move into thread
|
||||
let user = self.user.clone();
|
||||
let host = self.host.clone();
|
||||
@@ -232,10 +234,10 @@ impl PostgresSource {
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config, NoTls)?;
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
// extract rows from query
|
||||
for row in client.query(&query, &[])? {
|
||||
for row in client.query(&query, &[]).await? {
|
||||
// extract features from row
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
@@ -245,11 +247,12 @@ impl PostgresSource {
|
||||
}
|
||||
|
||||
/// Fetch data from postgres and format it as a FileSource
|
||||
pub fn fetch_and_format_as_file(
|
||||
pub async fn fetch_and_format_as_file(
|
||||
&self,
|
||||
) -> Result<Vec<Vec<FileSourceInner>>, Box<dyn std::error::Error>> {
|
||||
Ok(self
|
||||
.fetch()?
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
@@ -277,13 +280,13 @@ impl OnChainSource {
|
||||
mut shapes: Vec<Vec<usize>>,
|
||||
rpc: Option<&str>,
|
||||
) -> Result<(Vec<Tensor<Fp>>, Self), Box<dyn std::error::Error>> {
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, test_on_chain_data};
|
||||
use crate::eth::{
|
||||
evm_quantize, read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT,
|
||||
};
|
||||
use log::debug;
|
||||
|
||||
// Set up local anvil instance for reading on-chain data
|
||||
let (anvil, client) = crate::eth::setup_eth_backend(rpc, None).await?;
|
||||
|
||||
let address = client.address();
|
||||
let (client, client_address) = crate::eth::setup_eth_backend(rpc, None).await?;
|
||||
|
||||
let mut scales = scales;
|
||||
// set scales to 1 where data is a field element
|
||||
@@ -296,7 +299,8 @@ impl OnChainSource {
|
||||
|
||||
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
|
||||
debug!("Calls to accounts: {:?}", calls_to_accounts);
|
||||
let inputs = read_on_chain_inputs(client.clone(), address, &calls_to_accounts).await?;
|
||||
let inputs =
|
||||
read_on_chain_inputs(client.clone(), client_address, &calls_to_accounts).await?;
|
||||
debug!("Inputs: {:?}", inputs);
|
||||
|
||||
let mut quantized_evm_inputs = vec![];
|
||||
@@ -325,7 +329,7 @@ impl OnChainSource {
|
||||
inputs.push(t);
|
||||
}
|
||||
|
||||
let used_rpc = rpc.unwrap_or(&anvil.endpoint()).to_string();
|
||||
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
|
||||
|
||||
// Fill the input_data field of the GraphData struct
|
||||
Ok((
|
||||
@@ -502,7 +506,7 @@ impl GraphData {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn split_into_batches(
|
||||
pub async fn split_into_batches(
|
||||
&self,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
|
||||
@@ -527,7 +531,7 @@ impl GraphData {
|
||||
GraphData {
|
||||
input_data: DataSource::DB(data),
|
||||
output_data: _,
|
||||
} => data.fetch_and_format_as_file()?,
|
||||
} => data.fetch_and_format_as_file().await?,
|
||||
};
|
||||
|
||||
for (i, shape) in input_shapes.iter().enumerate() {
|
||||
|
||||
@@ -6,10 +6,14 @@ pub mod model;
|
||||
pub mod modules;
|
||||
/// Inner elements of a computational graph that represent a single operation / constraints.
|
||||
pub mod node;
|
||||
/// postgres helper functions
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub mod postgres;
|
||||
/// Helper functions
|
||||
pub mod utilities;
|
||||
/// Representations of a computational graph's variables.
|
||||
pub mod vars;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored_json::ToColoredJson;
|
||||
#[cfg(unix)]
|
||||
@@ -62,13 +66,13 @@ pub use vars::*;
|
||||
use crate::pfsys::field_to_string;
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
pub const RANGE_MULTIPLIER: i64 = 2;
|
||||
|
||||
/// The maximum number of columns in a lookup table.
|
||||
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
|
||||
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
@@ -175,11 +179,11 @@ pub struct GraphWitness {
|
||||
/// Any hashes of outputs generated during the forward pass
|
||||
pub processed_outputs: Option<ModuleForwardResult>,
|
||||
/// max lookup input
|
||||
pub max_lookup_inputs: i128,
|
||||
pub max_lookup_inputs: i64,
|
||||
/// max lookup input
|
||||
pub min_lookup_inputs: i128,
|
||||
pub min_lookup_inputs: i64,
|
||||
/// max range check size
|
||||
pub max_range_size: i128,
|
||||
pub max_range_size: i64,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -958,7 +962,7 @@ impl GraphCircuit {
|
||||
|
||||
///
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn load_graph_input(
|
||||
pub async fn load_graph_input(
|
||||
&mut self,
|
||||
data: &GraphData,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
@@ -968,6 +972,7 @@ impl GraphCircuit {
|
||||
debug!("input scales: {:?}", scales);
|
||||
|
||||
self.process_data_source(&data.input_data, shapes, scales, input_types)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
@@ -991,7 +996,7 @@ impl GraphCircuit {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Process the data source for the model
|
||||
fn process_data_source(
|
||||
async fn process_data_source(
|
||||
&mut self,
|
||||
data: &DataSource,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
@@ -1005,21 +1010,14 @@ impl GraphCircuit {
|
||||
per_item_scale.extend(vec![scales[i]; shape.iter().product::<usize>()]);
|
||||
}
|
||||
|
||||
// start runtime and fetch data
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
})
|
||||
self.load_on_chain_data(source.clone(), &shapes, per_item_scale)
|
||||
.await
|
||||
}
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
DataSource::DB(pg) => {
|
||||
let data = pg.fetch_and_format_as_file()?;
|
||||
let data = pg.fetch_and_format_as_file().await?;
|
||||
self.load_file_data(&data, &shapes, scales, input_types)
|
||||
}
|
||||
}
|
||||
@@ -1034,8 +1032,8 @@ impl GraphCircuit {
|
||||
scales: Vec<crate::Scale>,
|
||||
) -> Result<Vec<Tensor<Fp>>, Box<dyn std::error::Error>> {
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
|
||||
let (_, client) = setup_eth_backend(Some(&source.rpc), None).await?;
|
||||
let inputs = read_on_chain_inputs(client.clone(), client.address(), &source.calls).await?;
|
||||
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
|
||||
let inputs = read_on_chain_inputs(client.clone(), client_address, &source.calls).await?;
|
||||
// quantize the supplied data using the provided scale + QuantizeData.sol
|
||||
let quantized_evm_inputs = evm_quantize(client, scales, &inputs).await?;
|
||||
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
|
||||
@@ -1098,14 +1096,14 @@ impl GraphCircuit {
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i64) -> Range {
|
||||
(
|
||||
lookup_safety_margin * min_max_lookup.0,
|
||||
lookup_safety_margin * min_max_lookup.1,
|
||||
)
|
||||
}
|
||||
|
||||
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
|
||||
fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
|
||||
num_cols_required(range_len, max_col_size)
|
||||
}
|
||||
@@ -1113,7 +1111,7 @@ impl GraphCircuit {
|
||||
fn table_size_logrows(
|
||||
&self,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i128,
|
||||
max_range_size: i64,
|
||||
) -> Result<u32, Box<dyn std::error::Error>> {
|
||||
// pick the range with the largest absolute size safe_lookup_range or max_range_size
|
||||
let safe_range = std::cmp::max(
|
||||
@@ -1132,9 +1130,9 @@ impl GraphCircuit {
|
||||
pub fn calc_min_logrows(
|
||||
&mut self,
|
||||
min_max_lookup: Range,
|
||||
max_range_size: i128,
|
||||
max_range_size: i64,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i128,
|
||||
lookup_safety_margin: i64,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// load the max logrows
|
||||
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
|
||||
@@ -1228,7 +1226,7 @@ impl GraphCircuit {
|
||||
&self,
|
||||
k: u32,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i128,
|
||||
max_range_size: i64,
|
||||
) -> bool {
|
||||
// if num cols is too large then the extended k is too large
|
||||
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
|
||||
@@ -1287,6 +1285,7 @@ impl GraphCircuit {
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
@@ -1335,7 +1334,7 @@ impl GraphCircuit {
|
||||
|
||||
let mut model_results =
|
||||
self.model()
|
||||
.forward(inputs, &self.settings().run_args, witness_gen)?;
|
||||
.forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
|
||||
@@ -65,11 +65,11 @@ pub struct ForwardResult {
|
||||
/// The outputs of the forward pass.
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
/// The maximum value of any input to a lookup operation.
|
||||
pub max_lookup_inputs: i128,
|
||||
pub max_lookup_inputs: i64,
|
||||
/// The minimum value of any input to a lookup operation.
|
||||
pub min_lookup_inputs: i128,
|
||||
pub min_lookup_inputs: i64,
|
||||
/// The max range check size
|
||||
pub max_range_size: i128,
|
||||
pub max_range_size: i64,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
@@ -117,11 +117,11 @@ pub struct DummyPassRes {
|
||||
/// range checks
|
||||
pub range_checks: HashSet<Range>,
|
||||
/// max lookup inputs
|
||||
pub max_lookup_inputs: i128,
|
||||
pub max_lookup_inputs: i64,
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: i128,
|
||||
pub min_lookup_inputs: i64,
|
||||
/// min range check
|
||||
pub max_range_size: i128,
|
||||
pub max_range_size: i64,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
@@ -538,7 +538,7 @@ impl Model {
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs, false)?;
|
||||
let res = self.dummy_layout(run_args, &inputs, false, false)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -582,12 +582,13 @@ impl Model {
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?;
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -1392,6 +1393,7 @@ impl Model {
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
@@ -1410,7 +1412,8 @@ impl Model {
|
||||
vars: ModelVars::new_dummy(),
|
||||
};
|
||||
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen);
|
||||
let mut region =
|
||||
RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup);
|
||||
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
|
||||
@@ -314,7 +314,6 @@ impl GraphModules {
|
||||
let commitments = inputs.iter().fold(vec![], |mut acc, x| {
|
||||
let res = PolyCommitChip::commit::<Scheme>(
|
||||
x.to_vec(),
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
srs,
|
||||
);
|
||||
|
||||
492
src/graph/postgres.rs
Normal file
492
src/graph/postgres.rs
Normal file
@@ -0,0 +1,492 @@
|
||||
use log::{debug, error, info};
|
||||
use std::fmt::Debug;
|
||||
use std::net::IpAddr;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, pin::Pin};
|
||||
use tokio::task::JoinHandle;
|
||||
#[doc(inline)]
|
||||
pub use tokio_postgres::config::{
|
||||
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
|
||||
};
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::NoTls;
|
||||
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
|
||||
|
||||
/// Connection configuration.
|
||||
///
|
||||
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
|
||||
///
|
||||
///
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
config: tokio_postgres::Config,
|
||||
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Config {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.debug_struct("Config")
|
||||
.field("config", &self.config)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Config {
|
||||
Config::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Creates a new configuration.
|
||||
pub fn new() -> Config {
|
||||
tokio_postgres::Config::new().into()
|
||||
}
|
||||
|
||||
/// Sets the user to authenticate with.
|
||||
///
|
||||
/// If the user is not set, then this defaults to the user executing this process.
|
||||
pub fn user(&mut self, user: &str) -> &mut Config {
|
||||
self.config.user(user);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the user to authenticate with, if one has been configured with
|
||||
/// the `user` method.
|
||||
pub fn get_user(&self) -> Option<&str> {
|
||||
self.config.get_user()
|
||||
}
|
||||
|
||||
/// Sets the password to authenticate with.
|
||||
pub fn password<T>(&mut self, password: T) -> &mut Config
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
{
|
||||
self.config.password(password);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the password to authenticate with, if one has been configured with
|
||||
/// the `password` method.
|
||||
pub fn get_password(&self) -> Option<&[u8]> {
|
||||
self.config.get_password()
|
||||
}
|
||||
|
||||
/// Sets the name of the database to connect to.
|
||||
///
|
||||
/// Defaults to the user.
|
||||
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
|
||||
self.config.dbname(dbname);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the name of the database to connect to, if one has been configured
|
||||
/// with the `dbname` method.
|
||||
pub fn get_dbname(&self) -> Option<&str> {
|
||||
self.config.get_dbname()
|
||||
}
|
||||
|
||||
/// Sets command line options used to configure the server.
|
||||
pub fn options(&mut self, options: &str) -> &mut Config {
|
||||
self.config.options(options);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the command line options used to configure the server, if the
|
||||
/// options have been set with the `options` method.
|
||||
pub fn get_options(&self) -> Option<&str> {
|
||||
self.config.get_options()
|
||||
}
|
||||
|
||||
/// Sets the value of the `application_name` runtime parameter.
|
||||
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
|
||||
self.config.application_name(application_name);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the value of the `application_name` runtime parameter, if it has
|
||||
/// been set with the `application_name` method.
|
||||
pub fn get_application_name(&self) -> Option<&str> {
|
||||
self.config.get_application_name()
|
||||
}
|
||||
|
||||
/// Sets the SSL configuration.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
|
||||
self.config.ssl_mode(ssl_mode);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the SSL configuration.
|
||||
pub fn get_ssl_mode(&self) -> SslMode {
|
||||
self.config.get_ssl_mode()
|
||||
}
|
||||
|
||||
/// Adds a host to the configuration.
|
||||
///
|
||||
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
|
||||
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
|
||||
/// There must be either no hosts, or the same number of hosts as hostaddrs.
|
||||
pub fn host(&mut self, host: &str) -> &mut Config {
|
||||
self.config.host(host);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the hosts that have been added to the configuration with `host`.
|
||||
pub fn get_hosts(&self) -> &[Host] {
|
||||
self.config.get_hosts()
|
||||
}
|
||||
|
||||
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
|
||||
pub fn get_hostaddrs(&self) -> &[IpAddr] {
|
||||
self.config.get_hostaddrs()
|
||||
}
|
||||
|
||||
/// Adds a Unix socket host to the configuration.
|
||||
///
|
||||
/// Unlike `host`, this method allows non-UTF8 paths.
|
||||
#[cfg(unix)]
|
||||
pub fn host_path<T>(&mut self, host: T) -> &mut Config
|
||||
where
|
||||
T: AsRef<Path>,
|
||||
{
|
||||
self.config.host_path(host);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a hostaddr to the configuration.
|
||||
///
|
||||
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
|
||||
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
|
||||
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
|
||||
self.config.hostaddr(hostaddr);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a port to the configuration.
|
||||
///
|
||||
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
|
||||
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
|
||||
/// as hosts.
|
||||
pub fn port(&mut self, port: u16) -> &mut Config {
|
||||
self.config.port(port);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the ports that have been added to the configuration with `port`.
|
||||
pub fn get_ports(&self) -> &[u16] {
|
||||
self.config.get_ports()
|
||||
}
|
||||
|
||||
/// Sets the timeout applied to socket-level connection attempts.
|
||||
///
|
||||
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
|
||||
/// host separately. Defaults to no limit.
|
||||
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
|
||||
self.config.connect_timeout(connect_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the connection timeout, if one has been set with the
|
||||
/// `connect_timeout` method.
|
||||
pub fn get_connect_timeout(&self) -> Option<&Duration> {
|
||||
self.config.get_connect_timeout()
|
||||
}
|
||||
|
||||
/// Sets the TCP user timeout.
|
||||
///
|
||||
/// This is ignored for Unix domain socket connections. It is only supported on systems where
|
||||
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
|
||||
/// on other systems, it has no effect.
|
||||
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
|
||||
self.config.tcp_user_timeout(tcp_user_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the TCP user timeout, if one has been set with the
|
||||
/// `user_timeout` method.
|
||||
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
|
||||
self.config.get_tcp_user_timeout()
|
||||
}
|
||||
|
||||
/// Controls the use of TCP keepalive.
|
||||
///
|
||||
/// This is ignored for Unix domain socket connections. Defaults to `true`.
|
||||
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
|
||||
self.config.keepalives(keepalives);
|
||||
self
|
||||
}
|
||||
|
||||
/// Reports whether TCP keepalives will be used.
|
||||
pub fn get_keepalives(&self) -> bool {
|
||||
self.config.get_keepalives()
|
||||
}
|
||||
|
||||
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
|
||||
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
|
||||
self.config.keepalives_idle(keepalives_idle);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the configured amount of idle time before a keepalive packet will
|
||||
/// be sent on the connection.
|
||||
pub fn get_keepalives_idle(&self) -> Duration {
|
||||
self.config.get_keepalives_idle()
|
||||
}
|
||||
|
||||
/// Sets the time interval between TCP keepalive probes.
|
||||
/// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
|
||||
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
|
||||
self.config.keepalives_interval(keepalives_interval);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the time interval between TCP keepalive probes.
|
||||
pub fn get_keepalives_interval(&self) -> Option<Duration> {
|
||||
self.config.get_keepalives_interval()
|
||||
}
|
||||
|
||||
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
|
||||
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
|
||||
self.config.keepalives_retries(keepalives_retries);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
|
||||
pub fn get_keepalives_retries(&self) -> Option<u32> {
|
||||
self.config.get_keepalives_retries()
|
||||
}
|
||||
|
||||
/// Sets the requirements of the session.
|
||||
///
|
||||
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
|
||||
/// secondary servers. Defaults to `Any`.
|
||||
pub fn target_session_attrs(
|
||||
&mut self,
|
||||
target_session_attrs: TargetSessionAttrs,
|
||||
) -> &mut Config {
|
||||
self.config.target_session_attrs(target_session_attrs);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the requirements of the session.
|
||||
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
|
||||
self.config.get_target_session_attrs()
|
||||
}
|
||||
|
||||
/// Sets the channel binding behavior.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
|
||||
self.config.channel_binding(channel_binding);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the channel binding behavior.
|
||||
pub fn get_channel_binding(&self) -> ChannelBinding {
|
||||
self.config.get_channel_binding()
|
||||
}
|
||||
|
||||
/// Sets the host load balancing behavior.
|
||||
///
|
||||
/// Defaults to `disable`.
|
||||
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
|
||||
self.config.load_balance_hosts(load_balance_hosts);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the host load balancing behavior.
|
||||
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
|
||||
self.config.get_load_balance_hosts()
|
||||
}
|
||||
|
||||
/// Sets the notice callback.
|
||||
///
|
||||
/// This callback will be invoked with the contents of every
|
||||
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
|
||||
/// the same structure as errors, but they are not "errors" per-se.
|
||||
///
|
||||
/// Notices are distinct from notifications, which are instead accessible
|
||||
/// via the [`Notifications`] API.
|
||||
///
|
||||
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
|
||||
/// [`Notifications`]: crate::Notifications
|
||||
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
|
||||
where
|
||||
F: Fn(DbError) + Send + Sync + 'static,
|
||||
{
|
||||
self.notice_callback = Arc::new(f);
|
||||
self
|
||||
}
|
||||
|
||||
/// Opens a connection to a PostgreSQL database.
|
||||
pub async fn connect(&self) -> Result<Client, Error> {
|
||||
let (client, connection) = self.config.connect(NoTls).await?;
|
||||
|
||||
let connection = Connection::new(connection);
|
||||
|
||||
Ok(Client::new(client, connection))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Config {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Config, Error> {
|
||||
s.parse::<tokio_postgres::Config>().map(Config::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Config> for Config {
|
||||
fn from(config: tokio_postgres::Config) -> Config {
|
||||
Config {
|
||||
config,
|
||||
notice_callback: Arc::new(|notice| {
|
||||
info!("{}: {}", notice.severity(), notice.message())
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_debug_implementations, dead_code)]
|
||||
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
|
||||
/// get dropped.
|
||||
pub struct Connection {
|
||||
/// The underlying connection stream.
|
||||
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Creates a new connection.
|
||||
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
|
||||
Connection {
|
||||
connection: Box::pin(connection),
|
||||
}
|
||||
}
|
||||
|
||||
/// start the connection
|
||||
pub async fn start(self) {
|
||||
if let Err(e) = self.connection.await {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_debug_implementations, dead_code)]
|
||||
/// An asynchronous PostgreSQL client.
|
||||
pub struct Client {
|
||||
connection: JoinHandle<()>,
|
||||
client: tokio_postgres::Client,
|
||||
}
|
||||
|
||||
impl Drop for Client {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.close_inner();
|
||||
}
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
let thread = tokio::spawn(async move {
|
||||
connection.start().await;
|
||||
});
|
||||
|
||||
Client {
|
||||
client,
|
||||
connection: thread,
|
||||
}
|
||||
}
|
||||
|
||||
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
|
||||
///
|
||||
/// See the documentation for [`Config`] for information about the connection syntax.
|
||||
///
|
||||
/// [`Config`]: config/struct.Config.html
|
||||
pub async fn connect(params: &str) -> Result<Client, Error> {
|
||||
debug!("Connecting to database with params: {}", params);
|
||||
params.parse::<Config>()?.connect().await
|
||||
}
|
||||
|
||||
/// Returns a new `Config` object which can be used to configure and connect to a database.
|
||||
pub fn configure() -> Config {
|
||||
Config::new()
|
||||
}
|
||||
|
||||
/// Executes a statement, returning the number of rows modified.
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
|
||||
///
|
||||
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
pub async fn execute<T>(
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<u64, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement + Debug,
|
||||
{
|
||||
debug!("Executing query: {:?}", query);
|
||||
self.client.execute(query, params).await
|
||||
}
|
||||
|
||||
/// Executes a statement, returning the resulting rows.
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
pub async fn query<T>(
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<Vec<Row>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement + Debug,
|
||||
{
|
||||
debug!("Executing query: {:?}", query);
|
||||
self.client.query(query, params).await
|
||||
}
|
||||
|
||||
/// Determines if the client's connection has already closed.
|
||||
///
|
||||
/// If this returns `true`, the client is no longer usable.
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.client.is_closed()
|
||||
}
|
||||
|
||||
/// Closes the client's connection to the server.
|
||||
///
|
||||
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
|
||||
/// caller.
|
||||
pub fn close(mut self) -> Result<(), Error> {
|
||||
self.close_inner()
|
||||
}
|
||||
|
||||
fn close_inner(&mut self) -> Result<(), Error> {
|
||||
self.client.__private_api_close();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -52,16 +52,16 @@ use tract_onnx::tract_hir::{
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i128, TensorError> {
|
||||
pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i64, TensorError> {
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((i128::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
let max_value = ((i64::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
// we parallelize the quantization process as it seems to be quite slow at times
|
||||
let scaled = (mult * *elem + shift).round() as i128;
|
||||
let scaled = (mult * *elem + shift).round() as i64;
|
||||
|
||||
Ok(scaled)
|
||||
}
|
||||
@@ -72,7 +72,7 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i12
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
|
||||
let int_rep = crate::fieldutils::felt_to_i128(felt);
|
||||
let int_rep = crate::fieldutils::felt_to_i64(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
int_rep as f64 / multiplier - shift
|
||||
}
|
||||
@@ -1475,7 +1475,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
visibility: &Visibility,
|
||||
) -> Result<Tensor<F>, Box<dyn std::error::Error>> {
|
||||
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
|
||||
Ok::<_, TensorError>(crate::fieldutils::i128_to_felt::<F>(quantize_float(
|
||||
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
0.0,
|
||||
scale,
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
)]
|
||||
// we allow this for our dynamic range based indexing scheme
|
||||
#![allow(clippy::single_range_in_vec_init)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
@@ -189,7 +190,7 @@ pub struct RunArgs {
|
||||
#[arg(long, default_value = "1")]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i128, i128>, default_value = "-32768->32768")]
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i64, i64>, default_value = "-32768->32768")]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17")]
|
||||
@@ -200,13 +201,13 @@ pub struct RunArgs {
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
#[arg(short = 'V', long, value_parser = parse_key_val::<String, usize>, default_value = "batch_size->1", value_delimiter = ',')]
|
||||
pub variables: Vec<(String, usize)>,
|
||||
/// Flags whether inputs are public, private, hashed, fixed, kzgcommit
|
||||
/// Flags whether inputs are public, private, fixed, hashed, polycommit
|
||||
#[arg(long, default_value = "private")]
|
||||
pub input_visibility: Visibility,
|
||||
/// Flags whether outputs are public, private, fixed, hashed, kzgcommit
|
||||
/// Flags whether outputs are public, private, fixed, hashed, polycommit
|
||||
#[arg(long, default_value = "public")]
|
||||
pub output_visibility: Visibility,
|
||||
/// Flags whether params are fixed, private, hashed, kzgcommit
|
||||
/// Flags whether params are fixed, private, hashed, polycommit
|
||||
#[arg(long, default_value = "private")]
|
||||
pub param_visibility: Visibility,
|
||||
#[arg(long, default_value = "false")]
|
||||
|
||||
269
src/python.rs
269
src/python.rs
@@ -6,7 +6,7 @@ use crate::circuit::modules::poseidon::{
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_i128, i128_to_felt};
|
||||
use crate::fieldutils::{felt_to_i64, i64_to_felt};
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
@@ -29,7 +29,6 @@ use pyo3_log;
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::str::FromStr;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
type PyFelt = String;
|
||||
|
||||
@@ -332,9 +331,9 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<i128> {
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_i128(felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
Ok(int_rep)
|
||||
}
|
||||
|
||||
@@ -358,7 +357,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<i128> {
|
||||
))]
|
||||
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_i128(felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier;
|
||||
Ok(float_rep)
|
||||
@@ -386,7 +385,7 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
|
||||
let felt = i128_to_felt(int_rep);
|
||||
let felt = i64_to_felt(int_rep);
|
||||
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
|
||||
}
|
||||
|
||||
@@ -547,7 +546,6 @@ fn kzg_commit(
|
||||
|
||||
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
|
||||
message,
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
&srs,
|
||||
);
|
||||
@@ -606,7 +604,6 @@ fn ipa_commit(
|
||||
|
||||
let output = PolyCommitChip::commit::<IPACommitmentScheme<G1Affine>>(
|
||||
message,
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
&srs,
|
||||
);
|
||||
@@ -781,29 +778,27 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
commitment=None,
|
||||
))]
|
||||
fn get_srs(
|
||||
py: Python,
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
srs_path: Option<PathBuf>,
|
||||
commitment: Option<PyCommitments>,
|
||||
) -> PyResult<bool> {
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let commitment: Option<Commitments> = match commitment {
|
||||
Some(c) => Some(c.into()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::get_srs_cmd(
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
commitment,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Ok(true)
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::get_srs_cmd(srs_path, settings_path, logrows, commitment)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates the circuit settings
|
||||
@@ -887,33 +882,37 @@ fn gen_settings(
|
||||
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
py: Python,
|
||||
data: PathBuf,
|
||||
model: PathBuf,
|
||||
settings: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i128,
|
||||
lookup_safety_margin: i64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
data,
|
||||
settings,
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to calibrate settings: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
data,
|
||||
settings,
|
||||
target,
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to calibrate settings: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// Runs the forward pass operation to generate a witness
|
||||
@@ -948,18 +947,22 @@ fn calibrate_settings(
|
||||
srs_path=None,
|
||||
))]
|
||||
fn gen_witness(
|
||||
py: Python,
|
||||
data: PathBuf,
|
||||
model: PathBuf,
|
||||
output: Option<PathBuf>,
|
||||
vk_path: Option<PathBuf>,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<PyObject> {
|
||||
let output =
|
||||
crate::execute::gen_witness(model, data, output, vk_path, srs_path).map_err(|e| {
|
||||
let err_str = format!("Failed to run generate witness: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to generate witness: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
})
|
||||
}
|
||||
|
||||
/// Mocks the prover
|
||||
@@ -1462,27 +1465,31 @@ fn verify_aggr(
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
))]
|
||||
fn create_evm_verifier(
|
||||
py: Python,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
render_vk_seperately,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
@@ -1512,18 +1519,27 @@ fn create_evm_verifier(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
|
||||
))]
|
||||
fn create_evm_data_attestation(
|
||||
py: Python,
|
||||
input_data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::create_evm_data_attestation(settings_path, sol_code_path, abi_path, input_data)
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_data_attestation(
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
input_data,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_data_attestation: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// Setup test evm witness
|
||||
@@ -1561,29 +1577,31 @@ fn create_evm_data_attestation(
|
||||
rpc_url=None,
|
||||
))]
|
||||
fn setup_test_evm_witness(
|
||||
py: Python,
|
||||
data_path: PathBuf,
|
||||
compiled_circuit_path: PathBuf,
|
||||
test_data: PathBuf,
|
||||
input_source: PyTestDataSource,
|
||||
output_source: PyTestDataSource,
|
||||
rpc_url: Option<String>,
|
||||
) -> Result<bool, PyErr> {
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::setup_test_evm_witness(
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::setup_test_evm_witness(
|
||||
data_path,
|
||||
compiled_circuit_path,
|
||||
test_data,
|
||||
rpc_url,
|
||||
input_source.into(),
|
||||
output_source.into(),
|
||||
))
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run setup_test_evm_witness: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// deploys the solidity verifier
|
||||
@@ -1595,28 +1613,30 @@ fn setup_test_evm_witness(
|
||||
private_key=None,
|
||||
))]
|
||||
fn deploy_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> Result<bool, PyErr> {
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::deploy_evm(
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2Verifier",
|
||||
))
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// deploys the solidity vk verifier
|
||||
@@ -1628,28 +1648,30 @@ fn deploy_evm(
|
||||
private_key=None,
|
||||
))]
|
||||
fn deploy_vk_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> Result<bool, PyErr> {
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::deploy_evm(
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
"Halo2VerifyingKey",
|
||||
))
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// deploys the solidity da verifier
|
||||
@@ -1663,6 +1685,7 @@ fn deploy_vk_evm(
|
||||
private_key=None
|
||||
))]
|
||||
fn deploy_da_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
input_data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
@@ -1670,10 +1693,9 @@ fn deploy_da_evm(
|
||||
rpc_url: Option<String>,
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> Result<bool, PyErr> {
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::deploy_da_evm(
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_da_evm(
|
||||
input_data,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
@@ -1681,13 +1703,15 @@ fn deploy_da_evm(
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
))
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run deploy_da_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
/// verifies an evm compatible proof, you will need solc installed in your environment to run this
|
||||
///
|
||||
@@ -1718,13 +1742,14 @@ fn deploy_da_evm(
|
||||
addr_da = None,
|
||||
addr_vk = None,
|
||||
))]
|
||||
fn verify_evm(
|
||||
addr_verifier: &str,
|
||||
fn verify_evm<'a>(
|
||||
py: Python<'a>,
|
||||
addr_verifier: &'a str,
|
||||
proof_path: PathBuf,
|
||||
rpc_url: Option<String>,
|
||||
addr_da: Option<&str>,
|
||||
addr_vk: Option<&str>,
|
||||
) -> Result<bool, PyErr> {
|
||||
addr_da: Option<&'a str>,
|
||||
addr_vk: Option<&'a str>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let addr_verifier = H160Flag::from(addr_verifier);
|
||||
let addr_da = if let Some(addr_da) = addr_da {
|
||||
let addr_da = H160Flag::from(addr_da);
|
||||
@@ -1739,21 +1764,16 @@ fn verify_evm(
|
||||
None
|
||||
};
|
||||
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::verify_evm(
|
||||
proof_path,
|
||||
addr_verifier,
|
||||
rpc_url,
|
||||
addr_da,
|
||||
addr_vk,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run verify_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run verify_evm: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
@@ -1795,6 +1815,7 @@ fn verify_evm(
|
||||
render_vk_seperately = DEFAULT_RENDER_VK_SEPERATELY.parse().unwrap(),
|
||||
))]
|
||||
fn create_evm_verifier_aggr(
|
||||
py: Python,
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
@@ -1802,21 +1823,25 @@ fn create_evm_verifier_aggr(
|
||||
logrows: u32,
|
||||
srs_path: Option<PathBuf>,
|
||||
render_vk_seperately: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier_aggr: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
Ok(true)
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
aggregation_settings,
|
||||
logrows,
|
||||
render_vk_seperately,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier_aggr: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
// Python Module
|
||||
|
||||
BIN
src/tensor/metal/tensor_ops.air
Normal file
BIN
src/tensor/metal/tensor_ops.air
Normal file
Binary file not shown.
31
src/tensor/metal/tensor_ops.metal
Normal file
31
src/tensor/metal/tensor_ops.metal
Normal file
@@ -0,0 +1,31 @@
|
||||
[[kernel]]
|
||||
void add(
|
||||
constant long *inA [[buffer(0)]],
|
||||
constant long *inB [[buffer(1)]],
|
||||
device long *result [[buffer(2)]],
|
||||
uint index [[thread_position_in_grid]])
|
||||
{
|
||||
result[index] = inA[index] + inB[index];
|
||||
}
|
||||
|
||||
|
||||
[[kernel]]
|
||||
void sub(
|
||||
constant long *inA [[buffer(0)]],
|
||||
constant long *inB [[buffer(1)]],
|
||||
device long *result [[buffer(2)]],
|
||||
uint index [[thread_position_in_grid]])
|
||||
{
|
||||
result[index] = inA[index] - inB[index];
|
||||
}
|
||||
|
||||
|
||||
[[kernel]]
|
||||
void mul(
|
||||
constant long *inA [[buffer(0)]],
|
||||
constant long *inB [[buffer(1)]],
|
||||
device long *result [[buffer(2)]],
|
||||
uint index [[thread_position_in_grid]])
|
||||
{
|
||||
result[index] = inA[index] * inB[index];
|
||||
}
|
||||
BIN
src/tensor/metal/tensor_ops.metallib
Normal file
BIN
src/tensor/metal/tensor_ops.metallib
Normal file
Binary file not shown.
@@ -5,7 +5,7 @@ pub mod val;
|
||||
/// A wrapper around a tensor of Halo2 Value types.
|
||||
pub mod var;
|
||||
|
||||
use halo2curves::ff::PrimeField;
|
||||
use halo2curves::{bn256::Fr, ff::PrimeField};
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
|
||||
@@ -17,9 +17,12 @@ use serde::{Deserialize, Serialize};
|
||||
pub use val::*;
|
||||
pub use var::*;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use instant::Instant;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{felt_to_i32, i128_to_felt, i32_to_felt},
|
||||
fieldutils::{felt_to_i32, felt_to_i64, i32_to_felt, i64_to_felt},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -30,12 +33,18 @@ use halo2_proofs::{
|
||||
poly::Rotation,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
#[cfg(feature = "metal")]
|
||||
use metal::{Device, MTLResourceOptions, MTLSize};
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
use thiserror::Error;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A wrapper for tensor related errors.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorError {
|
||||
@@ -65,6 +74,28 @@ pub enum TensorError {
|
||||
Overflow(String),
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib");
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
lazy_static::lazy_static! {
|
||||
static ref DEVICE: Device = Device::system_default().expect("no device found");
|
||||
|
||||
static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap();
|
||||
|
||||
static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue();
|
||||
|
||||
static ref PIPELINES: HashMap<String, metal::ComputePipelineState> = {
|
||||
let mut map = HashMap::new();
|
||||
for name in ["add", "sub", "mul"] {
|
||||
let function = LIB.get_function(name, None).unwrap();
|
||||
let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap();
|
||||
map.insert(name.to_string(), pipeline);
|
||||
}
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
pub trait TensorType: Clone + Debug + 'static {
|
||||
/// Returns the zero value.
|
||||
@@ -145,7 +176,7 @@ impl TensorType for f64 {
|
||||
}
|
||||
|
||||
tensor_type!(bool, Bool, false, true);
|
||||
tensor_type!(i128, Int128, 0, 1);
|
||||
tensor_type!(i64, Int64, 0, 1);
|
||||
tensor_type!(i32, Int32, 0, 1);
|
||||
tensor_type!(usize, USize, 0, 1);
|
||||
tensor_type!((), Empty, (), ());
|
||||
@@ -311,6 +342,94 @@ impl<T: TensorType> DerefMut for Tensor<T> {
|
||||
self.inner.deref_mut()
|
||||
}
|
||||
}
|
||||
/// Convert to i64 trait
|
||||
pub trait IntoI64 {
|
||||
/// Convert to i64
|
||||
fn into_i64(self) -> i64;
|
||||
|
||||
/// From i64
|
||||
fn from_i64(i: i64) -> Self;
|
||||
}
|
||||
|
||||
impl IntoI64 for i64 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self
|
||||
}
|
||||
fn from_i64(i: i64) -> i64 {
|
||||
i
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for i32 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as i32
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for usize {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for f32 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for f64 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for () {
|
||||
fn into_i64(self) -> i64 {
|
||||
0
|
||||
}
|
||||
fn from_i64(_: i64) -> Self {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for Fr {
|
||||
fn into_i64(self) -> i64 {
|
||||
felt_to_i64(self)
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i64_to_felt::<Fr>(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + IntoI64> IntoI64 for Value<F> {
|
||||
fn into_i64(self) -> i64 {
|
||||
let mut res = vec![];
|
||||
self.map(|x| res.push(x.into_i64()));
|
||||
|
||||
if res.is_empty() {
|
||||
0
|
||||
} else {
|
||||
res[0]
|
||||
}
|
||||
}
|
||||
|
||||
fn from_i64(i: i64) -> Self {
|
||||
Value::known(F::from_i64(i))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialEq + TensorType> PartialEq for Tensor<T> {
|
||||
fn eq(&self, other: &Tensor<T>) -> bool {
|
||||
@@ -427,10 +546,10 @@ impl<F: PrimeField + TensorType + Clone> From<Tensor<i32>> for Tensor<Value<F>>
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<i128>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<i128>) -> Tensor<Value<F>> {
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<i64>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<i64>) -> Tensor<Value<F>> {
|
||||
let mut ta: Tensor<Value<F>> =
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(i128_to_felt::<F>(t[i]))));
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(i64_to_felt::<F>(t[i]))));
|
||||
// safe to unwrap as we know the dims are correct
|
||||
ta.reshape(t.dims()).unwrap();
|
||||
ta
|
||||
@@ -1217,6 +1336,97 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[allow(unsafe_code)]
|
||||
/// Perform a tensor operation on the GPU using Metal.
|
||||
pub fn metal_tensor_op<T: Clone + TensorType + IntoI64 + Send + Sync>(
|
||||
v: &Tensor<T>,
|
||||
w: &Tensor<T>,
|
||||
op: &str,
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(v.dims(), w.dims());
|
||||
|
||||
log::trace!("------------------------------------------------");
|
||||
|
||||
let start = Instant::now();
|
||||
let v = v
|
||||
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
|
||||
.unwrap();
|
||||
let w = w
|
||||
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
|
||||
.unwrap();
|
||||
log::trace!("Time to map tensors: {:?}", start.elapsed());
|
||||
|
||||
objc::rc::autoreleasepool(|| {
|
||||
// create function pipeline.
|
||||
// this compiles the function, so a pipline can't be created in performance sensitive code.
|
||||
|
||||
let pipeline = &PIPELINES[op];
|
||||
|
||||
let length = v.len() as u64;
|
||||
let size = length * core::mem::size_of::<i64>() as u64;
|
||||
assert_eq!(v.len(), w.len());
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let buffer_a = DEVICE.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(v.as_ptr()) },
|
||||
size,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let buffer_b = DEVICE.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(w.as_ptr()) },
|
||||
size,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let buffer_result = DEVICE.new_buffer(
|
||||
size, // the operation will return an array with the same size.
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
log::trace!("Time to load buffers: {:?}", start.elapsed());
|
||||
|
||||
// for sending commands, a command buffer is needed.
|
||||
let start = Instant::now();
|
||||
let command_buffer = QUEUE.new_command_buffer();
|
||||
log::trace!("Time to load command buffer: {:?}", start.elapsed());
|
||||
// to write commands into a buffer an encoder is needed, in our case a compute encoder.
|
||||
let start = Instant::now();
|
||||
let compute_encoder = command_buffer.new_compute_command_encoder();
|
||||
compute_encoder.set_compute_pipeline_state(&pipeline);
|
||||
compute_encoder.set_buffers(
|
||||
0,
|
||||
&[Some(&buffer_a), Some(&buffer_b), Some(&buffer_result)],
|
||||
&[0; 3],
|
||||
);
|
||||
log::trace!("Time to load compute encoder: {:?}", start.elapsed());
|
||||
|
||||
// specify thread count and organization
|
||||
let start = Instant::now();
|
||||
let grid_size = MTLSize::new(length, 1, 1);
|
||||
let threadgroup_size = MTLSize::new(length, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_size, threadgroup_size);
|
||||
log::trace!("Time to dispatch threads: {:?}", start.elapsed());
|
||||
|
||||
// end encoding and execute commands
|
||||
let start = Instant::now();
|
||||
compute_encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
|
||||
command_buffer.wait_until_completed();
|
||||
log::trace!("Time to commit: {:?}", start.elapsed());
|
||||
|
||||
let start = Instant::now();
|
||||
let ptr = buffer_result.contents() as *const i64;
|
||||
let len = buffer_result.length() as usize / std::mem::size_of::<i64>();
|
||||
let slice = unsafe { core::slice::from_raw_parts(ptr, len) };
|
||||
let res = Tensor::new(Some(&slice.to_vec()), &v.dims()).unwrap();
|
||||
log::trace!("Time to get result: {:?}", start.elapsed());
|
||||
|
||||
res.map(|x| T::from_i64(x))
|
||||
})
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
/// Flattens a tensor of tensors
|
||||
/// ```
|
||||
@@ -1238,7 +1448,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Add for Tensor<T> {
|
||||
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Add
|
||||
for Tensor<T>
|
||||
{
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Adds tensors.
|
||||
/// # Arguments
|
||||
@@ -1288,14 +1500,24 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
|
||||
/// ```
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() + r;
|
||||
});
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "add");
|
||||
|
||||
Ok(lhs)
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| o.clone() + r)
|
||||
.collect();
|
||||
res.reshape(&broadcasted_shape).unwrap();
|
||||
res
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1318,6 +1540,7 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
|
||||
/// ```
|
||||
fn neg(self) -> Self {
|
||||
let mut output = self;
|
||||
|
||||
output.par_iter_mut().for_each(|x| {
|
||||
*x = x.clone().neg();
|
||||
});
|
||||
@@ -1325,7 +1548,9 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Sub for Tensor<T> {
|
||||
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Sub
|
||||
for Tensor<T>
|
||||
{
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Subtracts tensors.
|
||||
/// # Arguments
|
||||
@@ -1376,18 +1601,30 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
|
||||
/// ```
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() - r;
|
||||
});
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "sub");
|
||||
|
||||
Ok(lhs)
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| o.clone() - r)
|
||||
.collect();
|
||||
res.reshape(&broadcasted_shape).unwrap();
|
||||
res
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mul for Tensor<T> {
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Mul
|
||||
for Tensor<T>
|
||||
{
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Elementwise multiplies tensors.
|
||||
/// # Arguments
|
||||
@@ -1436,18 +1673,28 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
|
||||
/// ```
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap();
|
||||
let mut lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let lhs = self.expand(&broadcasted_shape).unwrap();
|
||||
let rhs = rhs.expand(&broadcasted_shape).unwrap();
|
||||
|
||||
lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| {
|
||||
*o = o.clone() * r;
|
||||
});
|
||||
#[cfg(feature = "metal")]
|
||||
let res = metal_tensor_op(&lhs, &rhs, "mul");
|
||||
|
||||
Ok(lhs)
|
||||
#[cfg(not(feature = "metal"))]
|
||||
let res = {
|
||||
let mut res: Tensor<T> = lhs
|
||||
.par_iter()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| o.clone() * r)
|
||||
.collect();
|
||||
res.reshape(&broadcasted_shape).unwrap();
|
||||
res
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Tensor<T> {
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Tensor<T> {
|
||||
/// Elementwise raise a tensor to the nth power.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -1661,4 +1908,66 @@ mod tests {
|
||||
let b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_int() {
|
||||
let a = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let b = Tensor::<i64>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap());
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "metal")]
|
||||
fn tensor_metal_felt() {
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
let a = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
let b = Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "add");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "sub");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let c = metal_tensor_op(&a, &b, "mul");
|
||||
assert_eq!(
|
||||
c,
|
||||
Tensor::<Fr>::new(
|
||||
Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]),
|
||||
&[2, 2],
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -316,9 +316,9 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i128].
|
||||
pub fn from_i128_tensor(t: Tensor<i128>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(i128_to_felt(x))));
|
||||
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i64].
|
||||
pub fn from_i64_tensor(t: Tensor<i64>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(i64_to_felt(x))));
|
||||
inner.into()
|
||||
}
|
||||
|
||||
@@ -521,9 +521,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `int_evals` on the inner tensor.
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i128>, Box<dyn Error>> {
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i64>, Box<dyn Error>> {
|
||||
// finally convert to vector of integers
|
||||
let mut integer_evals: Vec<i128> = vec![];
|
||||
let mut integer_evals: Vec<i64> = vec![];
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: _, ..
|
||||
@@ -531,25 +531,25 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
// we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want)
|
||||
let _ = v.map(|vaf| match vaf {
|
||||
ValType::Value(v) => v.map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i128(f));
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f));
|
||||
}),
|
||||
ValType::AssignedValue(v) => v.map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i128(f.evaluate()));
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
|
||||
}),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
v.value_field().map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i128(f.evaluate()));
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
|
||||
})
|
||||
}
|
||||
ValType::Constant(v) => {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i128(v));
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(v));
|
||||
Value::unknown()
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
|
||||
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
|
||||
13
src/wasm.rs
13
src/wasm.rs
@@ -2,8 +2,8 @@ use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use crate::circuit::modules::poseidon::PoseidonChip;
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::fieldutils::felt_to_i128;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
use crate::fieldutils::felt_to_i64;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::quantize_float;
|
||||
use crate::graph::scale_to_multiplier;
|
||||
@@ -113,7 +113,7 @@ pub fn feltToInt(
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&felt_to_i128(felt))
|
||||
serde_json::to_vec(&felt_to_i64(felt))
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
|
||||
))
|
||||
}
|
||||
@@ -127,7 +127,7 @@ pub fn feltToFloat(
|
||||
) -> Result<f64, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let int_rep = felt_to_i128(felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
Ok(int_rep as f64 / multiplier)
|
||||
}
|
||||
@@ -141,7 +141,7 @@ pub fn floatToFelt(
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let int_rep =
|
||||
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let felt = i128_to_felt(int_rep);
|
||||
let felt = i64_to_felt(int_rep);
|
||||
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|
||||
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
|
||||
@@ -177,7 +177,6 @@ pub fn kzgCommit(
|
||||
|
||||
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
|
||||
message,
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
¶ms,
|
||||
);
|
||||
@@ -276,7 +275,7 @@ pub fn genWitness(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false)
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false, false)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_i128, i128_to_felt};
|
||||
use ezkl::fieldutils::{felt_to_i64, i64_to_felt};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
@@ -200,7 +200,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 92] = [
|
||||
const TESTS: [&str; 93] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -296,7 +296,8 @@ mod native_tests {
|
||||
"reducel1",
|
||||
"reducel2", // 89
|
||||
"1l_lppool",
|
||||
"lstm_large", // 91
|
||||
"lstm_large", // 91
|
||||
"lstm_medium", // 92
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -535,7 +536,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=91 {
|
||||
seq!(N in 0..=92 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -623,7 +624,7 @@ mod native_tests {
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
// currently variable output rank is not supported in ONNX
|
||||
if test != "gather_nd" && test != "lstm_large" {
|
||||
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
@@ -908,7 +909,7 @@ mod native_tests {
|
||||
prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single", Commitments::KZG, 2);
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
// test_dir.close().unwrap();
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
@@ -921,7 +922,7 @@ mod native_tests {
|
||||
prove_and_verify(path, test.to_string(), "safe", "hashed", "private", "public", 1, None, true, "single", Commitments::KZG, 2);
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
// test_dir.close().unwrap();
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
@@ -1373,7 +1374,7 @@ mod native_tests {
|
||||
let witness = witness.clone();
|
||||
let outputs = witness.outputs.clone();
|
||||
|
||||
// get values as i128
|
||||
// get values as i64
|
||||
let output_perturbed_safe: Vec<Vec<halo2curves::bn256::Fr>> = outputs
|
||||
.iter()
|
||||
.map(|sv| {
|
||||
@@ -1383,10 +1384,10 @@ mod native_tests {
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::zero()
|
||||
} else {
|
||||
i128_to_felt(
|
||||
(felt_to_i128(*v) as f32
|
||||
i64_to_felt(
|
||||
(felt_to_i64(*v) as f32
|
||||
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
|
||||
as i128,
|
||||
as i64,
|
||||
)
|
||||
};
|
||||
|
||||
@@ -1396,7 +1397,7 @@ mod native_tests {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// get values as i128
|
||||
// get values as i64
|
||||
let output_perturbed_bad: Vec<Vec<halo2curves::bn256::Fr>> = outputs
|
||||
.iter()
|
||||
.map(|sv| {
|
||||
@@ -1406,10 +1407,10 @@ mod native_tests {
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::from(2)
|
||||
} else {
|
||||
i128_to_felt(
|
||||
(felt_to_i128(*v) as f32
|
||||
i64_to_felt(
|
||||
(felt_to_i64(*v) as f32
|
||||
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
|
||||
as i128,
|
||||
as i64,
|
||||
)
|
||||
};
|
||||
*v + perturbation
|
||||
|
||||
@@ -123,7 +123,7 @@ mod py_tests {
|
||||
}
|
||||
}
|
||||
|
||||
const TESTS: [&str; 32] = [
|
||||
const TESTS: [&str; 33] = [
|
||||
"proof_splitting.ipynb", // 0
|
||||
"variance.ipynb",
|
||||
"mnist_gan.ipynb",
|
||||
@@ -157,6 +157,7 @@ mod py_tests {
|
||||
"generalized_inverse.ipynb",
|
||||
"mnist_classifier.ipynb", // 30
|
||||
"world_rotation.ipynb",
|
||||
"logistic_regression.ipynb",
|
||||
];
|
||||
|
||||
macro_rules! test_func {
|
||||
@@ -169,7 +170,7 @@ mod py_tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
seq!(N in 0..=31 {
|
||||
seq!(N in 0..=32 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn run_notebook_(test: &str) {
|
||||
|
||||
@@ -109,7 +109,7 @@ def test_gen_srs():
|
||||
|
||||
|
||||
|
||||
def test_calibrate_over_user_range():
|
||||
async def test_calibrate_over_user_range():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
@@ -136,14 +136,14 @@ def test_calibrate_over_user_range():
|
||||
model_path, output_path, py_run_args=run_args)
|
||||
assert res == True
|
||||
|
||||
res = ezkl.calibrate_settings(
|
||||
res = await ezkl.calibrate_settings(
|
||||
data_path, model_path, output_path, "resources", 1, [0, 1, 2])
|
||||
assert res == True
|
||||
assert os.path.isfile(output_path)
|
||||
|
||||
|
||||
|
||||
def test_calibrate():
|
||||
async def test_calibrate():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
@@ -170,7 +170,7 @@ def test_calibrate():
|
||||
model_path, output_path, py_run_args=run_args)
|
||||
assert res == True
|
||||
|
||||
res = ezkl.calibrate_settings(
|
||||
res = await ezkl.calibrate_settings(
|
||||
data_path, model_path, output_path, "resources")
|
||||
assert res == True
|
||||
assert os.path.isfile(output_path)
|
||||
@@ -198,7 +198,7 @@ def test_model_compile():
|
||||
assert res == True
|
||||
|
||||
|
||||
def test_forward():
|
||||
async def test_forward():
|
||||
"""
|
||||
Test for vanilla forward pass
|
||||
"""
|
||||
@@ -217,7 +217,7 @@ def test_forward():
|
||||
'witness.json'
|
||||
)
|
||||
|
||||
res = ezkl.gen_witness(data_path, model_path, output_path)
|
||||
res = await ezkl.gen_witness(data_path, model_path, output_path)
|
||||
|
||||
with open(output_path, "r") as f:
|
||||
data = json.load(f)
|
||||
@@ -229,12 +229,12 @@ def test_forward():
|
||||
assert data["processed_outputs"]["poseidon_hash"] == res["processed_outputs"]["poseidon_hash"]
|
||||
|
||||
|
||||
def test_get_srs():
|
||||
async def test_get_srs():
|
||||
"""
|
||||
Test for get_srs
|
||||
"""
|
||||
settings_path = os.path.join(folder_path, 'settings.json')
|
||||
res = ezkl.get_srs(settings_path, srs_path=srs_path)
|
||||
res = await ezkl.get_srs(settings_path, srs_path=srs_path)
|
||||
|
||||
assert res == True
|
||||
|
||||
@@ -242,7 +242,7 @@ def test_get_srs():
|
||||
|
||||
another_srs_path = os.path.join(folder_path, "kzg_test_k8.params")
|
||||
|
||||
res = ezkl.get_srs(logrows=8, srs_path=another_srs_path, commitment=ezkl.PyCommitments.KZG)
|
||||
res = await ezkl.get_srs(logrows=8, srs_path=another_srs_path, commitment=ezkl.PyCommitments.KZG)
|
||||
|
||||
assert os.path.isfile(another_srs_path)
|
||||
|
||||
@@ -393,9 +393,7 @@ def test_prove_evm():
|
||||
assert os.path.isfile(proof_path)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_create_evm_verifier():
|
||||
async def test_create_evm_verifier():
|
||||
"""
|
||||
Create EVM verifier with solidity code
|
||||
In order to run this test you will need to install solc in your environment
|
||||
@@ -405,7 +403,7 @@ def test_create_evm_verifier():
|
||||
sol_code_path = os.path.join(folder_path, 'test.sol')
|
||||
abi_path = os.path.join(folder_path, 'test.abi')
|
||||
|
||||
res = ezkl.create_evm_verifier(
|
||||
res = await ezkl.create_evm_verifier(
|
||||
vk_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
@@ -417,7 +415,7 @@ def test_create_evm_verifier():
|
||||
assert os.path.isfile(sol_code_path)
|
||||
|
||||
|
||||
def test_deploy_evm():
|
||||
async def test_deploy_evm():
|
||||
"""
|
||||
Test deployment of the verifier smart contract
|
||||
In order to run this you will need to install solc in your environment
|
||||
@@ -428,7 +426,7 @@ def test_deploy_evm():
|
||||
# TODO: without optimization there will be out of gas errors
|
||||
# sol_code_path = os.path.join(folder_path, 'test.sol')
|
||||
|
||||
res = ezkl.deploy_evm(
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
@@ -437,7 +435,7 @@ def test_deploy_evm():
|
||||
assert res == True
|
||||
|
||||
|
||||
def test_deploy_evm_with_private_key():
|
||||
async def test_deploy_evm_with_private_key():
|
||||
"""
|
||||
Test deployment of the verifier smart contract using a custom private key
|
||||
In order to run this you will need to install solc in your environment
|
||||
@@ -450,7 +448,7 @@ def test_deploy_evm_with_private_key():
|
||||
|
||||
anvil_default_private_key = "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"
|
||||
|
||||
res = ezkl.deploy_evm(
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
@@ -462,7 +460,7 @@ def test_deploy_evm_with_private_key():
|
||||
custom_zero_balance_private_key = "ff9dfe0b6d31e93ba13460a4d6f63b5e31dd9532b1304f1cbccea7092a042aa4"
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to run deploy_evm"):
|
||||
res = ezkl.deploy_evm(
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
@@ -470,7 +468,7 @@ def test_deploy_evm_with_private_key():
|
||||
)
|
||||
|
||||
|
||||
def test_verify_evm():
|
||||
async def test_verify_evm():
|
||||
"""
|
||||
Verifies an evm proof
|
||||
In order to run this you will need to install solc in your environment
|
||||
@@ -486,7 +484,7 @@ def test_verify_evm():
|
||||
# TODO: without optimization there will be out of gas errors
|
||||
# sol_code_path = os.path.join(folder_path, 'test.sol')
|
||||
|
||||
res = ezkl.verify_evm(
|
||||
res = await ezkl.verify_evm(
|
||||
addr,
|
||||
proof_path,
|
||||
rpc_url=anvil_url,
|
||||
@@ -497,7 +495,7 @@ def test_verify_evm():
|
||||
assert res == True
|
||||
|
||||
|
||||
def test_aggregate_and_verify_aggr():
|
||||
async def test_aggregate_and_verify_aggr():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
@@ -526,7 +524,7 @@ def test_aggregate_and_verify_aggr():
|
||||
res = ezkl.gen_settings(model_path, settings_path)
|
||||
assert res == True
|
||||
|
||||
res = ezkl.calibrate_settings(
|
||||
res = await ezkl.calibrate_settings(
|
||||
data_path, model_path, settings_path, "resources")
|
||||
assert res == True
|
||||
assert os.path.isfile(settings_path)
|
||||
@@ -548,7 +546,7 @@ def test_aggregate_and_verify_aggr():
|
||||
'1l_relu_aggr_witness.json'
|
||||
)
|
||||
|
||||
res = ezkl.gen_witness(data_path, compiled_model_path,
|
||||
res = await ezkl.gen_witness(data_path, compiled_model_path,
|
||||
output_path)
|
||||
|
||||
ezkl.prove(
|
||||
@@ -603,7 +601,7 @@ def test_aggregate_and_verify_aggr():
|
||||
assert res == True
|
||||
|
||||
|
||||
def test_evm_aggregate_and_verify_aggr():
|
||||
async def test_evm_aggregate_and_verify_aggr():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
@@ -628,7 +626,7 @@ def test_evm_aggregate_and_verify_aggr():
|
||||
settings_path,
|
||||
)
|
||||
|
||||
ezkl.calibrate_settings(
|
||||
await ezkl.calibrate_settings(
|
||||
data_path,
|
||||
model_path,
|
||||
settings_path,
|
||||
@@ -657,7 +655,7 @@ def test_evm_aggregate_and_verify_aggr():
|
||||
'1l_relu_aggr_evm_witness.json'
|
||||
)
|
||||
|
||||
res = ezkl.gen_witness(data_path, compiled_model_path,
|
||||
res = await ezkl.gen_witness(data_path, compiled_model_path,
|
||||
output_path)
|
||||
|
||||
ezkl.prove(
|
||||
@@ -698,7 +696,7 @@ def test_evm_aggregate_and_verify_aggr():
|
||||
sol_code_path = os.path.join(folder_path, 'aggr_evm_1l_relu.sol')
|
||||
abi_path = os.path.join(folder_path, 'aggr_evm_1l_relu.abi')
|
||||
|
||||
res = ezkl.create_evm_verifier_aggr(
|
||||
res = await ezkl.create_evm_verifier_aggr(
|
||||
[settings_path],
|
||||
aggregate_vk_path,
|
||||
sol_code_path,
|
||||
@@ -712,7 +710,7 @@ def test_evm_aggregate_and_verify_aggr():
|
||||
|
||||
addr_path = os.path.join(folder_path, 'address_aggr.json')
|
||||
|
||||
res = ezkl.deploy_evm(
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
@@ -730,7 +728,7 @@ def test_evm_aggregate_and_verify_aggr():
|
||||
# with open(addr_path, 'r') as file:
|
||||
# addr_aggr = file.read().rstrip()
|
||||
|
||||
# res = ezkl.verify_evm(
|
||||
# res = await ezkl.verify_evm(
|
||||
# aggregate_proof_path,
|
||||
# addr_aggr,
|
||||
# rpc_url=anvil_url,
|
||||
@@ -769,7 +767,7 @@ def get_examples():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_file, input_file", get_examples())
|
||||
def test_all_examples(model_file, input_file):
|
||||
async def test_all_examples(model_file, input_file):
|
||||
"""Tests all examples in the examples folder"""
|
||||
# gen settings
|
||||
settings_path = os.path.join(folder_path, "settings.json")
|
||||
@@ -783,7 +781,7 @@ def test_all_examples(model_file, input_file):
|
||||
res = ezkl.gen_settings(model_file, settings_path)
|
||||
assert res
|
||||
|
||||
res = ezkl.calibrate_settings(
|
||||
res = await ezkl.calibrate_settings(
|
||||
input_file, model_file, settings_path, "resources")
|
||||
assert res
|
||||
|
||||
@@ -814,7 +812,7 @@ def test_all_examples(model_file, input_file):
|
||||
assert os.path.isfile(pk_path)
|
||||
|
||||
print("Generating witness for example: ", model_file)
|
||||
res = ezkl.gen_witness(input_file, compiled_model_path, witness_path)
|
||||
res = await ezkl.gen_witness(input_file, compiled_model_path, witness_path)
|
||||
assert os.path.isfile(witness_path)
|
||||
|
||||
print("Proving example: ", model_file)
|
||||
|
||||
@@ -130,7 +130,6 @@ mod wasm32 {
|
||||
serde_json::from_slice(&commitment_ser[..]).unwrap();
|
||||
let reference_commitment = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
|
||||
message,
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
¶ms,
|
||||
);
|
||||
@@ -151,10 +150,10 @@ mod wasm32 {
|
||||
.unwrap();
|
||||
assert_eq!(floating_point, (i as f64) / 4.0);
|
||||
|
||||
let integer: i128 =
|
||||
let integer: i64 =
|
||||
serde_json::from_slice(&feltToInt(clamped.clone()).map_err(|_| "failed").unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(integer, i as i128);
|
||||
assert_eq!(integer, i as i64);
|
||||
|
||||
let hex_string = format!("{:?}", field_element.clone());
|
||||
let returned_string: String = feltToBigEndian(clamped.clone())
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user