mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-14 16:57:59 -05:00
Compare commits
15 Commits
verifier-r
...
v15.5.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c21ec96817 | ||
|
|
85302453d9 | ||
|
|
523c77c912 | ||
|
|
948e5cd4b9 | ||
|
|
00155e585f | ||
|
|
0876faa12c | ||
|
|
a3c131dac0 | ||
|
|
fd9c2305ac | ||
|
|
a0060f341d | ||
|
|
17f1d42739 | ||
|
|
ebaee9e2b1 | ||
|
|
d51cba589a | ||
|
|
1cb1b6e143 | ||
|
|
d2b683b527 | ||
|
|
a06b09ef1f |
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -240,6 +240,8 @@ jobs:
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: public outputs and bounded lookup log
|
||||
run: cargo nextest run --release --verbose tests::mock_bounded_lookup_log --test-threads 32
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
|
||||
12
.github/workflows/update-ios-package.yml
vendored
12
.github/workflows/update-ios-package.yml
vendored
@@ -36,6 +36,15 @@ jobs:
|
||||
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
|
||||
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
|
||||
|
||||
- name: Copy Test Files
|
||||
run: |
|
||||
rm -rf ezkl-swift-package/Tests/EzklAssets/*
|
||||
|
||||
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
|
||||
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
|
||||
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
|
||||
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
|
||||
|
||||
- name: Set up Xcode environment
|
||||
run: |
|
||||
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
|
||||
@@ -66,10 +75,11 @@ jobs:
|
||||
git config user.name "GitHub Action"
|
||||
git config user.email "action@github.com"
|
||||
git add Sources/EzklCoreBindings
|
||||
git add Tests/EzklAssets
|
||||
git commit -m "Automatically updated EzklCoreBindings for EZKL"
|
||||
git tag ${{ github.event.inputs.tag }}
|
||||
git remote set-url origin https://zkonduit:${EZKL_PORTER_TOKEN}@github.com/zkonduit/ezkl-swift-package.git
|
||||
git push origin
|
||||
git push origin --tags
|
||||
git push origin tag ${{ github.event.inputs.tag }}
|
||||
env:
|
||||
EZKL_PORTER_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,7 +9,6 @@ pkg
|
||||
!AttestData.sol
|
||||
!VerifierBase.sol
|
||||
!LoadInstances.sol
|
||||
!VerifierManager.sol
|
||||
*.pf
|
||||
*.vk
|
||||
*.pk
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2397,7 +2397,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_solidity_verifier"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=vka-log#c319e229ad677ee4c7d95bdae45c2958350cfd14"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/update-h2-curves#eede1db7f3e599112bd1186e9d1913286bdcb539"
|
||||
dependencies = [
|
||||
"askama",
|
||||
"blake2b_simd",
|
||||
|
||||
64
Cargo.toml
64
Cargo.toml
@@ -19,11 +19,8 @@ crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde",
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = [
|
||||
"circuit-params",
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = ["circuit-params"] }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
itertools = { version = "0.10.3", default-features = false }
|
||||
clap = { version = "4.5.3", features = ["derive"], optional = true }
|
||||
@@ -36,9 +33,9 @@ halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "vka-log", optional = true }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves", optional = true }
|
||||
maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = { version = "1.6.0", optional = true }
|
||||
@@ -46,7 +43,10 @@ tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package
|
||||
semver = { version = "1.0.22", optional = true }
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
serde_json = { version = "1.0.97", features = ["float_roundtrip", "raw_value"] }
|
||||
serde_json = { version = "1.0.97", features = [
|
||||
"float_roundtrip",
|
||||
"raw_value",
|
||||
] }
|
||||
|
||||
# evm related deps
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = [
|
||||
@@ -56,39 +56,23 @@ alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5
|
||||
"rpc-types-eth",
|
||||
"signer-wallet",
|
||||
"node-bindings",
|
||||
|
||||
], optional = true }
|
||||
foundry-compilers = { version = "0.4.1", features = [
|
||||
"svm-solc",
|
||||
|
||||
], optional = true }
|
||||
foundry-compilers = { version = "0.4.1", features = ["svm-solc"], optional = true }
|
||||
ethabi = { version = "18", optional = true }
|
||||
indicatif = { version = "0.17.5", features = ["rayon"], optional = true }
|
||||
gag = { version = "1.0.0", default-features = false, optional = true }
|
||||
instant = { version = "0.1" }
|
||||
reqwest = { version = "0.12.4", default-features = false, features = [
|
||||
"default-tls",
|
||||
"multipart",
|
||||
"stream",
|
||||
], optional = true }
|
||||
reqwest = { version = "0.12.4", default-features = false, features = ["default-tls", "multipart", "stream"], optional = true }
|
||||
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
|
||||
tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
regex = { version = "1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
], optional = true }
|
||||
pyo3 = { version = "0.21.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default-features = false, optional = true }
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch = "migration-pyo3-0.21", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = ["macros", "rt-multi-thread"], optional = true }
|
||||
pyo3 = { version = "0.21.2", features = ["extension-module", "abi3-py37", "macros"], default-features = false, optional = true }
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = ["attributes", "tokio-runtime"], default-features = false, optional = true }
|
||||
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
@@ -185,7 +169,7 @@ harness = false
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "relu"
|
||||
name = "sigmoid"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
@@ -193,12 +177,12 @@ name = "relu_lookupless"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu"
|
||||
name = "accum_matmul_sigmoid"
|
||||
harness = false
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu_overflow"
|
||||
name = "accum_matmul_sigmoid_overflow"
|
||||
harness = false
|
||||
|
||||
[[bin]]
|
||||
@@ -213,13 +197,7 @@ required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = [
|
||||
"ezkl",
|
||||
"mv-lookup",
|
||||
"precompute-coset",
|
||||
"no-banner",
|
||||
"parallel-poly-read",
|
||||
]
|
||||
default = ["ezkl", "mv-lookup", "precompute-coset", "no-banner", "parallel-poly-read"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
|
||||
@@ -253,10 +231,7 @@ ezkl = [
|
||||
"dep:clap",
|
||||
"dep:tosubcommand",
|
||||
]
|
||||
parallel-poly-read = [
|
||||
"halo2_proofs/circuit-params",
|
||||
"halo2_proofs/parallel-poly-read",
|
||||
]
|
||||
parallel-poly-read = ["halo2_proofs/circuit-params", "halo2_proofs/parallel-poly-read"]
|
||||
mv-lookup = [
|
||||
"halo2_proofs/mv-lookup",
|
||||
"snark-verifier/mv-lookup",
|
||||
@@ -285,3 +260,4 @@ rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "owner",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "OwnableInvalidOwner",
|
||||
"type": "error"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "account",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "OwnableUnauthorizedAccount",
|
||||
"type": "error"
|
||||
},
|
||||
{
|
||||
"anonymous": false,
|
||||
"inputs": [
|
||||
{
|
||||
"indexed": false,
|
||||
"internalType": "address",
|
||||
"name": "addr",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "DeployedVerifier",
|
||||
"type": "event"
|
||||
},
|
||||
{
|
||||
"anonymous": false,
|
||||
"inputs": [
|
||||
{
|
||||
"indexed": true,
|
||||
"internalType": "address",
|
||||
"name": "previousOwner",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"indexed": true,
|
||||
"internalType": "address",
|
||||
"name": "newOwner",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "OwnershipTransferred",
|
||||
"type": "event"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "bytecode",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "deployVerifier",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "addr",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "owner",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "bytecode",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "precomputeAddress",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "renounceOwnership",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "newOwner",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "transferOwnership",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "verifierAddresses",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&a,
|
||||
BITS,
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
&LookupOp::Sigmoid { scale: 1.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -93,7 +93,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -65,7 +65,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&a,
|
||||
BITS,
|
||||
k,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
&LookupOp::Sigmoid { scale: 1.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -94,7 +94,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -68,7 +68,14 @@ impl Circuit<Fr> for NLCircuit {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|
||||
@@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
|
||||
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
|
||||
|
||||
let mut config = Config::default();
|
||||
|
||||
@@ -68,7 +68,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -1,184 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
pragma solidity 0.8.20;
|
||||
|
||||
// lib/openzeppelin-contracts/contracts/utils/Context.sol
|
||||
|
||||
// OpenZeppelin Contracts (last updated v5.0.1) (utils/Context.sol)
|
||||
|
||||
/**
|
||||
* @dev Provides information about the current execution context, including the
|
||||
* sender of the transaction and its data. While these are generally available
|
||||
* via msg.sender and msg.data, they should not be accessed in such a direct
|
||||
* manner, since when dealing with meta-transactions the account sending and
|
||||
* paying for execution may not be the actual sender (as far as an application
|
||||
* is concerned).
|
||||
*
|
||||
* This contract is only required for intermediate, library-like contracts.
|
||||
*/
|
||||
abstract contract Context {
|
||||
function _msgSender() internal view virtual returns (address) {
|
||||
return msg.sender;
|
||||
}
|
||||
|
||||
function _msgData() internal view virtual returns (bytes calldata) {
|
||||
return msg.data;
|
||||
}
|
||||
|
||||
function _contextSuffixLength() internal view virtual returns (uint256) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// lib/openzeppelin-contracts/contracts/access/Ownable.sol
|
||||
|
||||
// OpenZeppelin Contracts (last updated v5.0.0) (access/Ownable.sol)
|
||||
|
||||
/**
|
||||
* @dev Contract module which provides a basic access control mechanism, where
|
||||
* there is an account (an owner) that can be granted exclusive access to
|
||||
* specific functions.
|
||||
*
|
||||
* The initial owner is set to the address provided by the deployer. This can
|
||||
* later be changed with {transferOwnership}.
|
||||
*
|
||||
* This module is used through inheritance. It will make available the modifier
|
||||
* `onlyOwner`, which can be applied to your functions to restrict their use to
|
||||
* the owner.
|
||||
*/
|
||||
abstract contract Ownable is Context {
|
||||
/// set the owener initialy to be the anvil test account
|
||||
address private _owner = 0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266;
|
||||
|
||||
/**
|
||||
* @dev The caller account is not authorized to perform an operation.
|
||||
*/
|
||||
error OwnableUnauthorizedAccount(address account);
|
||||
|
||||
/**
|
||||
* @dev The owner is not a valid owner account. (eg. `address(0)`)
|
||||
*/
|
||||
error OwnableInvalidOwner(address owner);
|
||||
|
||||
event OwnershipTransferred(
|
||||
address indexed previousOwner,
|
||||
address indexed newOwner
|
||||
);
|
||||
|
||||
/**
|
||||
* @dev Initializes the contract setting the address provided by the deployer as the initial owner.
|
||||
*/
|
||||
constructor() {
|
||||
_transferOwnership(msg.sender);
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Throws if called by any account other than the owner.
|
||||
*/
|
||||
modifier onlyOwner() {
|
||||
_checkOwner();
|
||||
_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Returns the address of the current owner.
|
||||
*/
|
||||
function owner() public view virtual returns (address) {
|
||||
return _owner;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Throws if the sender is not the owner.
|
||||
*/
|
||||
function _checkOwner() internal view virtual {
|
||||
if (owner() != _msgSender()) {
|
||||
revert OwnableUnauthorizedAccount(_msgSender());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Leaves the contract without owner. It will not be possible to call
|
||||
* `onlyOwner` functions. Can only be called by the current owner.
|
||||
*
|
||||
* NOTE: Renouncing ownership will leave the contract without an owner,
|
||||
* thereby disabling any functionality that is only available to the owner.
|
||||
*/
|
||||
function renounceOwnership() public virtual onlyOwner {
|
||||
_transferOwnership(address(0));
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Transfers ownership of the contract to a new account (`newOwner`).
|
||||
* Can only be called by the current owner.
|
||||
*/
|
||||
function transferOwnership(address newOwner) public virtual onlyOwner {
|
||||
if (newOwner == address(0)) {
|
||||
revert OwnableInvalidOwner(address(0));
|
||||
}
|
||||
_transferOwnership(newOwner);
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Transfers ownership of the contract to a new account (`newOwner`).
|
||||
* Internal function without access restriction.
|
||||
*/
|
||||
function _transferOwnership(address newOwner) internal virtual {
|
||||
address oldOwner = _owner;
|
||||
_owner = newOwner;
|
||||
emit OwnershipTransferred(oldOwner, newOwner);
|
||||
}
|
||||
}
|
||||
|
||||
// interface for the reusable verifier.
|
||||
interface Halo2VerifierReusable {
|
||||
function verifyProof(
|
||||
address vkArtifact,
|
||||
bytes calldata proof,
|
||||
uint256[] calldata instances
|
||||
) external returns (bool);
|
||||
}
|
||||
|
||||
// Manages the deployment of all EZKL reusbale verifiers (ezkl version specific), verifiying key artifacts (circuit specific) and
|
||||
// routing proof verifications to the correct VKA and associate reusable verifier.
|
||||
// Helps to prevent the deployment of duplicate verifiers.
|
||||
contract EZKLVerifierManager is Ownable {
|
||||
/// @dev Mapping that checks if a given reusable verifier has been deployed
|
||||
mapping(address => bool) public verifierAddresses;
|
||||
|
||||
event DeployedVerifier(address addr);
|
||||
|
||||
// 1. Compute the address of the verifier to be deployed
|
||||
function precomputeAddress(
|
||||
bytes memory bytecode
|
||||
) public view returns (address) {
|
||||
bytes32 hash = keccak256(
|
||||
abi.encodePacked(
|
||||
bytes1(0xff),
|
||||
address(this),
|
||||
uint(0),
|
||||
keccak256(bytecode)
|
||||
)
|
||||
);
|
||||
|
||||
return address(uint160(uint(hash)));
|
||||
}
|
||||
|
||||
// 2. Deploy the reusable verifier using create2
|
||||
/// @param bytecode The bytecode of the reusable verifier to deploy
|
||||
function deployVerifier(
|
||||
bytes memory bytecode
|
||||
) public returns (address addr) {
|
||||
assembly {
|
||||
addr := create2(
|
||||
0x0, // value, hardcode to 0
|
||||
add(bytecode, 0x20),
|
||||
mload(bytecode),
|
||||
0x0 // salt, hardcode to 0
|
||||
)
|
||||
if iszero(extcodesize(addr)) {
|
||||
revert(0, 0)
|
||||
}
|
||||
}
|
||||
verifierAddresses[addr] = true;
|
||||
emit DeployedVerifier(addr);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
ezkl==15.5.1
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '15.5.1'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -146,6 +146,8 @@ where
|
||||
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
println!("INPUT COL {:#?}", input);
|
||||
|
||||
let mut layer_config = PolyConfig::configure(
|
||||
@@ -156,15 +158,11 @@ where
|
||||
);
|
||||
|
||||
layer_config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&input,
|
||||
&output,
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.configure_range_check(cs, &input, ¶ms, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
.configure_range_check(cs, &input, ¶ms, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
@@ -195,6 +193,11 @@ where
|
||||
) -> Result<(), Error> {
|
||||
config.layer_config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
config
|
||||
.layer_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
|
||||
let x = layouter
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
@@ -224,7 +227,10 @@ where
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -53,6 +53,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
// tells the config layer to add an affine op to the circuit gate
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
println!("INPUT COL {:#?}", input);
|
||||
|
||||
let mut layer_config = PolyConfig::<F>::configure(
|
||||
cs,
|
||||
&[input.clone(), params.clone()],
|
||||
@@ -60,17 +64,12 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
|
||||
layer_config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&input,
|
||||
&output,
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.configure_range_check(cs, &input, ¶ms, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
.configure_range_check(cs, &input, ¶ms, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
|
||||
@@ -104,6 +103,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
) -> Result<(), Error> {
|
||||
config.layer_config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
config
|
||||
.layer_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
|
||||
let x = layouter
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
@@ -144,7 +148,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
}),
|
||||
)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
@@ -184,7 +191,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
println!("6");
|
||||
|
||||
@@ -592,7 +592,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -648,10 +648,10 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -171,7 +171,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
"await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -328,7 +328,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"id": "171702d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -348,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 27,
|
||||
"id": "671dfdd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -364,7 +364,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 28,
|
||||
"id": "50eba2f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -399,9 +399,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
42
examples/onnx/log/gen.py
Normal file
42
examples/onnx/log/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.log(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 3)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/log/input.json
Normal file
1
examples/onnx/log/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[1.9252371788024902, 1.8418371677398682, 0.8400403261184692, 2.083845853805542, 0.9760497808456421, 0.6940176486968994, 0.015579521656036377, 2.2689192295074463]]}
|
||||
14
examples/onnx/log/network.onnx
Normal file
14
examples/onnx/log/network.onnx
Normal file
@@ -0,0 +1,14 @@
|
||||
pytorch2.2.2:o
|
||||
|
||||
inputoutput/Log"Log
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -21,9 +21,9 @@ def main():
|
||||
torch_model = Circuit()
|
||||
# Input to the model
|
||||
shape = [3, 2, 3]
|
||||
w = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
x = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
y = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
torch_out = torch_model(w, x, y)
|
||||
# Export the model
|
||||
torch.onnx.export(torch_model, # model being run
|
||||
|
||||
@@ -1 +1,148 @@
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}
|
||||
{
|
||||
"input_shapes": [
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
3,
|
||||
2,
|
||||
3
|
||||
]
|
||||
],
|
||||
"input_data": [
|
||||
[
|
||||
0.5,
|
||||
1.5,
|
||||
-0.04514765739440918,
|
||||
0.5936200618743896,
|
||||
0.9271858930587769,
|
||||
0.6688600778579712,
|
||||
-0.20331168174743652,
|
||||
-0.7016235589981079,
|
||||
0.025863051414489746,
|
||||
-0.19426143169403076,
|
||||
0.9827852249145508,
|
||||
0.4897397756576538,
|
||||
-1.5,
|
||||
-0.5,
|
||||
0.9278832674026489,
|
||||
0.5943725109100342,
|
||||
-0.573331356048584,
|
||||
0.3675816059112549
|
||||
],
|
||||
[
|
||||
0.7803324460983276,
|
||||
-0.9616303443908691,
|
||||
0.6070173978805542,
|
||||
-0.028337717056274414,
|
||||
-0.5080242156982422,
|
||||
-0.9280107021331787,
|
||||
0.6150380373001099,
|
||||
0.3865993022918701,
|
||||
-0.43668973445892334,
|
||||
0.17152702808380127,
|
||||
0.5144252777099609,
|
||||
-0.28881049156188965,
|
||||
0.8932310342788696,
|
||||
0.059034109115600586,
|
||||
0.6865451335906982,
|
||||
0.009820222854614258,
|
||||
0.23011493682861328,
|
||||
-0.9492779970169067
|
||||
],
|
||||
[
|
||||
-0.21352827548980713,
|
||||
-0.16015326976776123,
|
||||
-0.38964390754699707,
|
||||
0.13464701175689697,
|
||||
-0.8814496994018555,
|
||||
0.5037975311279297,
|
||||
-0.804405927658081,
|
||||
0.9858957529067993,
|
||||
0.19567716121673584,
|
||||
0.9777265787124634,
|
||||
0.6151977777481079,
|
||||
0.568595290184021,
|
||||
0.10584986209869385,
|
||||
-0.8975653648376465,
|
||||
0.6235959529876709,
|
||||
-0.547879695892334,
|
||||
0.9289869070053101,
|
||||
0.7567293643951416
|
||||
]
|
||||
],
|
||||
"output_data": [
|
||||
[
|
||||
1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
0.0
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
-1.0
|
||||
],
|
||||
[
|
||||
-0.0,
|
||||
-0.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
-0.0,
|
||||
1.0,
|
||||
1.0
|
||||
]
|
||||
]
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
pytorch2.0.1:â
|
||||
pytorch2.2.2:ă
|
||||
|
||||
woutput_w/Round"Round
|
||||
|
||||
xoutput_x/Floor"Floor
|
||||
|
||||
youtput_y/Ceil"Ceil torch_jitZ%
|
||||
youtput_y/Ceil"Ceil
|
||||
main_graphZ%
|
||||
w
|
||||
|
||||
|
||||
|
||||
42
examples/onnx/rsqrt/gen.py
Normal file
42
examples/onnx/rsqrt/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# reciprocal sqrt
|
||||
m = 1 / torch.sqrt(x)
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/rsqrt/input.json
Normal file
1
examples/onnx/rsqrt/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.8590779900550842, 0.4029041528701782, 0.6507361531257629, 0.9782488942146301, 0.37392884492874146, 0.6867020726203918, 0.11407750844955444, 0.362740159034729]]}
|
||||
17
examples/onnx/rsqrt/network.onnx
Normal file
17
examples/onnx/rsqrt/network.onnx
Normal file
@@ -0,0 +1,17 @@
|
||||
pytorch2.2.2:Ź
|
||||
$
|
||||
input/Sqrt_output_0/Sqrt"Sqrt
|
||||
1
|
||||
/Sqrt_output_0output/Reciprocal"
|
||||
Reciprocal
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -180,9 +180,6 @@ struct PyRunArgs {
|
||||
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Should constants with 0.0 fraction be rebased to scale 0
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
@@ -197,6 +194,9 @@ struct PyRunArgs {
|
||||
/// int: The number of legs used for decomposition
|
||||
#[pyo3(get, set)]
|
||||
pub decomp_legs: usize,
|
||||
/// bool: Should the circuit use unbounded lookups for log
|
||||
#[pyo3(get, set)]
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -212,6 +212,7 @@ impl PyRunArgs {
|
||||
impl From<PyRunArgs> for RunArgs {
|
||||
fn from(py_run_args: PyRunArgs) -> Self {
|
||||
RunArgs {
|
||||
bounded_log_lookup: py_run_args.bounded_log_lookup,
|
||||
tolerance: Tolerance::from(py_run_args.tolerance),
|
||||
input_scale: py_run_args.input_scale,
|
||||
param_scale: py_run_args.param_scale,
|
||||
@@ -223,7 +224,6 @@ impl From<PyRunArgs> for RunArgs {
|
||||
output_visibility: py_run_args.output_visibility,
|
||||
param_visibility: py_run_args.param_visibility,
|
||||
variables: py_run_args.variables,
|
||||
div_rebasing: py_run_args.div_rebasing,
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
commitment: Some(py_run_args.commitment.into()),
|
||||
@@ -236,6 +236,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
tolerance: self.tolerance.val,
|
||||
input_scale: self.input_scale,
|
||||
param_scale: self.param_scale,
|
||||
@@ -247,7 +248,6 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
output_visibility: self.output_visibility,
|
||||
param_visibility: self.param_visibility,
|
||||
variables: self.variables,
|
||||
div_rebasing: self.div_rebasing,
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
commitment: self.commitment.into(),
|
||||
@@ -873,8 +873,6 @@ fn gen_settings(
|
||||
/// max_logrows: int
|
||||
/// Optional max logrows to use for calibration
|
||||
///
|
||||
/// only_range_check_rebase: bool
|
||||
/// Check ranges when rebasing
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -889,7 +887,6 @@ fn gen_settings(
|
||||
scales = None,
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
py: Python,
|
||||
@@ -901,7 +898,6 @@ fn calibrate_settings(
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::calibrate(
|
||||
@@ -912,7 +908,6 @@ fn calibrate_settings(
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase,
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -94,4 +94,7 @@ pub enum CircuitError {
|
||||
#[error("[io] {0}")]
|
||||
/// IO error
|
||||
IoError(#[from] std::io::Error),
|
||||
/// Invalid scale
|
||||
#[error("negative scale for an op that requires positive inputs {0}")]
|
||||
NegativeScale(String),
|
||||
}
|
||||
|
||||
@@ -13,14 +13,38 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
ReduceMax {
|
||||
axes: Vec<usize>,
|
||||
@@ -45,6 +69,8 @@ pub enum HybridOp {
|
||||
ReduceArgMin {
|
||||
dim: usize,
|
||||
},
|
||||
Max,
|
||||
Min,
|
||||
Softmax {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
@@ -79,6 +105,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
| HybridOp::Less { .. }
|
||||
| HybridOp::Equals { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
| HybridOp::Max
|
||||
| HybridOp::Min
|
||||
| HybridOp::LessEqual { .. } => {
|
||||
vec![0, 1]
|
||||
}
|
||||
@@ -93,21 +121,32 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RSQRT (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
|
||||
|
||||
HybridOp::Max => "MAX".to_string(),
|
||||
HybridOp::Min => "MIN".to_string(),
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
|
||||
input_scale, output_scale, use_range_check_for_int
|
||||
),
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"DIV (denom={}, use_range_check_for_int={})",
|
||||
denom, use_range_check_for_int
|
||||
"RECIP (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -162,6 +201,34 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => layouts::rsqrt(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
)?,
|
||||
HybridOp::Sqrt { scale } => {
|
||||
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
|
||||
}
|
||||
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Ceil { scale, legs } => {
|
||||
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Floor { scale, legs } => {
|
||||
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Round { scale, legs } => {
|
||||
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -179,38 +246,16 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
if input_scale.0.fract() == 0.0
|
||||
&& output_scale.0.fract() == 0.0
|
||||
&& *use_range_check_for_int
|
||||
{
|
||||
layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Recip {
|
||||
input_scale: *input_scale,
|
||||
output_scale: *output_scale,
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
layouts::loop_div(
|
||||
} => layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
layouts::div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
@@ -301,9 +346,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
| HybridOp::ReduceArgMax { .. }
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
|
||||
|
||||
HybridOp::Recip { output_scale, .. } | HybridOp::Rsqrt { output_scale, .. } => {
|
||||
multiplier_to_scale(output_scale.0 as f64)
|
||||
}
|
||||
HybridOp::Softmax {
|
||||
output_scale,
|
||||
input_scale,
|
||||
..
|
||||
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
|
||||
HybridOp::Ln {
|
||||
scale: output_scale,
|
||||
} => 4 * multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
@@ -15,101 +14,27 @@ use halo2curves::ff::PrimeField;
|
||||
/// An enum representing the operations that can be used to express more complex operations via accumulation
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
},
|
||||
Cast {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Max {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Min {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
},
|
||||
Sigmoid {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Exp {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Erf {
|
||||
scale: utils::F32,
|
||||
},
|
||||
KroneckerDelta,
|
||||
Pow {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
HardSwish {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Div { denom: utils::F32 },
|
||||
IsOdd,
|
||||
PowersOfTwo { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
Exp { scale: utils::F32 },
|
||||
Cos { scale: utils::F32 },
|
||||
ACos { scale: utils::F32 },
|
||||
Cosh { scale: utils::F32 },
|
||||
ACosh { scale: utils::F32 },
|
||||
Sin { scale: utils::F32 },
|
||||
ASin { scale: utils::F32 },
|
||||
Sinh { scale: utils::F32 },
|
||||
ASinh { scale: utils::F32 },
|
||||
Tan { scale: utils::F32 },
|
||||
ATan { scale: utils::F32 },
|
||||
Tanh { scale: utils::F32 },
|
||||
ATanh { scale: utils::F32 },
|
||||
Erf { scale: utils::F32 },
|
||||
Pow { scale: utils::F32, a: utils::F32 },
|
||||
HardSwish { scale: utils::F32 },
|
||||
}
|
||||
|
||||
impl LookupOp {
|
||||
@@ -123,27 +48,14 @@ impl LookupOp {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::PowersOfTwo { scale } => format!("pow2_{}", scale),
|
||||
LookupOp::IsOdd => "is_odd".to_string(),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!("recip_{}_{}", input_scale, output_scale),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale } => format!("exp_{}", scale),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::Cos { scale } => format!("cos_{}", scale),
|
||||
LookupOp::ACos { scale } => format!("acos_{}", scale),
|
||||
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
|
||||
@@ -168,65 +80,28 @@ impl LookupOp {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res =
|
||||
match &self {
|
||||
LookupOp::Ceil { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into()))
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Floor { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into()))
|
||||
LookupOp::PowersOfTwo { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ipow2(&x, scale.0.into()))
|
||||
}
|
||||
LookupOp::Round { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into()))
|
||||
}
|
||||
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
|
||||
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::KroneckerDelta => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
|
||||
}
|
||||
LookupOp::Max { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::max(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Min { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
|
||||
),
|
||||
LookupOp::Cast { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
|
||||
),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => Ok::<_, TensorError>(tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
)),
|
||||
LookupOp::LeakyReLU { slope: a } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
|
||||
}
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Sqrt { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sqrt(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Rsqrt { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::rsqrt(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Erf { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::erffunc(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Exp { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::exp(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Ln { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ln(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Cos { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::cos(&x, scale.into()))
|
||||
}
|
||||
@@ -283,29 +158,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
|
||||
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
|
||||
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::KroneckerDelta => "K_DELTA".into(),
|
||||
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RECIP(input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
|
||||
LookupOp::PowersOfTwo { scale } => format!("POWERS_OF_TWO(scale={})", scale),
|
||||
LookupOp::IsOdd => "IS_ODD".to_string(),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
|
||||
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("RSQRT(scale={})", scale),
|
||||
LookupOp::Exp { scale } => format!("EXP(scale={})", scale),
|
||||
LookupOp::Tan { scale } => format!("TAN(scale={})", scale),
|
||||
LookupOp::ATan { scale } => format!("ATAN(scale={})", scale),
|
||||
@@ -339,15 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
LookupOp::Cast { scale } => {
|
||||
let in_scale = inputs_scale[0];
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
|
||||
LookupOp::KroneckerDelta => 0,
|
||||
_ => inputs_scale[0],
|
||||
};
|
||||
let scale = inputs_scale[0];
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
circuit::{
|
||||
layouts,
|
||||
utils::{self, F32},
|
||||
},
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -9,9 +12,12 @@ use super::{base::BaseOp, *};
|
||||
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum PolyOp {
|
||||
ReLU,
|
||||
Abs,
|
||||
Sign,
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
scale: i32,
|
||||
},
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
@@ -112,9 +118,9 @@ impl<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::LeakyReLU { slope: a, .. } => format!("LEAKYRELU (slope={})", a),
|
||||
PolyOp::Abs => "ABS".to_string(),
|
||||
PolyOp::Sign => "SIGN".to_string(),
|
||||
PolyOp::ReLU => "RELU".to_string(),
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
@@ -198,7 +204,9 @@ impl<
|
||||
Ok(Some(match self {
|
||||
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
|
||||
PolyOp::ReLU => layouts::relu(config, region, values[..].try_into()?)?,
|
||||
PolyOp::LeakyReLU { slope, scale } => {
|
||||
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
|
||||
}
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
@@ -329,6 +337,12 @@ impl<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
// this corresponds to the relu operation
|
||||
PolyOp::LeakyReLU {
|
||||
slope: F32(0.0), ..
|
||||
} => in_scales[0],
|
||||
// this corresponds to the leaky relu operation with a slope which induces a change in scale
|
||||
PolyOp::LeakyReLU { scale, .. } => in_scales[0] + *scale,
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Iff => in_scales[1],
|
||||
|
||||
@@ -474,6 +474,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min forcefully
|
||||
pub fn update_max_min_lookup_inputs_force(
|
||||
&mut self,
|
||||
min: IntegerRep,
|
||||
max: IntegerRep,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
if range.0 > range.1 {
|
||||
|
||||
@@ -150,12 +150,16 @@ pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get largest element represented by the range
|
||||
pub fn largest(&self) -> IntegerRep {
|
||||
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
|
||||
}
|
||||
fn name(&self) -> String {
|
||||
format!(
|
||||
"{}_{}_{}",
|
||||
self.nonlinearity.as_path(),
|
||||
self.range.0,
|
||||
self.range.1
|
||||
self.largest()
|
||||
)
|
||||
}
|
||||
/// Configures the table.
|
||||
@@ -222,7 +226,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
}
|
||||
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
let largest = self.largest();
|
||||
|
||||
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
|
||||
let inputs = Tensor::from(smallest..=largest)
|
||||
@@ -291,6 +295,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
|
||||
row_offset += chunk_idx * self.col_size;
|
||||
let (x, y) = self.cartesian_coord(row_offset);
|
||||
|
||||
if !preassigned_input {
|
||||
table.assign_cell(
|
||||
|| format!("nl_i_col row {}", row_offset),
|
||||
|
||||
@@ -1379,7 +1379,10 @@ mod conv_relu_col_ultra_overflow {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap().unwrap()],
|
||||
Box::new(PolyOp::ReLU),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -2347,7 +2350,14 @@ mod matmul_relu {
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
@@ -2439,7 +2449,14 @@ mod relu {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
|
||||
Ok(config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
)
|
||||
@@ -2482,11 +2499,11 @@ mod lookup_ultra_overflow {
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
struct SigmoidCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub input: ValTensor<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for ReLUCircuit<F> {
|
||||
impl Circuit<F> for SigmoidCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
@@ -2500,7 +2517,7 @@ mod lookup_ultra_overflow {
|
||||
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
|
||||
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
|
||||
|
||||
let mut config = BaseConfig::default();
|
||||
|
||||
@@ -2533,7 +2550,7 @@ mod lookup_ultra_overflow {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
@@ -2546,13 +2563,13 @@ mod lookup_ultra_overflow {
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn relucircuit() {
|
||||
fn sigmoidcircuit() {
|
||||
// get some logs fam
|
||||
crate::logger::init_logger();
|
||||
// parameters
|
||||
let a = Tensor::from((0..4).map(|i| Value::known(F::from(i + 1))));
|
||||
|
||||
let circuit = ReLUCircuit::<F> {
|
||||
let circuit = SigmoidCircuit::<F> {
|
||||
input: ValTensor::from(a),
|
||||
};
|
||||
|
||||
@@ -2562,7 +2579,7 @@ mod lookup_ultra_overflow {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
ReLUCircuit<F>,
|
||||
SigmoidCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -95,9 +95,6 @@ pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
/// Default commitment
|
||||
pub const DEFAULT_COMMITMENT: &str = "kzg";
|
||||
// TODO: In prod this will be the same across all chains we deploy to using the EZKL multisig create2 deployment.
|
||||
/// Default address of the verifier manager.
|
||||
pub const DEFAULT_VERIFIER_MANAGER_ADDRESS: &str = "0xdc64a140aa3e981100a9beca4e685f962f0cf6c9";
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
@@ -190,13 +187,11 @@ pub enum ContractType {
|
||||
/// Deploys a verifier contrat tailored to the circuit and not reusable
|
||||
Verifier {
|
||||
/// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits)
|
||||
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
|
||||
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
|
||||
reusable: bool,
|
||||
},
|
||||
/// Deploys a verifying key artifact that the reusable verifier loads into memory during runtime. Encodes the circuit specific data that was otherwise hardcoded onto the stack.
|
||||
VerifyingKeyArtifact,
|
||||
/// Manages the deployments of all reusable verifier and verifying artifact keys. Routes all the verification tx to the correct artifacts.
|
||||
VerifierManager
|
||||
}
|
||||
|
||||
impl Default for ContractType {
|
||||
@@ -220,7 +215,6 @@ impl std::fmt::Display for ContractType {
|
||||
reusable: false,
|
||||
} => "verifier".to_string(),
|
||||
ContractType::VerifyingKeyArtifact => "vka".to_string(),
|
||||
ContractType::VerifierManager => "manager".to_string()
|
||||
}
|
||||
)
|
||||
}
|
||||
@@ -238,16 +232,16 @@ impl From<&str> for ContractType {
|
||||
"verifier" => ContractType::Verifier { reusable: false },
|
||||
"verifier/reusable" => ContractType::Verifier { reusable: true },
|
||||
"vka" => ContractType::VerifyingKeyArtifact,
|
||||
"manager" => ContractType::VerifierManager,
|
||||
_ => {
|
||||
log::error!("Invalid value for ContractType");
|
||||
log::warn!("Defaulting to verifier");
|
||||
ContractType::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// wrapper for H160 to make it easy to parse into flag vals
|
||||
pub struct H160Flag {
|
||||
@@ -339,18 +333,6 @@ impl<'source> FromPyObject<'source> for ContractType {
|
||||
}
|
||||
}
|
||||
}
|
||||
// not wasm
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
lazy_static! {
|
||||
/// The version of the ezkl library
|
||||
pub static ref VERSION: &'static str = if env!("CARGO_PKG_VERSION") == "0.0.0" {
|
||||
"source - no compatibility guaranteed"
|
||||
} else {
|
||||
env!("CARGO_PKG_VERSION")
|
||||
};
|
||||
}
|
||||
|
||||
/// Get the styles for the CLI
|
||||
pub fn get_styles() -> clap::builder::Styles {
|
||||
@@ -401,7 +383,7 @@ pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, about, long_about = None)]
|
||||
#[clap(version = *VERSION, styles = get_styles(), trailing_var_arg = true)]
|
||||
#[clap(version = crate::version(), styles = get_styles(), trailing_var_arg = true)]
|
||||
pub struct Cli {
|
||||
/// If provided, outputs the completion file for given shell
|
||||
#[clap(long = "generate", value_parser)]
|
||||
@@ -492,9 +474,6 @@ pub enum Commands {
|
||||
/// max logrows to use for calibration, 26 is the max public SRS size
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
max_logrows: Option<u32>,
|
||||
// whether to only range check rebases (instead of trying both range check and lookup)
|
||||
#[arg(long, default_value = DEFAULT_ONLY_RANGE_CHECK_REBASE, action = clap::ArgAction::SetTrue)]
|
||||
only_range_check_rebase: Option<bool>,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
@@ -514,7 +493,7 @@ pub enum Commands {
|
||||
/// Gets an SRS from a circuit settings file.
|
||||
#[command(name = "get-srs")]
|
||||
GetSrs {
|
||||
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
|
||||
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
|
||||
@@ -561,7 +540,7 @@ pub enum Commands {
|
||||
/// The path to save the proving key to
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
@@ -588,7 +567,7 @@ pub enum Commands {
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
@@ -630,7 +609,7 @@ pub enum Commands {
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to output the verification key file to
|
||||
@@ -707,7 +686,7 @@ pub enum Commands {
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
@@ -739,7 +718,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEvmVerifier {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -761,7 +740,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
|
||||
#[command(name = "create-evm-vka")]
|
||||
CreateEvmVKArtifact {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -804,7 +783,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
@@ -837,7 +816,7 @@ pub enum Commands {
|
||||
/// The path to the verification key file (generated using the setup command)
|
||||
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
@@ -855,7 +834,7 @@ pub enum Commands {
|
||||
/// reduced srs
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
@@ -882,14 +861,6 @@ pub enum Commands {
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
/// Deployed verifier manager contract's address
|
||||
/// Used to facilitate reusable verifier and vk artifact deployment
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_verifier_manager: Option<H160Flag>,
|
||||
/// Deployed reusable verifier contract's address
|
||||
/// Use to facilitate reusable verifier and vk artifact deployment
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_reusable_verifier: Option<H160Flag>,
|
||||
/// Contract type to be deployed
|
||||
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
|
||||
contract: ContractType,
|
||||
|
||||
105
src/eth.rs
105
src/eth.rs
@@ -31,7 +31,7 @@ use alloy::transports::{RpcError, TransportErrorKind};
|
||||
use foundry_compilers::artifacts::Settings as SolcSettings;
|
||||
use foundry_compilers::error::{SolcError, SolcIoError};
|
||||
use foundry_compilers::Solc;
|
||||
use halo2_solidity_verifier::{encode_calldata, encode_deploy};
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Fr, G1Affine};
|
||||
use halo2curves::group::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
@@ -213,16 +213,6 @@ abigen!(
|
||||
}
|
||||
);
|
||||
|
||||
// The bytecode here was generated from running solc compiler version 0.8.20 with optimization enabled and runs param set to 1.
|
||||
abigen!(
|
||||
#[allow(missing_docs)]
|
||||
#[sol(
|
||||
rpc,
|
||||
bytecode = "60806040525f80546001600160a01b03191673f39fd6e51aad88f6f4ce6ab8827279cfffb92266179055348015610034575f80fd5b5061003e33610043565b610092565b5f80546001600160a01b038381166001600160a01b0319831681178455604051919092169283917f8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e09190a35050565b6103dc8061009f5f395ff3fe608060405234801561000f575f80fd5b5060043610610060575f3560e01c80635717ecef146100645780635d34fd561461009b578063715018a6146100bb5780637a33ac87146100c55780638da5cb5b14610125578063f2fde38b1461012d575b5f80fd5b6100866100723660046102a7565b60016020525f908152604090205460ff1681565b60405190151581526020015b60405180910390f35b6100ae6100a93660046102e8565b610140565b6040516100929190610392565b6100c36101bf565b005b6100ae6100d33660046102e8565b8051602091820120604080516001600160f81b0319818501523060601b6001600160601b03191660218201525f6035820152605580820193909352815180820390930183526075019052805191012090565b6100ae6101d2565b6100c361013b3660046102a7565b6101e0565b5f610149610226565b5f8251602084015ff59050803b61015e575f80fd5b6001600160a01b0381165f90815260016020819052604091829020805460ff19169091179055517f27bf8213352a1c07513a54703c920b9e437940154edead05874c43279acf166c906101b2908390610392565b60405180910390a1919050565b6101c7610226565b6101d05f610258565b565b5f546001600160a01b031690565b6101e8610226565b6001600160a01b03811661021a575f604051631e4fbdf760e01b81526004016102119190610392565b60405180910390fd5b61022381610258565b50565b3361022f6101d2565b6001600160a01b0316146101d0573360405163118cdaa760e01b81526004016102119190610392565b5f80546001600160a01b038381166001600160a01b0319831681178455604051919092169283917f8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e09190a35050565b5f602082840312156102b7575f80fd5b81356001600160a01b03811681146102cd575f80fd5b9392505050565b634e487b7160e01b5f52604160045260245ffd5b5f602082840312156102f8575f80fd5b81356001600160401b038082111561030e575f80fd5b818401915084601f830112610321575f80fd5b813581811115610333576103336102d4565b604051601f8201601f19908116603f0116810190838211818310171561035b5761035b6102d4565b81604052828152876020848701011115610373575f80fd5b826020860160208301375f928101602001929092525095945050505050565b6001600160a01b039190911681526020019056fea26469706673582212201d85104628b308554b775f612650220008f8e318f66dc4ace466d82d70bae4e264736f6c63430008140033"
|
||||
)]
|
||||
EZKLVerifierManager,
|
||||
"./abis/EZKLVerifierManager.json"
|
||||
);
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EthError {
|
||||
#[error("a transport error occurred: {0}")]
|
||||
@@ -362,99 +352,6 @@ pub async fn deploy_contract_via_solidity(
|
||||
Ok(contract)
|
||||
}
|
||||
|
||||
pub async fn deploy_vka(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
verifier_manager: H160,
|
||||
reusable_verifier: H160,
|
||||
) -> Result<H160, EthError> {
|
||||
let (client, _) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
// Create an instance of the EZKLVerifierManager contract
|
||||
let verifier_manager_contract = EZKLVerifierManager::new(verifier_manager, client.clone());
|
||||
|
||||
// Get the bytecode of the contract to be deployed
|
||||
let (_, bytecode, _run_time_bytecode) =
|
||||
get_contract_artifacts(sol_code_path.clone(), contract_name, runs).await?;
|
||||
|
||||
// Check if the reusable verifier is already deployed
|
||||
let deployed_verifier: bool = verifier_manager_contract
|
||||
.verifierAddresses(reusable_verifier)
|
||||
.call()
|
||||
.await?
|
||||
._0;
|
||||
|
||||
if deployed_verifier == false {
|
||||
panic!("The reusable verifier for this VKA has not been deployed yet.");
|
||||
}
|
||||
|
||||
let encoded = encode_deploy(&bytecode);
|
||||
|
||||
debug!("encoded: {:#?}", hex::encode(&encoded));
|
||||
|
||||
let input: TransactionInput = encoded.into();
|
||||
|
||||
let tx = TransactionRequest::default()
|
||||
.to(reusable_verifier)
|
||||
.input(input);
|
||||
debug!("transaction {:#?}", tx);
|
||||
|
||||
let result = client.call(&tx).await;
|
||||
|
||||
if let Err(e) = result {
|
||||
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
|
||||
}
|
||||
|
||||
// Now send the tx
|
||||
let _ = client.send_transaction(tx).await?;
|
||||
|
||||
let result = result?;
|
||||
debug!("result: {:#?}", result.to_vec());
|
||||
|
||||
let contract = H160::from_slice(&result.to_vec()[12..32]);
|
||||
return Ok(contract);
|
||||
}
|
||||
|
||||
pub async fn deploy_reusable_verifier(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
verifier_manager: H160,
|
||||
) -> Result<H160, EthError> {
|
||||
let (client, _) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
// Create an instance of the EZKLVerifierManager contract
|
||||
let verifier_manager_contract = EZKLVerifierManager::new(verifier_manager, client.clone());
|
||||
|
||||
// Get the bytecode of the contract to be deployed
|
||||
let (_, bytecode, _run_time_bytecode) =
|
||||
get_contract_artifacts(sol_code_path.clone(), contract_name, runs).await?;
|
||||
|
||||
// Deploy the contract using the EZKLVerifierManager
|
||||
let output = verifier_manager_contract
|
||||
.deployVerifier(bytecode.clone().into())
|
||||
.call()
|
||||
.await?;
|
||||
let out = verifier_manager_contract
|
||||
.precomputeAddress(bytecode.clone().into())
|
||||
.call()
|
||||
.await?;
|
||||
// assert that out == output
|
||||
assert_eq!(out._0, output.addr);
|
||||
// Get the deployed contract address from the receipt
|
||||
let contract = output.addr;
|
||||
let _ = verifier_manager_contract
|
||||
.deployVerifier(bytecode.into())
|
||||
.send()
|
||||
.await?;
|
||||
return Ok(contract);
|
||||
}
|
||||
|
||||
///
|
||||
pub async fn deploy_da_verifier_via_solidity(
|
||||
settings_path: PathBuf,
|
||||
|
||||
124
src/execute.rs
124
src/execute.rs
@@ -140,7 +140,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
max_logrows,
|
||||
only_range_check_rebase,
|
||||
} => calibrate(
|
||||
model.unwrap_or(DEFAULT_MODEL.into()),
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
@@ -149,7 +148,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
lookup_safety_margin,
|
||||
scales,
|
||||
scale_rebase_multiplier,
|
||||
only_range_check_rebase.unwrap_or(DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap()),
|
||||
max_logrows,
|
||||
)
|
||||
.await
|
||||
@@ -410,46 +408,24 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvm {
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
addr_verifier_manager,
|
||||
addr_reusable_verifier,
|
||||
contract,
|
||||
} => {
|
||||
// if contract type is either verifier/reusable
|
||||
match contract {
|
||||
ContractType::Verifier { reusable: true } => {
|
||||
if addr_verifier_manager.is_none() {
|
||||
panic!("Must pass a verifier manager address for reusable verifier")
|
||||
}
|
||||
}
|
||||
ContractType::VerifyingKeyArtifact => {
|
||||
if addr_verifier_manager.is_none() || addr_reusable_verifier.is_none() {
|
||||
panic!(
|
||||
"Must pass a verifier manager address and reusable verifier address for verifying key artifact"
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
deploy_evm(
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
rpc_url,
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
addr_verifier_manager.map(|s| s.into()),
|
||||
addr_reusable_verifier.map(|s| s.into()),
|
||||
contract,
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmDataAttestation {
|
||||
data,
|
||||
settings_path,
|
||||
@@ -699,11 +675,15 @@ pub(crate) async fn get_srs_cmd(
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone(), commitment))?;
|
||||
let computed_srs_path = get_srs_path(k, srs_path.clone(), commitment);
|
||||
let mut file = std::fs::File::create(&computed_srs_path)?;
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
info!(
|
||||
"Saved SRS to {}.",
|
||||
computed_srs_path.as_os_str().to_str().unwrap_or("disk")
|
||||
);
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -989,7 +969,6 @@ pub(crate) async fn calibrate(
|
||||
lookup_safety_margin: f64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
max_logrows: Option<u32>,
|
||||
) -> Result<GraphSettings, EZKLError> {
|
||||
use log::error;
|
||||
@@ -1025,12 +1004,6 @@ pub(crate) async fn calibrate(
|
||||
(11..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
vec![false]
|
||||
} else {
|
||||
vec![true, false]
|
||||
};
|
||||
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
|
||||
// 2 x 2 grid
|
||||
@@ -1068,12 +1041,6 @@ pub(crate) async fn calibrate(
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();
|
||||
|
||||
let range_grid = range_grid
|
||||
.iter()
|
||||
.cartesian_product(div_rebasing.iter())
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect::<Vec<(((crate::Scale, crate::Scale), u32), bool)>>();
|
||||
|
||||
let mut forward_pass_res = HashMap::new();
|
||||
|
||||
let pb = init_bar(range_grid.len() as u64);
|
||||
@@ -1082,30 +1049,23 @@ pub(crate) async fn calibrate(
|
||||
let mut num_failed = 0;
|
||||
let mut num_passed = 0;
|
||||
|
||||
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
|
||||
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
|
||||
pb.set_message(format!(
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
|
||||
"i-scale: {}, p-scale: {}, rebase-(x): {}, fail: {}, pass: {}",
|
||||
input_scale.to_string().blue(),
|
||||
param_scale.to_string().blue(),
|
||||
scale_rebase_multiplier.to_string().blue(),
|
||||
div_rebasing.to_string().yellow(),
|
||||
scale_rebase_multiplier.to_string().yellow(),
|
||||
num_failed.to_string().red(),
|
||||
num_passed.to_string().green()
|
||||
));
|
||||
|
||||
let key = (
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
let local_run_args = RunArgs {
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
@@ -1209,7 +1169,6 @@ pub(crate) async fn calibrate(
|
||||
let found_run_args = RunArgs {
|
||||
input_scale: new_settings.run_args.input_scale,
|
||||
param_scale: new_settings.run_args.param_scale,
|
||||
div_rebasing: new_settings.run_args.div_rebasing,
|
||||
lookup_range: new_settings.run_args.lookup_range,
|
||||
logrows: new_settings.run_args.logrows,
|
||||
scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier,
|
||||
@@ -1317,7 +1276,6 @@ pub(crate) async fn calibrate(
|
||||
best_params.run_args.input_scale,
|
||||
best_params.run_args.param_scale,
|
||||
best_params.run_args.scale_rebase_multiplier,
|
||||
best_params.run_args.div_rebasing,
|
||||
))
|
||||
.ok_or("no params found")?
|
||||
.iter()
|
||||
@@ -1439,7 +1397,6 @@ pub(crate) async fn create_evm_verifier(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_vka(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -1468,20 +1425,9 @@ pub(crate) async fn create_evm_vka(
|
||||
num_instance,
|
||||
);
|
||||
|
||||
let (reusable_verifier, vk_solidity) = generator.render_separately()?;
|
||||
let vk_solidity = generator.render_separately()?.1;
|
||||
|
||||
// Remove the first line of vk_solidity (license identifier). Same license identifier for all contracts in this .sol
|
||||
let vk_solidity = vk_solidity
|
||||
.lines()
|
||||
.skip(1)
|
||||
.collect::<Vec<&str>>()
|
||||
.join("\n");
|
||||
|
||||
// We store each contracts to the same file...
|
||||
// We need to do this so that during the deployment transaction we make sure
|
||||
// verifier manager links the VKA to the correct reusable_verifier.
|
||||
let combined_solidity = format!("{}\n\n{}", reusable_verifier, vk_solidity);
|
||||
File::create(sol_code_path.clone())?.write_all(combined_solidity.as_bytes())?;
|
||||
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingArtifact", 0).await?;
|
||||
@@ -1599,51 +1545,21 @@ pub(crate) async fn deploy_evm(
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
verifier_manager: Option<alloy::primitives::Address>,
|
||||
reusable_verifier: Option<alloy::primitives::Address>,
|
||||
contract: ContractType,
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::eth::{deploy_reusable_verifier, deploy_vka};
|
||||
|
||||
let contract_name = match contract {
|
||||
ContractType::Verifier { reusable: false } => "Halo2Verifier",
|
||||
ContractType::Verifier { reusable: true } => "Halo2VerifierReusable",
|
||||
ContractType::VerifyingKeyArtifact => "Halo2VerifyingArtifact",
|
||||
ContractType::VerifierManager => "EZKLVerifierManager",
|
||||
};
|
||||
|
||||
let contract_address = if contract_name == "Halo2VerifierReusable" {
|
||||
// Use VerifierManager to deploy the contract
|
||||
deploy_reusable_verifier(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
verifier_manager.unwrap(),
|
||||
)
|
||||
.await?
|
||||
} else if contract_name == "Halo2VerifyingArtifact" {
|
||||
deploy_vka(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
verifier_manager.unwrap(),
|
||||
reusable_verifier.unwrap(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("Contract deployed at: {:#?}", contract_address);
|
||||
|
||||
|
||||
@@ -133,6 +133,8 @@ pub struct GraphWitness {
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// max range check size
|
||||
pub max_range_size: IntegerRep,
|
||||
/// (optional) version of ezkl used
|
||||
pub version: Option<String>,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -161,6 +163,7 @@ impl GraphWitness {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
version: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1350,6 +1353,7 @@ impl GraphCircuit {
|
||||
max_lookup_inputs: model_results.max_lookup_inputs,
|
||||
min_lookup_inputs: model_results.min_lookup_inputs,
|
||||
max_range_size: model_results.max_range_size,
|
||||
version: Some(crate::version().to_string()),
|
||||
};
|
||||
|
||||
witness.generate_rescaled_elements(
|
||||
|
||||
@@ -915,20 +915,9 @@ impl Model {
|
||||
if scales.contains_key(&i) {
|
||||
let scale_diff = n.out_scale - scales[&i];
|
||||
n.opkind = if scale_diff > 0 {
|
||||
RebaseScale::rebase(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
1,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
RebaseScale::rebase(n.opkind, scales[&i], n.out_scale, 1)
|
||||
} else {
|
||||
RebaseScale::rebase_up(
|
||||
n.opkind,
|
||||
scales[&i],
|
||||
n.out_scale,
|
||||
run_args.div_rebasing,
|
||||
)
|
||||
RebaseScale::rebase_up(n.opkind, scales[&i], n.out_scale)
|
||||
};
|
||||
n.out_scale = scales[&i];
|
||||
}
|
||||
|
||||
@@ -120,7 +120,6 @@ impl RebaseScale {
|
||||
global_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
scale_rebase_multiplier: u32,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
|
||||
&& !inner.is_constant()
|
||||
@@ -137,7 +136,6 @@ impl RebaseScale {
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op.original_scale,
|
||||
})
|
||||
@@ -148,7 +146,6 @@ impl RebaseScale {
|
||||
multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
original_scale: op_out_scale,
|
||||
})
|
||||
@@ -163,7 +160,6 @@ impl RebaseScale {
|
||||
inner: SupportedOp,
|
||||
target_scale: crate::Scale,
|
||||
op_out_scale: crate::Scale,
|
||||
div_rebasing: bool,
|
||||
) -> SupportedOp {
|
||||
if (op_out_scale < (target_scale)) && !inner.is_constant() && !inner.is_input() {
|
||||
let multiplier = scale_to_multiplier(op_out_scale - target_scale);
|
||||
@@ -176,7 +172,6 @@ impl RebaseScale {
|
||||
original_scale: op.original_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
@@ -187,7 +182,6 @@ impl RebaseScale {
|
||||
original_scale: op_out_scale,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32(multiplier as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -595,13 +589,7 @@ impl Node {
|
||||
let mut out_scale = opkind.out_scale(in_scales.clone())?;
|
||||
// rescale the inputs if necessary to get consistent fixed points, we select the largest scale (highest precision)
|
||||
let global_scale = scales.get_max();
|
||||
opkind = RebaseScale::rebase(
|
||||
opkind,
|
||||
global_scale,
|
||||
out_scale,
|
||||
scales.rebase_multiplier,
|
||||
run_args.div_rebasing,
|
||||
);
|
||||
opkind = RebaseScale::rebase(opkind, global_scale, out_scale, scales.rebase_multiplier);
|
||||
|
||||
out_scale = opkind.out_scale(in_scales)?;
|
||||
|
||||
|
||||
@@ -763,93 +763,52 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
if unit == 0. {
|
||||
SupportedOp::Linear(PolyOp::ReLU)
|
||||
if !const_inputs.is_empty() {
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
if unit == 0. {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
SupportedOp::Linear(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Hybrid(HybridOp::Max)
|
||||
}
|
||||
} else {
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
SupportedOp::Nonlinear(LookupOp::Max {
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
SupportedOp::Hybrid(HybridOp::Max)
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
}
|
||||
"Min" => {
|
||||
// Extract the min value
|
||||
// first find the input that is a constant
|
||||
// and then extract the value
|
||||
let const_inputs = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.is_constant())
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Min {
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
SupportedOp::Hybrid(HybridOp::Min)
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
}
|
||||
"Recip" => {
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
// If the input scale is larger than the params scale
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
use_range_check_for_int: true,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -864,8 +823,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::LeakyReLU {
|
||||
SupportedOp::Linear(PolyOp::LeakyReLU {
|
||||
slope: crate::circuit::utils::F32(leaky_op.alpha),
|
||||
scale: scales.params,
|
||||
})
|
||||
}
|
||||
"Scan" => {
|
||||
@@ -877,61 +837,75 @@ pub fn new_op_from_onnx(
|
||||
"Abs" => SupportedOp::Linear(PolyOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Sqrt" => SupportedOp::Hybrid(HybridOp::Sqrt {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Rsqrt" => {
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
})
|
||||
}
|
||||
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Ln" => SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Ln" => {
|
||||
if run_args.bounded_log_lookup {
|
||||
SupportedOp::Hybrid(HybridOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
"Sin" => SupportedOp::Nonlinear(LookupOp::Sin {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Cos" => SupportedOp::Nonlinear(LookupOp::Cos {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Tan" => SupportedOp::Nonlinear(LookupOp::Tan {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Asin" => SupportedOp::Nonlinear(LookupOp::ASin {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Acos" => SupportedOp::Nonlinear(LookupOp::ACos {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Atan" => SupportedOp::Nonlinear(LookupOp::ATan {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Erf" => SupportedOp::Nonlinear(LookupOp::Erf {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
}),
|
||||
"Source" => {
|
||||
let dt = node.outputs[0].fact.datum_type;
|
||||
@@ -975,11 +949,9 @@ pub fn new_op_from_onnx(
|
||||
replace_const(
|
||||
0,
|
||||
0,
|
||||
SupportedOp::Nonlinear(LookupOp::Cast {
|
||||
scale: crate::circuit::utils::F32(scale_to_multiplier(
|
||||
input_scales[0],
|
||||
)
|
||||
as f32),
|
||||
SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
)?
|
||||
} else {
|
||||
@@ -1085,7 +1057,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let in_scale = input_scales[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Softmax {
|
||||
@@ -1123,17 +1095,21 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Nonlinear(LookupOp::Floor {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Nonlinear(LookupOp::Round {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
"RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Sign" => SupportedOp::Linear(PolyOp::Sign),
|
||||
"Pow" => {
|
||||
@@ -1146,10 +1122,17 @@ pub fn new_op_from_onnx(
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
}
|
||||
SupportedOp::Nonlinear(LookupOp::Pow {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(c.raw_values[0]),
|
||||
})
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
|
||||
if exponent.fract() == 0.0 {
|
||||
SupportedOp::Linear(PolyOp::Pow(exponent as u32))
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Pow {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
a: crate::circuit::utils::F32(exponent),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support constant pow for now")
|
||||
}
|
||||
|
||||
22
src/lib.rs
22
src/lib.rs
@@ -108,6 +108,18 @@ use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
// if CARGO VERSION is 0.0.0 replace with "source - no compatibility guaranteed"
|
||||
/// The version of the ezkl library
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Get the version of the library
|
||||
pub fn version() -> &'static str {
|
||||
match VERSION {
|
||||
"0.0.0" => "source - no compatibility guaranteed",
|
||||
_ => VERSION,
|
||||
}
|
||||
}
|
||||
|
||||
/// Bindings managment
|
||||
#[cfg(any(
|
||||
feature = "ios-bindings",
|
||||
@@ -297,8 +309,6 @@ pub struct RunArgs {
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
/// Should constants with 0.0 fraction be rebased to scale 0
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
@@ -317,11 +327,18 @@ pub struct RunArgs {
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))]
|
||||
/// the number of legs used for decompositions
|
||||
pub decomp_legs: usize,
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
/// use unbounded lookup for the log
|
||||
pub bounded_log_lookup: bool,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bounded_log_lookup: false,
|
||||
tolerance: Tolerance::default(),
|
||||
input_scale: 7,
|
||||
param_scale: 7,
|
||||
@@ -333,7 +350,6 @@ impl Default for RunArgs {
|
||||
input_visibility: Visibility::Private,
|
||||
output_visibility: Visibility::Public,
|
||||
param_visibility: Visibility::Private,
|
||||
div_rebasing: false,
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
commitment: None,
|
||||
|
||||
@@ -59,10 +59,7 @@ fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum ProofType {
|
||||
#[default]
|
||||
Single,
|
||||
@@ -134,10 +131,7 @@ impl<'source> pyo3::FromPyObject<'source> for ProofType {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum StrategyType {
|
||||
Single,
|
||||
Accum,
|
||||
@@ -203,10 +197,7 @@ pub enum PfSysError {
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)]
|
||||
#[cfg_attr(
|
||||
all(feature = "ezkl", not(target_arch = "wasm32")),
|
||||
derive(ValueEnum)
|
||||
)]
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), derive(ValueEnum))]
|
||||
pub enum TranscriptType {
|
||||
Poseidon,
|
||||
#[default]
|
||||
@@ -324,6 +315,8 @@ where
|
||||
pub timestamp: Option<u128>,
|
||||
/// commitment
|
||||
pub commitment: Option<Commitments>,
|
||||
/// (optional) version of ezkl used to generate the proof
|
||||
version: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -385,6 +378,7 @@ where
|
||||
.as_millis(),
|
||||
),
|
||||
commitment,
|
||||
version: Some(crate::version().to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -920,6 +914,7 @@ mod tests {
|
||||
pretty_public_inputs: None,
|
||||
timestamp: None,
|
||||
commitment: None,
|
||||
version: None,
|
||||
};
|
||||
|
||||
snark
|
||||
|
||||
@@ -27,7 +27,7 @@ pub fn get_rep(
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) {
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
@@ -1421,85 +1421,6 @@ pub fn slice<T: TensorType + Send + Sync>(
|
||||
pub mod nonlinearities {
|
||||
use super::*;
|
||||
|
||||
/// Ceiling operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
///
|
||||
/// use ezkl::tensor::ops::nonlinearities::ceil;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = ceil(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ceil(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.ceil() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Floor operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::floor;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = floor(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 2, 2, 4, 4, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn floor(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.floor() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Round operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::round;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = round(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn round(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.round() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Round half to even operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
@@ -1553,30 +1474,88 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Applies Kronecker delta to a tensor of integers.
|
||||
/// Checks if a tensor's elements are odd
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::kronecker_delta;
|
||||
/// use ezkl::tensor::ops::nonlinearities::is_odd;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = kronecker_delta(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
///
|
||||
/// let result = is_odd(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn kronecker_delta<T: TensorType + std::cmp::PartialEq + Send + Sync>(
|
||||
a: &Tensor<T>,
|
||||
) -> Tensor<T> {
|
||||
pub fn is_odd(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
if a_i == T::zero().unwrap() {
|
||||
Ok::<_, TensorError>(T::one().unwrap())
|
||||
let rounded = if a_i % 2 == 0 { 0 } else { 1 };
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Powers of 2
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ipow2;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = ipow2(&x, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 32768, 4, 2, 2, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ipow2(a: &Tensor<IntegerRep>, scale_output: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = a_i as f64;
|
||||
let kix = scale_output * (2.0_f64).powf(kix);
|
||||
let rounded = kix.round();
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies ln base 2 to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::ilog2;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 2]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = ilog2(&x, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 4, 1, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ilog2(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
let log = (kix).log2();
|
||||
let floor = log.floor();
|
||||
let ceil = log.ceil();
|
||||
let floor_dist = ((2.0_f64).powf(floor) - kix).abs();
|
||||
let ceil_dist = (kix - (2.0_f64).powf(ceil)).abs();
|
||||
|
||||
if floor_dist < ceil_dist {
|
||||
Ok::<_, TensorError>(floor as IntegerRep)
|
||||
} else {
|
||||
Ok::<_, TensorError>(T::zero().unwrap())
|
||||
Ok::<_, TensorError>(ceil as IntegerRep)
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
@@ -1710,12 +1689,11 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies exponential to a tensor of integers.
|
||||
/// Elementwise applies ln to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// * `scale_output` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
@@ -1750,27 +1728,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies sign to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::sign;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[-2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = sign(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn sign(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(a_i.signum()))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies square root to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -2254,101 +2211,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies leaky relu to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// * `slope` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::leakyrelu;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = leakyrelu(&x, 0.1);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn leakyrelu(a: &Tensor<IntegerRep>, slope: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rounded = if a_i < 0 {
|
||||
let d_inv_x = (slope) * (a_i as f64);
|
||||
d_inv_x.round() as IntegerRep
|
||||
} else {
|
||||
let d_inv_x = a_i as f64;
|
||||
d_inv_x.round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies max to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - scalar
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::max;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = max(&x, 1.0, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn max(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
|
||||
// calculate value of output
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let d_inv_x = (a_i as f64) / scale_input;
|
||||
let rounded = if d_inv_x <= threshold {
|
||||
(threshold * scale_input).round() as IntegerRep
|
||||
} else {
|
||||
(d_inv_x * scale_input).round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies min to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - scalar
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::min;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = min(&x, 1.0, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn min(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
|
||||
// calculate value of output
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let d_inv_x = (a_i as f64) / scale_input;
|
||||
let rounded = if d_inv_x >= threshold {
|
||||
(threshold * scale_input).round() as IntegerRep
|
||||
} else {
|
||||
(d_inv_x * scale_input).round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise divides a tensor with a const integer element.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -2429,104 +2291,6 @@ pub mod nonlinearities {
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::greater_than;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
/// let result = greater_than(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) > 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::greater_than_equal;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
/// let result = greater_than_equal(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) >= 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise less than
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::less_than;
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
///
|
||||
/// let result = less_than(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn less_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) < 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise less than
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::less_than_equal;
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
///
|
||||
/// let result = less_than_equal(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn less_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) <= 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ops that return the transcript i.e intermediate calcs of an op
|
||||
|
||||
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -27,7 +27,8 @@
|
||||
"check_mode": "UNSAFE",
|
||||
"commitment": "KZG",
|
||||
"decomp_base": 128,
|
||||
"decomp_legs": 2
|
||||
"decomp_legs": 2,
|
||||
"bounded_log_lookup": false
|
||||
},
|
||||
"num_rows": 46,
|
||||
"total_assignments": 92,
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127}
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":0,"max_range_size":127,"version":"source - no compatibility guaranteed"}
|
||||
@@ -205,7 +205,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 94] = [
|
||||
const TESTS: [&str; 96] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -304,6 +304,8 @@ mod native_tests {
|
||||
"lstm_large", // 91
|
||||
"lstm_medium", // 92
|
||||
"lenet_5", // 93
|
||||
"rsqrt", // 94
|
||||
"log", // 95
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -537,12 +539,12 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=93 {
|
||||
seq!(N in 0..=95 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -554,15 +556,7 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_div_rebase_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn accuracy_measurement_public_outputs_(test: &str) {
|
||||
@@ -570,7 +564,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 2.6);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -580,7 +574,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 , false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 2.6 );
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -590,7 +584,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6, false);
|
||||
accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 2.6);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -601,7 +595,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -610,7 +604,17 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_bounded_lookup_log(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0, true);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -621,7 +625,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// gen random number between 0.0 and 1.0
|
||||
let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0;
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance);
|
||||
mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -636,7 +640,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -646,7 +650,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -655,7 +659,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -664,7 +668,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -673,7 +677,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -682,7 +686,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -691,7 +695,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -700,7 +704,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "polycommit", "private", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -710,7 +714,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -720,7 +724,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "polycommit", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -729,7 +733,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -739,7 +743,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "private", "polycommit", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -748,7 +752,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -758,7 +762,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "public", "polycommit", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -768,7 +772,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "polycommit", "polycommit", "polycommit", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -778,7 +782,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -788,7 +792,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -798,7 +802,7 @@ mod native_tests {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
// needs an extra row for the large model
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -851,9 +855,11 @@ mod native_tests {
|
||||
fn kzg_prove_and_verify_tight_lookup_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let path = test_dir.into_path();
|
||||
let path = path.to_str().unwrap();
|
||||
crate::native_tests::mv_test_(path, test);
|
||||
prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, false, "single", Commitments::KZG, 1);
|
||||
test_dir.close().unwrap();
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
@@ -973,7 +979,7 @@ mod native_tests {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0);
|
||||
mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
@@ -1000,21 +1006,13 @@ mod native_tests {
|
||||
use crate::native_tests::run_js_tests;
|
||||
use ezkl::logger::init_logger;
|
||||
use crate::native_tests::lazy_static;
|
||||
use std::sync::Once;
|
||||
|
||||
// Global variables to store verifier hashes and identical verifiers
|
||||
lazy_static! {
|
||||
static ref ANVIL_INSTANCE: std::sync::Mutex<Option<std::process::Child>> = std::sync::Mutex::new(None);
|
||||
// create a new variable of type
|
||||
static ref REUSABLE_VERIFIER_ADDR: std::sync::Mutex<Option<String>> = std::sync::Mutex::new(None);
|
||||
}
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
fn initialize() {
|
||||
INIT.call_once(|| {
|
||||
let anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
*ANVIL_INSTANCE.lock().unwrap() = Some(anvil_child);
|
||||
});
|
||||
}
|
||||
|
||||
/// Currently only on chain inputs that return a non-negative value are supported.
|
||||
const TESTS_ON_CHAIN_INPUT: [&str; 17] = [
|
||||
@@ -1126,10 +1124,9 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=93 {
|
||||
seq!(N in 0..4 {
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier_(test: &str) {
|
||||
initialize();
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
@@ -1137,18 +1134,28 @@ mod native_tests {
|
||||
init_logger();
|
||||
log::error!("Running kzg_evm_prove_and_verify_reusable_verifier_ for test: {}", test);
|
||||
// default vis
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", false);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", &mut REUSABLE_VERIFIER_ADDR.lock().unwrap(), false);
|
||||
// public/public vis
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", false);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", &mut Some(reusable_verifier_address), false);
|
||||
// hashed input
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", false);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", &mut Some(reusable_verifier_address), false);
|
||||
|
||||
match REUSABLE_VERIFIER_ADDR.try_lock() {
|
||||
Ok(mut addr) => {
|
||||
*addr = Some(reusable_verifier_address.clone());
|
||||
log::error!("Reusing the same verifeir deployed at address: {}", reusable_verifier_address);
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to acquire lock on REUSABLE_VERIFIER_ADDR");
|
||||
}
|
||||
}
|
||||
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier_with_overflow_(test: &str) {
|
||||
initialize();
|
||||
// verifier too big to fit on chain with overflow calibration target
|
||||
if test == "1l_eltwise_div" || test == "lenet_5" || test == "ltsf" || test == "lstm_large" {
|
||||
return;
|
||||
@@ -1160,13 +1167,24 @@ mod native_tests {
|
||||
init_logger();
|
||||
log::error!("Running kzg_evm_prove_and_verify_reusable_verifier_with_overflow_ for test: {}", test);
|
||||
// default vis
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", true);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", &mut REUSABLE_VERIFIER_ADDR.lock().unwrap(), true);
|
||||
// public/public vis
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", true);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", &mut Some(reusable_verifier_address), true);
|
||||
// hashed input
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", true);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", &mut Some(reusable_verifier_address), true);
|
||||
|
||||
match REUSABLE_VERIFIER_ADDR.try_lock() {
|
||||
Ok(mut addr) => {
|
||||
*addr = Some(reusable_verifier_address.clone());
|
||||
log::error!("Reusing the same verifeir deployed at address: {}", reusable_verifier_address);
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to acquire lock on REUSABLE_VERIFIER_ADDR");
|
||||
}
|
||||
}
|
||||
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1438,6 +1456,7 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
bounded_lookup_log: bool,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
@@ -1450,10 +1469,10 @@ mod native_tests {
|
||||
cal_target,
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
&mut tolerance,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
bounded_lookup_log,
|
||||
);
|
||||
|
||||
if tolerance > 0.0 {
|
||||
@@ -1591,10 +1610,10 @@ mod native_tests {
|
||||
cal_target: &str,
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
tolerance: &mut f32,
|
||||
commitment: Commitments,
|
||||
lookup_safety_margin: usize,
|
||||
bounded_lookup_log: bool,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1613,13 +1632,12 @@ mod native_tests {
|
||||
format!("--commitment={}", commitment),
|
||||
];
|
||||
|
||||
if div_rebasing {
|
||||
args.push("--div-rebasing".to_string());
|
||||
};
|
||||
if bounded_lookup_log {
|
||||
args.push("--bounded-log-lookup".to_string());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
@@ -1716,7 +1734,6 @@ mod native_tests {
|
||||
batch_size: usize,
|
||||
cal_target: &str,
|
||||
target_perc: f32,
|
||||
div_rebasing: bool,
|
||||
) {
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
@@ -1728,10 +1745,10 @@ mod native_tests {
|
||||
cal_target,
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -2012,10 +2029,10 @@ mod native_tests {
|
||||
target_str,
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
&mut 0.0,
|
||||
commitment,
|
||||
lookup_safety_margin,
|
||||
false,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -2216,8 +2233,9 @@ mod native_tests {
|
||||
input_visibility: &str,
|
||||
param_visibility: &str,
|
||||
output_visibility: &str,
|
||||
reusable_verifier_address: &mut Option<String>,
|
||||
overflow: bool,
|
||||
) {
|
||||
) -> String {
|
||||
let anvil_url = ANVIL_URL.as_str();
|
||||
|
||||
prove_and_verify(
|
||||
@@ -2240,82 +2258,57 @@ mod native_tests {
|
||||
|
||||
let vk_arg = format!("{}/{}/key.vk", test_dir, example_name);
|
||||
let rpc_arg = format!("--rpc-url={}", anvil_url);
|
||||
// addr path for verifier manager contract
|
||||
let addr_path_arg = format!("--addr-path={}/{}/addr.txt", test_dir, example_name);
|
||||
let verifier_manager_arg: String;
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
// reusable verifier sol_arg
|
||||
let sol_arg = format!("--sol-code-path={}/{}/kzg.sol", test_dir, example_name);
|
||||
|
||||
// create the reusable verifier
|
||||
let args = vec![
|
||||
"create-evm-verifier",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg,
|
||||
"--reusable",
|
||||
];
|
||||
// if the reusable verifier address is not set, create the verifier
|
||||
let deployed_addr_arg = match reusable_verifier_address {
|
||||
Some(addr) => addr.clone(),
|
||||
None => {
|
||||
// create the reusable verifier
|
||||
let args = vec![
|
||||
"create-evm-verifier",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg,
|
||||
"--reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// deploy the verifier manager
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
// set the sol code path to be contracts/VerifierManager.sol relative to root
|
||||
"--sol-code-path=contracts/VerifierManager.sol",
|
||||
"-C=manager",
|
||||
];
|
||||
// deploy the verifier
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
sol_arg.as_str(),
|
||||
"-C=verifier/reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address of the verifier manager
|
||||
let addr = std::fs::read_to_string(format!("{}/{}/addr.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
// read in the address
|
||||
let addr =
|
||||
std::fs::read_to_string(format!("{}/{}/addr.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
verifier_manager_arg = format!("--addr-verifier-manager={}", addr);
|
||||
|
||||
// if the reusable verifier address is not set, deploy the verifier manager and then create the verifier
|
||||
let rv_addr = {
|
||||
// addr path for rv contract
|
||||
let addr_path_arg = format!("--addr-path={}/{}/addr_rv.txt", test_dir, example_name);
|
||||
// deploy the reusable verifier via the verifier router.
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
sol_arg.as_str(),
|
||||
verifier_manager_arg.as_str(),
|
||||
"-C=verifier/reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address of the verifier manager
|
||||
let addr =
|
||||
std::fs::read_to_string(format!("{}/{}/addr_rv.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
addr
|
||||
let deployed_addr_arg = format!("--addr-verifier={}", addr);
|
||||
// set the reusable verifier address
|
||||
*reusable_verifier_address = Some(addr);
|
||||
deployed_addr_arg
|
||||
}
|
||||
};
|
||||
|
||||
let addr_path_arg_vk = format!("--addr-path={}/{}/addr_vk.txt", test_dir, example_name);
|
||||
let sol_arg_vk: String = format!("--sol-code-path={}/{}/vk.sol", test_dir, example_name);
|
||||
// create the verifier
|
||||
let addr_path_arg_vk = format!("--addr-path={}/{}/addr_vk.txt", test_dir, example_name);
|
||||
let sol_arg_vk: String = format!("--sol-code-path={}/{}/vk.sol", test_dir, example_name);
|
||||
// create the verifier
|
||||
@@ -2333,15 +2326,11 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let rv_addr_arg = format!("--addr-reusable-verifier={}", rv_addr);
|
||||
|
||||
// deploy the vka via the "DeployVKA" command on the reusable verifier
|
||||
// deploy the vka
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg_vk.as_str(),
|
||||
verifier_manager_arg.as_str(),
|
||||
rv_addr_arg.as_str(),
|
||||
sol_arg_vk.as_str(),
|
||||
"-C=vka",
|
||||
];
|
||||
@@ -2371,8 +2360,6 @@ mod native_tests {
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
let deployed_addr_arg = format!("--addr-verifier={}", rv_addr);
|
||||
|
||||
// now verify the proof
|
||||
let pf_arg = format!("{}/{}/proof.pf", test_dir, example_name);
|
||||
let args = vec![
|
||||
@@ -2432,6 +2419,9 @@ mod native_tests {
|
||||
i
|
||||
);
|
||||
}
|
||||
|
||||
// Returned deploy_addr_arg for reusable verifier
|
||||
deployed_addr_arg
|
||||
}
|
||||
|
||||
// run js browser evm verify tests for a given example
|
||||
@@ -2471,10 +2461,10 @@ mod native_tests {
|
||||
// we need the accuracy
|
||||
Some(vec![4]),
|
||||
1,
|
||||
false,
|
||||
&mut 0.0,
|
||||
Commitments::KZG,
|
||||
2,
|
||||
false,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
@@ -124,41 +124,40 @@ mod py_tests {
|
||||
}
|
||||
|
||||
const TESTS: [&str; 34] = [
|
||||
"ezkl_demo_batch.ipynb",
|
||||
"proof_splitting.ipynb", // 0
|
||||
"variance.ipynb",
|
||||
"mnist_gan.ipynb",
|
||||
// "mnist_vae.ipynb",
|
||||
"keras_simple_demo.ipynb",
|
||||
"mnist_gan_proof_splitting.ipynb", // 4
|
||||
"hashed_vis.ipynb", // 5
|
||||
"simple_demo_all_public.ipynb",
|
||||
"data_attest.ipynb",
|
||||
"little_transformer.ipynb",
|
||||
"simple_demo_aggregated_proofs.ipynb",
|
||||
"ezkl_demo.ipynb", // 10
|
||||
"lstm.ipynb",
|
||||
"set_membership.ipynb", // 12
|
||||
"decision_tree.ipynb",
|
||||
"random_forest.ipynb",
|
||||
"gradient_boosted_trees.ipynb", // 15
|
||||
"xgboost.ipynb",
|
||||
"lightgbm.ipynb",
|
||||
"svm.ipynb",
|
||||
"simple_demo_public_input_output.ipynb",
|
||||
"simple_demo_public_network_output.ipynb", // 20
|
||||
"gcn.ipynb",
|
||||
"linear_regression.ipynb",
|
||||
"stacked_regression.ipynb",
|
||||
"data_attest_hashed.ipynb",
|
||||
"kzg_vis.ipynb", // 25
|
||||
"kmeans.ipynb",
|
||||
"solvency.ipynb",
|
||||
"sklearn_mlp.ipynb",
|
||||
"generalized_inverse.ipynb",
|
||||
"mnist_classifier.ipynb", // 30
|
||||
"world_rotation.ipynb",
|
||||
"logistic_regression.ipynb",
|
||||
"ezkl_demo_batch.ipynb", // 0
|
||||
"proof_splitting.ipynb", // 1
|
||||
"variance.ipynb", // 2
|
||||
"mnist_gan.ipynb", // 3
|
||||
"keras_simple_demo.ipynb", // 4
|
||||
"mnist_gan_proof_splitting.ipynb", // 5
|
||||
"hashed_vis.ipynb", // 6
|
||||
"simple_demo_all_public.ipynb", // 7
|
||||
"data_attest.ipynb", // 8
|
||||
"little_transformer.ipynb", // 9
|
||||
"simple_demo_aggregated_proofs.ipynb", // 10
|
||||
"ezkl_demo.ipynb", // 11
|
||||
"lstm.ipynb", // 12
|
||||
"set_membership.ipynb", // 13
|
||||
"decision_tree.ipynb", // 14
|
||||
"random_forest.ipynb", // 15
|
||||
"gradient_boosted_trees.ipynb", // 16
|
||||
"xgboost.ipynb", // 17
|
||||
"lightgbm.ipynb", // 18
|
||||
"svm.ipynb", // 19
|
||||
"simple_demo_public_input_output.ipynb", // 20
|
||||
"simple_demo_public_network_output.ipynb", // 21
|
||||
"gcn.ipynb", // 22
|
||||
"linear_regression.ipynb", // 23
|
||||
"stacked_regression.ipynb", // 24
|
||||
"data_attest_hashed.ipynb", // 25
|
||||
"kzg_vis.ipynb", // 26
|
||||
"kmeans.ipynb", // 27
|
||||
"solvency.ipynb", // 28
|
||||
"sklearn_mlp.ipynb", // 29
|
||||
"generalized_inverse.ipynb", // 30
|
||||
"mnist_classifier.ipynb", // 31
|
||||
"world_rotation.ipynb", // 32
|
||||
"logistic_regression.ipynb", // 33
|
||||
];
|
||||
|
||||
macro_rules! test_func {
|
||||
|
||||
@@ -1 +1 @@
|
||||
[{"type":"function","name":"deployVKA","inputs":[{"name":"bytecode","type":"bytes","internalType":"bytes"}],"outputs":[{"name":"addr","type":"address","internalType":"address"}],"stateMutability":"nonpayable"},{"type":"function","name":"precomputeAddress","inputs":[{"name":"bytecode","type":"bytes","internalType":"bytes"}],"outputs":[{"name":"","type":"address","internalType":"address"}],"stateMutability":"view"},{"type":"function","name":"verifyProof","inputs":[{"name":"vk","type":"address","internalType":"address"},{"name":"proof","type":"bytes","internalType":"bytes"},{"name":"instances","type":"uint256[]","internalType":"uint256[]"}],"outputs":[{"name":"","type":"bool","internalType":"bool"}],"stateMutability":"nonpayable"},{"type":"function","name":"vkaLog","inputs":[{"name":"","type":"address","internalType":"address"}],"outputs":[{"name":"","type":"bool","internalType":"bool"}],"stateMutability":"view"},{"type":"event","name":"DeployedVKArtifact","inputs":[{"name":"vka","type":"address","indexed":false,"internalType":"address"}],"anonymous":false},{"type":"error","name":"UnloggedVka","inputs":[{"name":"vka","type":"address","internalType":"address"}]}]
|
||||
[{"type":"function","name":"verifyProof","inputs":[{"internalType":"bytes","name":"proof","type":"bytes"},{"internalType":"uint256[]","name":"instances","type":"uint256[]"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"nonpayable"}]
|
||||
Reference in New Issue
Block a user