mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
10 Commits
verifier-r
...
v15.2.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
457196f9c1 | ||
|
|
a3c131dac0 | ||
|
|
fd9c2305ac | ||
|
|
a0060f341d | ||
|
|
17f1d42739 | ||
|
|
ebaee9e2b1 | ||
|
|
d51cba589a | ||
|
|
1cb1b6e143 | ||
|
|
d2b683b527 | ||
|
|
a06b09ef1f |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,7 +9,6 @@ pkg
|
||||
!AttestData.sol
|
||||
!VerifierBase.sol
|
||||
!LoadInstances.sol
|
||||
!VerifierManager.sol
|
||||
*.pf
|
||||
*.vk
|
||||
*.pk
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2397,7 +2397,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_solidity_verifier"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=vka-log#c319e229ad677ee4c7d95bdae45c2958350cfd14"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier?branch=ac/update-h2-curves#eede1db7f3e599112bd1186e9d1913286bdcb539"
|
||||
dependencies = [
|
||||
"askama",
|
||||
"blake2b_simd",
|
||||
|
||||
64
Cargo.toml
64
Cargo.toml
@@ -19,11 +19,8 @@ crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde",
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = [
|
||||
"circuit-params",
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = ["circuit-params"] }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
itertools = { version = "0.10.3", default-features = false }
|
||||
clap = { version = "4.5.3", features = ["derive"], optional = true }
|
||||
@@ -36,9 +33,9 @@ halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "vka-log", optional = true }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves", optional = true }
|
||||
maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = { version = "1.6.0", optional = true }
|
||||
@@ -46,7 +43,10 @@ tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package
|
||||
semver = { version = "1.0.22", optional = true }
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
serde_json = { version = "1.0.97", features = ["float_roundtrip", "raw_value"] }
|
||||
serde_json = { version = "1.0.97", features = [
|
||||
"float_roundtrip",
|
||||
"raw_value",
|
||||
] }
|
||||
|
||||
# evm related deps
|
||||
alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = [
|
||||
@@ -56,39 +56,23 @@ alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5
|
||||
"rpc-types-eth",
|
||||
"signer-wallet",
|
||||
"node-bindings",
|
||||
|
||||
], optional = true }
|
||||
foundry-compilers = { version = "0.4.1", features = [
|
||||
"svm-solc",
|
||||
|
||||
], optional = true }
|
||||
foundry-compilers = { version = "0.4.1", features = ["svm-solc"], optional = true }
|
||||
ethabi = { version = "18", optional = true }
|
||||
indicatif = { version = "0.17.5", features = ["rayon"], optional = true }
|
||||
gag = { version = "1.0.0", default-features = false, optional = true }
|
||||
instant = { version = "0.1" }
|
||||
reqwest = { version = "0.12.4", default-features = false, features = [
|
||||
"default-tls",
|
||||
"multipart",
|
||||
"stream",
|
||||
], optional = true }
|
||||
reqwest = { version = "0.12.4", default-features = false, features = ["default-tls", "multipart", "stream"], optional = true }
|
||||
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
|
||||
tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
regex = { version = "1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
], optional = true }
|
||||
pyo3 = { version = "0.21.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default-features = false, optional = true }
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch = "migration-pyo3-0.21", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = ["macros", "rt-multi-thread"], optional = true }
|
||||
pyo3 = { version = "0.21.2", features = ["extension-module", "abi3-py37", "macros"], default-features = false, optional = true }
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = ["attributes", "tokio-runtime"], default-features = false, optional = true }
|
||||
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
@@ -185,7 +169,7 @@ harness = false
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "relu"
|
||||
name = "sigmoid"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
@@ -193,12 +177,12 @@ name = "relu_lookupless"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu"
|
||||
name = "accum_matmul_sigmoid"
|
||||
harness = false
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "accum_matmul_relu_overflow"
|
||||
name = "accum_matmul_sigmoid_overflow"
|
||||
harness = false
|
||||
|
||||
[[bin]]
|
||||
@@ -213,13 +197,7 @@ required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = [
|
||||
"ezkl",
|
||||
"mv-lookup",
|
||||
"precompute-coset",
|
||||
"no-banner",
|
||||
"parallel-poly-read",
|
||||
]
|
||||
default = ["ezkl", "mv-lookup", "precompute-coset", "no-banner", "parallel-poly-read"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
|
||||
@@ -253,10 +231,7 @@ ezkl = [
|
||||
"dep:clap",
|
||||
"dep:tosubcommand",
|
||||
]
|
||||
parallel-poly-read = [
|
||||
"halo2_proofs/circuit-params",
|
||||
"halo2_proofs/parallel-poly-read",
|
||||
]
|
||||
parallel-poly-read = ["halo2_proofs/circuit-params", "halo2_proofs/parallel-poly-read"]
|
||||
mv-lookup = [
|
||||
"halo2_proofs/mv-lookup",
|
||||
"snark-verifier/mv-lookup",
|
||||
@@ -285,3 +260,4 @@ rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "owner",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "OwnableInvalidOwner",
|
||||
"type": "error"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "account",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "OwnableUnauthorizedAccount",
|
||||
"type": "error"
|
||||
},
|
||||
{
|
||||
"anonymous": false,
|
||||
"inputs": [
|
||||
{
|
||||
"indexed": false,
|
||||
"internalType": "address",
|
||||
"name": "addr",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "DeployedVerifier",
|
||||
"type": "event"
|
||||
},
|
||||
{
|
||||
"anonymous": false,
|
||||
"inputs": [
|
||||
{
|
||||
"indexed": true,
|
||||
"internalType": "address",
|
||||
"name": "previousOwner",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"indexed": true,
|
||||
"internalType": "address",
|
||||
"name": "newOwner",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "OwnershipTransferred",
|
||||
"type": "event"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "bytecode",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "deployVerifier",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "addr",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "owner",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "bytecode",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "precomputeAddress",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "renounceOwnership",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "newOwner",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "transferOwnership",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "verifierAddresses",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
@@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&a,
|
||||
BITS,
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
&LookupOp::Sigmoid { scale: 1.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -93,7 +93,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -65,7 +65,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&a,
|
||||
BITS,
|
||||
k,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
&LookupOp::Sigmoid { scale: 1.0.into() },
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -94,7 +94,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -68,7 +68,14 @@ impl Circuit<Fr> for NLCircuit {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|
||||
@@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
|
||||
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
|
||||
|
||||
let mut config = Config::default();
|
||||
|
||||
@@ -68,7 +68,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -1,184 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
pragma solidity 0.8.20;
|
||||
|
||||
// lib/openzeppelin-contracts/contracts/utils/Context.sol
|
||||
|
||||
// OpenZeppelin Contracts (last updated v5.0.1) (utils/Context.sol)
|
||||
|
||||
/**
|
||||
* @dev Provides information about the current execution context, including the
|
||||
* sender of the transaction and its data. While these are generally available
|
||||
* via msg.sender and msg.data, they should not be accessed in such a direct
|
||||
* manner, since when dealing with meta-transactions the account sending and
|
||||
* paying for execution may not be the actual sender (as far as an application
|
||||
* is concerned).
|
||||
*
|
||||
* This contract is only required for intermediate, library-like contracts.
|
||||
*/
|
||||
abstract contract Context {
|
||||
function _msgSender() internal view virtual returns (address) {
|
||||
return msg.sender;
|
||||
}
|
||||
|
||||
function _msgData() internal view virtual returns (bytes calldata) {
|
||||
return msg.data;
|
||||
}
|
||||
|
||||
function _contextSuffixLength() internal view virtual returns (uint256) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// lib/openzeppelin-contracts/contracts/access/Ownable.sol
|
||||
|
||||
// OpenZeppelin Contracts (last updated v5.0.0) (access/Ownable.sol)
|
||||
|
||||
/**
|
||||
* @dev Contract module which provides a basic access control mechanism, where
|
||||
* there is an account (an owner) that can be granted exclusive access to
|
||||
* specific functions.
|
||||
*
|
||||
* The initial owner is set to the address provided by the deployer. This can
|
||||
* later be changed with {transferOwnership}.
|
||||
*
|
||||
* This module is used through inheritance. It will make available the modifier
|
||||
* `onlyOwner`, which can be applied to your functions to restrict their use to
|
||||
* the owner.
|
||||
*/
|
||||
abstract contract Ownable is Context {
|
||||
/// set the owener initialy to be the anvil test account
|
||||
address private _owner = 0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266;
|
||||
|
||||
/**
|
||||
* @dev The caller account is not authorized to perform an operation.
|
||||
*/
|
||||
error OwnableUnauthorizedAccount(address account);
|
||||
|
||||
/**
|
||||
* @dev The owner is not a valid owner account. (eg. `address(0)`)
|
||||
*/
|
||||
error OwnableInvalidOwner(address owner);
|
||||
|
||||
event OwnershipTransferred(
|
||||
address indexed previousOwner,
|
||||
address indexed newOwner
|
||||
);
|
||||
|
||||
/**
|
||||
* @dev Initializes the contract setting the address provided by the deployer as the initial owner.
|
||||
*/
|
||||
constructor() {
|
||||
_transferOwnership(msg.sender);
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Throws if called by any account other than the owner.
|
||||
*/
|
||||
modifier onlyOwner() {
|
||||
_checkOwner();
|
||||
_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Returns the address of the current owner.
|
||||
*/
|
||||
function owner() public view virtual returns (address) {
|
||||
return _owner;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Throws if the sender is not the owner.
|
||||
*/
|
||||
function _checkOwner() internal view virtual {
|
||||
if (owner() != _msgSender()) {
|
||||
revert OwnableUnauthorizedAccount(_msgSender());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Leaves the contract without owner. It will not be possible to call
|
||||
* `onlyOwner` functions. Can only be called by the current owner.
|
||||
*
|
||||
* NOTE: Renouncing ownership will leave the contract without an owner,
|
||||
* thereby disabling any functionality that is only available to the owner.
|
||||
*/
|
||||
function renounceOwnership() public virtual onlyOwner {
|
||||
_transferOwnership(address(0));
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Transfers ownership of the contract to a new account (`newOwner`).
|
||||
* Can only be called by the current owner.
|
||||
*/
|
||||
function transferOwnership(address newOwner) public virtual onlyOwner {
|
||||
if (newOwner == address(0)) {
|
||||
revert OwnableInvalidOwner(address(0));
|
||||
}
|
||||
_transferOwnership(newOwner);
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Transfers ownership of the contract to a new account (`newOwner`).
|
||||
* Internal function without access restriction.
|
||||
*/
|
||||
function _transferOwnership(address newOwner) internal virtual {
|
||||
address oldOwner = _owner;
|
||||
_owner = newOwner;
|
||||
emit OwnershipTransferred(oldOwner, newOwner);
|
||||
}
|
||||
}
|
||||
|
||||
// interface for the reusable verifier.
|
||||
interface Halo2VerifierReusable {
|
||||
function verifyProof(
|
||||
address vkArtifact,
|
||||
bytes calldata proof,
|
||||
uint256[] calldata instances
|
||||
) external returns (bool);
|
||||
}
|
||||
|
||||
// Manages the deployment of all EZKL reusbale verifiers (ezkl version specific), verifiying key artifacts (circuit specific) and
|
||||
// routing proof verifications to the correct VKA and associate reusable verifier.
|
||||
// Helps to prevent the deployment of duplicate verifiers.
|
||||
contract EZKLVerifierManager is Ownable {
|
||||
/// @dev Mapping that checks if a given reusable verifier has been deployed
|
||||
mapping(address => bool) public verifierAddresses;
|
||||
|
||||
event DeployedVerifier(address addr);
|
||||
|
||||
// 1. Compute the address of the verifier to be deployed
|
||||
function precomputeAddress(
|
||||
bytes memory bytecode
|
||||
) public view returns (address) {
|
||||
bytes32 hash = keccak256(
|
||||
abi.encodePacked(
|
||||
bytes1(0xff),
|
||||
address(this),
|
||||
uint(0),
|
||||
keccak256(bytecode)
|
||||
)
|
||||
);
|
||||
|
||||
return address(uint160(uint(hash)));
|
||||
}
|
||||
|
||||
// 2. Deploy the reusable verifier using create2
|
||||
/// @param bytecode The bytecode of the reusable verifier to deploy
|
||||
function deployVerifier(
|
||||
bytes memory bytecode
|
||||
) public returns (address addr) {
|
||||
assembly {
|
||||
addr := create2(
|
||||
0x0, // value, hardcode to 0
|
||||
add(bytecode, 0x20),
|
||||
mload(bytecode),
|
||||
0x0 // salt, hardcode to 0
|
||||
)
|
||||
if iszero(extcodesize(addr)) {
|
||||
revert(0, 0)
|
||||
}
|
||||
}
|
||||
verifierAddresses[addr] = true;
|
||||
emit DeployedVerifier(addr);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
ezkl==15.2.0
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '15.2.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -146,6 +146,8 @@ where
|
||||
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
println!("INPUT COL {:#?}", input);
|
||||
|
||||
let mut layer_config = PolyConfig::configure(
|
||||
@@ -156,15 +158,11 @@ where
|
||||
);
|
||||
|
||||
layer_config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&input,
|
||||
&output,
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.configure_range_check(cs, &input, ¶ms, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
.configure_range_check(cs, &input, ¶ms, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
@@ -195,6 +193,11 @@ where
|
||||
) -> Result<(), Error> {
|
||||
config.layer_config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
config
|
||||
.layer_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
|
||||
let x = layouter
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
@@ -224,7 +227,10 @@ where
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x.unwrap()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -53,6 +53,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
// tells the config layer to add an affine op to the circuit gate
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN, false);
|
||||
|
||||
println!("INPUT COL {:#?}", input);
|
||||
|
||||
let mut layer_config = PolyConfig::<F>::configure(
|
||||
cs,
|
||||
&[input.clone(), params.clone()],
|
||||
@@ -60,17 +64,12 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
|
||||
layer_config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&input,
|
||||
&output,
|
||||
¶ms,
|
||||
(LOOKUP_MIN, LOOKUP_MAX),
|
||||
K,
|
||||
&LookupOp::LeakyReLU { slope: 0.0.into() },
|
||||
)
|
||||
.configure_range_check(cs, &input, ¶ms, (-1, 1), K)
|
||||
.unwrap();
|
||||
|
||||
layer_config
|
||||
.configure_range_check(cs, &input, ¶ms, (0, 1023), K)
|
||||
.unwrap();
|
||||
|
||||
// sets up a new ReLU table and resuses it for l1 and l3 non linearities
|
||||
@@ -104,6 +103,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
) -> Result<(), Error> {
|
||||
config.layer_config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
config
|
||||
.layer_config
|
||||
.layout_range_checks(&mut layouter)
|
||||
.unwrap();
|
||||
|
||||
let x = layouter
|
||||
.assign_region(
|
||||
|| "mlp_4d",
|
||||
@@ -144,7 +148,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
}),
|
||||
)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
@@ -184,7 +191,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
println!("6");
|
||||
|
||||
@@ -21,9 +21,9 @@ def main():
|
||||
torch_model = Circuit()
|
||||
# Input to the model
|
||||
shape = [3, 2, 3]
|
||||
w = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
x = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
y = 0.1*torch.rand(1, *shape, requires_grad=True)
|
||||
w = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
x = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
y = 2 * torch.rand(1, *shape, requires_grad=True) - 1
|
||||
torch_out = torch_model(w, x, y)
|
||||
# Export the model
|
||||
torch.onnx.export(torch_model, # model being run
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}
|
||||
{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.6261028051376343, 0.49872446060180664, -0.04514765739440918, 0.5936200618743896, 0.9271858930587769, 0.6688600778579712, -0.20331168174743652, -0.7016235589981079, 0.025863051414489746, -0.19426143169403076, 0.9827852249145508, 0.4897397756576538, 0.2992602586746216, 0.7011144161224365, 0.9278832674026489, 0.5943725109100342, -0.573331356048584, 0.3675816059112549], [0.7803324460983276, -0.9616303443908691, 0.6070173978805542, -0.028337717056274414, -0.5080242156982422, -0.9280107021331787, 0.6150380373001099, 0.3865993022918701, -0.43668973445892334, 0.17152702808380127, 0.5144252777099609, -0.28881049156188965, 0.8932310342788696, 0.059034109115600586, 0.6865451335906982, 0.009820222854614258, 0.23011493682861328, -0.9492779970169067], [-0.21352827548980713, -0.16015326976776123, -0.38964390754699707, 0.13464701175689697, -0.8814496994018555, 0.5037975311279297, -0.804405927658081, 0.9858957529067993, 0.19567716121673584, 0.9777265787124634, 0.6151977777481079, 0.568595290184021, 0.10584986209869385, -0.8975653648376465, 0.6235959529876709, -0.547879695892334, 0.9289869070053101, 0.7567293643951416]], "output_data": [[1.0, 0.0, -0.0, 1.0, 1.0, 1.0, -0.0, -1.0, 0.0, -0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, -1.0, 0.0], [0.0, -1.0, 0.0, -1.0, -1.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0], [-0.0, -0.0, -0.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -0.0, 1.0, -0.0, 1.0, 1.0]]}
|
||||
@@ -1,10 +1,11 @@
|
||||
pytorch2.0.1:â
|
||||
pytorch2.2.2:ă
|
||||
|
||||
woutput_w/Round"Round
|
||||
|
||||
xoutput_x/Floor"Floor
|
||||
|
||||
youtput_y/Ceil"Ceil torch_jitZ%
|
||||
youtput_y/Ceil"Ceil
|
||||
main_graphZ%
|
||||
w
|
||||
|
||||
|
||||
|
||||
42
examples/onnx/rsqrt/gen.py
Normal file
42
examples/onnx/rsqrt/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# reciprocal sqrt
|
||||
m = 1 / torch.sqrt(x)
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/rsqrt/input.json
Normal file
1
examples/onnx/rsqrt/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.8590779900550842, 0.4029041528701782, 0.6507361531257629, 0.9782488942146301, 0.37392884492874146, 0.6867020726203918, 0.11407750844955444, 0.362740159034729]]}
|
||||
17
examples/onnx/rsqrt/network.onnx
Normal file
17
examples/onnx/rsqrt/network.onnx
Normal file
@@ -0,0 +1,17 @@
|
||||
pytorch2.2.2:Ź
|
||||
$
|
||||
input/Sqrt_output_0/Sqrt"Sqrt
|
||||
1
|
||||
/Sqrt_output_0output/Reciprocal"
|
||||
Reciprocal
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -94,4 +94,7 @@ pub enum CircuitError {
|
||||
#[error("[io] {0}")]
|
||||
/// IO error
|
||||
IoError(#[from] std::io::Error),
|
||||
/// Invalid scale
|
||||
#[error("negative scale for an op that requires positive inputs {0}")]
|
||||
NegativeScale(String),
|
||||
}
|
||||
|
||||
@@ -13,10 +13,21 @@ use serde::{Deserialize, Serialize};
|
||||
/// An enum representing the operations that consist of both lookups and arithmetic operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HybridOp {
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
use_range_check_for_int: bool,
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
@@ -45,6 +56,8 @@ pub enum HybridOp {
|
||||
ReduceArgMin {
|
||||
dim: usize,
|
||||
},
|
||||
Max,
|
||||
Min,
|
||||
Softmax {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
@@ -79,6 +92,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
| HybridOp::Less { .. }
|
||||
| HybridOp::Equals { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
| HybridOp::Max
|
||||
| HybridOp::Min
|
||||
| HybridOp::LessEqual { .. } => {
|
||||
vec![0, 1]
|
||||
}
|
||||
@@ -93,13 +108,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),
|
||||
HybridOp::Max => format!("MAX"),
|
||||
HybridOp::Min => format!("MIN"),
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => format!(
|
||||
"RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})",
|
||||
input_scale, output_scale, use_range_check_for_int
|
||||
"RECIP (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
@@ -162,6 +181,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Ceil { scale, legs } => {
|
||||
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Floor { scale, legs } => {
|
||||
layouts::floor(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Round { scale, legs } => {
|
||||
layouts::round(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -179,31 +209,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
if input_scale.0.fract() == 0.0
|
||||
&& output_scale.0.fract() == 0.0
|
||||
&& *use_range_check_for_int
|
||||
{
|
||||
layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Recip {
|
||||
input_scale: *input_scale,
|
||||
output_scale: *output_scale,
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
} => layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?,
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
|
||||
@@ -4155,6 +4155,110 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
Ok(assigned_argmin)
|
||||
}
|
||||
|
||||
/// Max layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 2]
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::max_comp;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let y = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
///
|
||||
/// let result = max_comp::<Fp>(&dummy_config, &mut dummy_region, &[x, y]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[5, 2, 3, 1]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn max_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let is_greater = greater(config, region, values)?;
|
||||
let is_less = not(config, region, &[is_greater.clone()])?;
|
||||
|
||||
let max_val_p1 = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), is_greater],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
let max_val_p2 = pairwise(config, region, &[values[1].clone(), is_less], BaseOp::Mult)?;
|
||||
|
||||
pairwise(config, region, &[max_val_p1, max_val_p2], BaseOp::Add)
|
||||
}
|
||||
|
||||
/// Min comp layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 2]
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::min_comp;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 2, 3, 0]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let y = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[5, 1, 1, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = min_comp::<Fp>(&dummy_config, &mut dummy_region, &[x, y]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[5, 1, 1, 0]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn min_comp<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let is_greater = greater(config, region, values)?;
|
||||
let is_less = not(config, region, &[is_greater.clone()])?;
|
||||
|
||||
let min_val_p1 = pairwise(config, region, &[values[0].clone(), is_less], BaseOp::Mult)?;
|
||||
|
||||
let min_val_p2 = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[values[1].clone(), is_greater],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
pairwise(config, region, &[min_val_p1, min_val_p2], BaseOp::Add)
|
||||
}
|
||||
|
||||
/// max layout
|
||||
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -4178,6 +4282,438 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// floor layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// * `legs` - usize
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::floor;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, -3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = floor::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into(), 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, -2, -4, 0]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
pub fn floor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
let negative_one = create_constant_tensor(integer_rep_to_felt(-1), 1);
|
||||
let assigned_negative_one = region.assign(&config.custom_gates.inputs[1], &negative_one)?;
|
||||
|
||||
region.increment(1);
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
|
||||
let mut incremented_tensor = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = decomposition.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let last_elem_is_zero = equals_zero(config, region, &[last_elem.clone()])?;
|
||||
let last_elem_is_not_zero = not(config, region, &[last_elem_is_zero.clone()])?;
|
||||
|
||||
let sign = sliced_input.first()?;
|
||||
let is_negative = equals(config, region, &[sign, assigned_negative_one.clone()])?;
|
||||
|
||||
let is_negative_and_not_zero = and(
|
||||
config,
|
||||
region,
|
||||
&[last_elem_is_not_zero.clone(), is_negative.clone()],
|
||||
)?;
|
||||
|
||||
// increment the penultimate element
|
||||
let incremented_elem = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
|
||||
is_negative_and_not_zero.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
|
||||
inner_tensor[sliced_input.len() - 2] =
|
||||
incremented_elem.get_inner_tensor()?.clone()[0].clone();
|
||||
|
||||
// set the last elem to zero
|
||||
inner_tensor[sliced_input.len() - 1] = zero.clone();
|
||||
|
||||
Ok(inner_tensor.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;
|
||||
|
||||
let mut incremented_tensor = incremented_tensor.combine()?;
|
||||
incremented_tensor.reshape(&dims)?;
|
||||
|
||||
recompose(
|
||||
config,
|
||||
region,
|
||||
&[incremented_tensor.into()],
|
||||
&(scale.0 as usize),
|
||||
)
|
||||
}
|
||||
|
||||
/// ceil layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// * `legs` - usize
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::ceil;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = ceil::<Fp>(&dummy_config, &mut dummy_region, &[x], 2.0.into(), 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, -2, 4, 2]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn ceil<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
let one = create_constant_tensor(integer_rep_to_felt(1), 1);
|
||||
let assigned_one = region.assign(&config.custom_gates.inputs[1], &one)?;
|
||||
|
||||
region.increment(1);
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
|
||||
let mut incremented_tensor = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = decomposition.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let last_elem_is_zero = equals_zero(config, region, &[last_elem.clone()])?;
|
||||
let last_elem_is_not_zero = not(config, region, &[last_elem_is_zero.clone()])?;
|
||||
|
||||
let sign = sliced_input.first()?;
|
||||
let is_positive = equals(config, region, &[sign, assigned_one.clone()])?;
|
||||
|
||||
let is_positive_and_not_zero = and(
|
||||
config,
|
||||
region,
|
||||
&[last_elem_is_not_zero.clone(), is_positive.clone()],
|
||||
)?;
|
||||
|
||||
// increment the penultimate element
|
||||
let incremented_elem = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
|
||||
is_positive_and_not_zero.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
|
||||
inner_tensor[sliced_input.len() - 2] =
|
||||
incremented_elem.get_inner_tensor()?.clone()[0].clone();
|
||||
|
||||
// set the last elem to zero
|
||||
inner_tensor[sliced_input.len() - 1] = zero.clone();
|
||||
|
||||
Ok(inner_tensor.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;
|
||||
|
||||
let mut incremented_tensor = incremented_tensor.combine()?;
|
||||
incremented_tensor.reshape(&dims)?;
|
||||
|
||||
recompose(
|
||||
config,
|
||||
region,
|
||||
&[incremented_tensor.into()],
|
||||
&(scale.0 as usize),
|
||||
)
|
||||
}
|
||||
|
||||
/// round layout
|
||||
/// # Arguments
|
||||
/// * `config` - BaseConfig
|
||||
/// * `region` - RegionCtx
|
||||
/// * `values` - &[ValTensor<F>; 1]
|
||||
/// * `scale` - utils::F32
|
||||
/// * `legs` - usize
|
||||
/// # Returns
|
||||
/// * ValTensor<F>
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::circuit::ops::layouts::round;
|
||||
/// use ezkl::tensor::val::ValTensor;
|
||||
/// use halo2curves::bn256::Fr as Fp;
|
||||
/// use ezkl::circuit::region::RegionCtx;
|
||||
/// use ezkl::circuit::region::RegionSettings;
|
||||
/// use ezkl::circuit::BaseConfig;
|
||||
/// let dummy_config = BaseConfig::dummy(12, 2);
|
||||
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
|
||||
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, -2, 3, 1]),
|
||||
/// &[1, 1, 2, 2],
|
||||
/// ).unwrap());
|
||||
/// let result = round::<Fp>(&dummy_config, &mut dummy_region, &[x], 4.0.into(), 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, -4, 4, 0]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(result.int_evals().unwrap(), expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
legs: usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
// decompose with base scale and then set the last element to zero
|
||||
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
|
||||
// set the last element to zero and then recompose, we don't actually need to assign here
|
||||
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
|
||||
let zero = ValType::Constant(F::ZERO);
|
||||
|
||||
let one = create_constant_tensor(integer_rep_to_felt(1), 1);
|
||||
let assigned_one = region.assign(&config.custom_gates.inputs[1], &one)?;
|
||||
let negative_one = create_constant_tensor(integer_rep_to_felt(-1), 1);
|
||||
let assigned_negative_one = region.assign(&config.custom_gates.output, &negative_one)?;
|
||||
|
||||
region.increment(1);
|
||||
|
||||
// if scale is not exactly divisible by 2 we warn
|
||||
if scale.0 % 2.0 != 0.0 {
|
||||
log::warn!("Scale is not exactly divisible by 2.0, rounding may not be accurate");
|
||||
}
|
||||
|
||||
let midway_point: ValTensor<F> = create_constant_tensor(
|
||||
integer_rep_to_felt((scale.0 / 2.0).round() as IntegerRep),
|
||||
1,
|
||||
);
|
||||
let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?;
|
||||
|
||||
let dims = decomposition.dims().to_vec();
|
||||
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();
|
||||
|
||||
let mut incremented_tensor = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = decomposition.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
let last_elem = sliced_input.last()?;
|
||||
|
||||
let sign = sliced_input.first()?;
|
||||
let is_positive = equals(config, region, &[sign.clone(), assigned_one.clone()])?;
|
||||
let is_negative = equals(config, region, &[sign, assigned_negative_one.clone()])?;
|
||||
|
||||
let is_greater_than_midway = greater_equal(
|
||||
config,
|
||||
region,
|
||||
&[last_elem.clone(), assigned_midway_point.clone()],
|
||||
)?;
|
||||
|
||||
// if greater than midway point and positive, increment
|
||||
let is_positive_and_more_than_midway = and(
|
||||
config,
|
||||
region,
|
||||
&[is_positive.clone(), is_greater_than_midway.clone()],
|
||||
)?;
|
||||
|
||||
// is less than midway point and negative, decrement
|
||||
let is_negative_and_more_than_midway = and(
|
||||
config,
|
||||
region,
|
||||
&[is_negative.clone(), is_greater_than_midway],
|
||||
)?;
|
||||
|
||||
let conditions_for_increment = or(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
is_positive_and_more_than_midway.clone(),
|
||||
is_negative_and_more_than_midway.clone(),
|
||||
],
|
||||
)?;
|
||||
|
||||
// increment the penultimate element
|
||||
let incremented_elem = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[
|
||||
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
|
||||
conditions_for_increment.clone(),
|
||||
],
|
||||
BaseOp::Add,
|
||||
)?;
|
||||
|
||||
let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
|
||||
inner_tensor[sliced_input.len() - 2] =
|
||||
incremented_elem.get_inner_tensor()?.clone()[0].clone();
|
||||
|
||||
// set the last elem to zero
|
||||
inner_tensor[sliced_input.len() - 1] = zero.clone();
|
||||
|
||||
Ok(inner_tensor.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;
|
||||
|
||||
let mut incremented_tensor = incremented_tensor.combine()?;
|
||||
incremented_tensor.reshape(&dims)?;
|
||||
|
||||
recompose(
|
||||
config,
|
||||
region,
|
||||
&[incremented_tensor.into()],
|
||||
&(scale.0 as usize),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
base: &usize,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let input = values[0].clone();
|
||||
|
||||
let first_dims = input.dims().to_vec()[..input.dims().len() - 1].to_vec();
|
||||
let n = input.dims().last().unwrap() - 1;
|
||||
|
||||
let is_assigned = !input.all_prev_assigned();
|
||||
|
||||
let bases: ValTensor<F> = Tensor::from(
|
||||
(0..n)
|
||||
.rev()
|
||||
.map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))),
|
||||
)
|
||||
.into();
|
||||
|
||||
// multiply and sum the values
|
||||
let mut output: Tensor<Tensor<ValType<F>>> = Tensor::new(None, &first_dims)?;
|
||||
|
||||
let cartesian_coord = first_dims
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inner_loop_function =
|
||||
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
|
||||
let coord = cartesian_coord[i].clone();
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut sliced_input = input.get_slice(&slice)?;
|
||||
sliced_input.flatten();
|
||||
|
||||
if !is_assigned {
|
||||
sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?;
|
||||
}
|
||||
|
||||
// get the sign bit and make sure it is valid
|
||||
let sign = sliced_input.first()?;
|
||||
let rest = sliced_input.get_slice(&[1..sliced_input.len()])?;
|
||||
|
||||
let prod_decomp = dot(config, region, &[rest, bases.clone()])?;
|
||||
|
||||
let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?;
|
||||
|
||||
Ok(signed_decomp.get_inner_tensor()?.clone())
|
||||
};
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let mut combined_output = output.combine()?;
|
||||
|
||||
combined_output.reshape(&first_dims)?;
|
||||
|
||||
Ok(combined_output.into())
|
||||
}
|
||||
|
||||
pub(crate) fn decompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
@@ -4263,7 +4799,6 @@ pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let mut decomp = decompose(config, region, values, ®ion.base(), ®ion.legs())?;
|
||||
// get every n elements now, which correspond to the sign bit
|
||||
|
||||
decomp.get_every_n(region.legs() + 1)?;
|
||||
decomp.reshape(values[0].dims())?;
|
||||
|
||||
@@ -4280,10 +4815,12 @@ pub(crate) fn abs<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pairwise(config, region, &[values[0].clone(), sign], BaseOp::Mult)
|
||||
}
|
||||
|
||||
pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
pub(crate) fn leaky_relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
alpha: &utils::F32,
|
||||
input_scale: &i32,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let sign = sign(config, region, values)?;
|
||||
|
||||
@@ -4292,12 +4829,45 @@ pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
let relu_mask = equals(config, region, &[sign, unit])?;
|
||||
|
||||
pairwise(
|
||||
let positive = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), relu_mask],
|
||||
&[values[0].clone(), relu_mask.clone()],
|
||||
BaseOp::Mult,
|
||||
)
|
||||
)?;
|
||||
|
||||
if alpha.0 == 0. {
|
||||
return Ok(positive);
|
||||
}
|
||||
|
||||
if input_scale < &0 {
|
||||
return Err(CircuitError::NegativeScale("leaky_relu".to_string()));
|
||||
}
|
||||
|
||||
let scale_constant = create_constant_tensor(F::from(2_i32.pow(*input_scale as u32) as u64), 1);
|
||||
|
||||
let rescaled_positive = pairwise(config, region, &[positive, scale_constant], BaseOp::Mult)?;
|
||||
|
||||
let neg_mask = not(config, region, &[relu_mask])?;
|
||||
|
||||
let quantized_alpha = quantize_tensor(
|
||||
Tensor::from([alpha.0; 1].into_iter()),
|
||||
*input_scale,
|
||||
&crate::graph::Visibility::Fixed,
|
||||
)?;
|
||||
|
||||
let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1);
|
||||
|
||||
let scaled_neg_mask = pairwise(config, region, &[neg_mask, alpha_tensor], BaseOp::Mult)?;
|
||||
|
||||
let neg_part = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[values[0].clone(), scaled_neg_mask],
|
||||
BaseOp::Mult,
|
||||
)?;
|
||||
|
||||
pairwise(config, region, &[rescaled_positive, neg_part], BaseOp::Add)
|
||||
}
|
||||
|
||||
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
|
||||
@@ -15,101 +15,29 @@ use halo2curves::ff::PrimeField;
|
||||
/// An enum representing the operations that can be used to express more complex operations via accumulation
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
|
||||
pub enum LookupOp {
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
},
|
||||
Cast {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Max {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Min {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
Ceil {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Floor {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Round {
|
||||
scale: utils::F32,
|
||||
},
|
||||
RoundHalfToEven {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Rsqrt {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
},
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
},
|
||||
Sigmoid {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Exp {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACos {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Cosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ACosh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASin {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Sinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ASinh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATan {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Tanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
ATanh {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Erf {
|
||||
scale: utils::F32,
|
||||
},
|
||||
KroneckerDelta,
|
||||
Pow {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
HardSwish {
|
||||
scale: utils::F32,
|
||||
},
|
||||
Div { denom: utils::F32 },
|
||||
Cast { scale: utils::F32 },
|
||||
RoundHalfToEven { scale: utils::F32 },
|
||||
Sqrt { scale: utils::F32 },
|
||||
Rsqrt { scale: utils::F32 },
|
||||
Sigmoid { scale: utils::F32 },
|
||||
Ln { scale: utils::F32 },
|
||||
Exp { scale: utils::F32 },
|
||||
Cos { scale: utils::F32 },
|
||||
ACos { scale: utils::F32 },
|
||||
Cosh { scale: utils::F32 },
|
||||
ACosh { scale: utils::F32 },
|
||||
Sin { scale: utils::F32 },
|
||||
ASin { scale: utils::F32 },
|
||||
Sinh { scale: utils::F32 },
|
||||
ASinh { scale: utils::F32 },
|
||||
Tan { scale: utils::F32 },
|
||||
ATan { scale: utils::F32 },
|
||||
Tanh { scale: utils::F32 },
|
||||
ATanh { scale: utils::F32 },
|
||||
Erf { scale: utils::F32 },
|
||||
Pow { scale: utils::F32, a: utils::F32 },
|
||||
HardSwish { scale: utils::F32 },
|
||||
}
|
||||
|
||||
impl LookupOp {
|
||||
@@ -123,21 +51,10 @@ impl LookupOp {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!("recip_{}_{}", input_scale, output_scale),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
@@ -168,47 +85,18 @@ impl LookupOp {
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res =
|
||||
match &self {
|
||||
LookupOp::Ceil { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Floor { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into()))
|
||||
}
|
||||
LookupOp::Round { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into()))
|
||||
}
|
||||
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
|
||||
),
|
||||
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::KroneckerDelta => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
|
||||
}
|
||||
LookupOp::Max { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::max(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Min { scale, a } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::min(&x, scale.0.into(), a.0.into()),
|
||||
),
|
||||
LookupOp::Div { denom } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
|
||||
),
|
||||
LookupOp::Cast { scale } => Ok::<_, TensorError>(
|
||||
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
|
||||
),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => Ok::<_, TensorError>(tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
)),
|
||||
LookupOp::LeakyReLU { slope: a } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
|
||||
}
|
||||
LookupOp::Sigmoid { scale } => {
|
||||
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
|
||||
}
|
||||
@@ -283,25 +171,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale),
|
||||
LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale),
|
||||
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
|
||||
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
|
||||
LookupOp::KroneckerDelta => "K_DELTA".into(),
|
||||
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!(
|
||||
"RECIP(input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
),
|
||||
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
|
||||
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
|
||||
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
|
||||
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
|
||||
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
|
||||
@@ -344,8 +218,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
let in_scale = inputs_scale[0];
|
||||
in_scale + multiplier_to_scale(1. / scale.0 as f64)
|
||||
}
|
||||
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
|
||||
LookupOp::KroneckerDelta => 0,
|
||||
_ => inputs_scale[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
circuit::{
|
||||
layouts,
|
||||
utils::{self, F32},
|
||||
},
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -9,9 +12,12 @@ use super::{base::BaseOp, *};
|
||||
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum PolyOp {
|
||||
ReLU,
|
||||
Abs,
|
||||
Sign,
|
||||
LeakyReLU {
|
||||
slope: utils::F32,
|
||||
scale: i32,
|
||||
},
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
@@ -112,9 +118,9 @@ impl<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::LeakyReLU { slope: a, .. } => format!("LEAKYRELU (slope={})", a),
|
||||
PolyOp::Abs => "ABS".to_string(),
|
||||
PolyOp::Sign => "SIGN".to_string(),
|
||||
PolyOp::ReLU => "RELU".to_string(),
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
@@ -198,7 +204,9 @@ impl<
|
||||
Ok(Some(match self {
|
||||
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
|
||||
PolyOp::ReLU => layouts::relu(config, region, values[..].try_into()?)?,
|
||||
PolyOp::LeakyReLU { slope, scale } => {
|
||||
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
|
||||
}
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
@@ -329,6 +337,12 @@ impl<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
// this corresponds to the relu operation
|
||||
PolyOp::LeakyReLU {
|
||||
slope: F32(0.0), ..
|
||||
} => in_scales[0],
|
||||
// this corresponds to the leaky relu operation with a slope which induces a change in scale
|
||||
PolyOp::LeakyReLU { scale, .. } => in_scales[0] + *scale,
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Iff => in_scales[1],
|
||||
|
||||
@@ -1379,7 +1379,10 @@ mod conv_relu_col_ultra_overflow {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap().unwrap()],
|
||||
Box::new(PolyOp::ReLU),
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
@@ -2347,7 +2350,14 @@ mod matmul_relu {
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(&mut region, &[output.unwrap()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
@@ -2439,7 +2449,14 @@ mod relu {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
|
||||
Ok(config
|
||||
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
)
|
||||
@@ -2482,11 +2499,11 @@ mod lookup_ultra_overflow {
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
struct SigmoidCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub input: ValTensor<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for ReLUCircuit<F> {
|
||||
impl Circuit<F> for SigmoidCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
@@ -2500,7 +2517,7 @@ mod lookup_ultra_overflow {
|
||||
.map(|_| VarTensor::new_advice(cs, 4, 1, 3))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
|
||||
let nl = LookupOp::Sigmoid { scale: 1.0.into() };
|
||||
|
||||
let mut config = BaseConfig::default();
|
||||
|
||||
@@ -2533,7 +2550,7 @@ mod lookup_ultra_overflow {
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
@@ -2546,13 +2563,13 @@ mod lookup_ultra_overflow {
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn relucircuit() {
|
||||
fn sigmoidcircuit() {
|
||||
// get some logs fam
|
||||
crate::logger::init_logger();
|
||||
// parameters
|
||||
let a = Tensor::from((0..4).map(|i| Value::known(F::from(i + 1))));
|
||||
|
||||
let circuit = ReLUCircuit::<F> {
|
||||
let circuit = SigmoidCircuit::<F> {
|
||||
input: ValTensor::from(a),
|
||||
};
|
||||
|
||||
@@ -2562,7 +2579,7 @@ mod lookup_ultra_overflow {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
ReLUCircuit<F>,
|
||||
SigmoidCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -95,9 +95,6 @@ pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
/// Default commitment
|
||||
pub const DEFAULT_COMMITMENT: &str = "kzg";
|
||||
// TODO: In prod this will be the same across all chains we deploy to using the EZKL multisig create2 deployment.
|
||||
/// Default address of the verifier manager.
|
||||
pub const DEFAULT_VERIFIER_MANAGER_ADDRESS: &str = "0xdc64a140aa3e981100a9beca4e685f962f0cf6c9";
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
@@ -190,13 +187,11 @@ pub enum ContractType {
|
||||
/// Deploys a verifier contrat tailored to the circuit and not reusable
|
||||
Verifier {
|
||||
/// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits)
|
||||
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
|
||||
/// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain.
|
||||
reusable: bool,
|
||||
},
|
||||
/// Deploys a verifying key artifact that the reusable verifier loads into memory during runtime. Encodes the circuit specific data that was otherwise hardcoded onto the stack.
|
||||
VerifyingKeyArtifact,
|
||||
/// Manages the deployments of all reusable verifier and verifying artifact keys. Routes all the verification tx to the correct artifacts.
|
||||
VerifierManager
|
||||
}
|
||||
|
||||
impl Default for ContractType {
|
||||
@@ -220,7 +215,6 @@ impl std::fmt::Display for ContractType {
|
||||
reusable: false,
|
||||
} => "verifier".to_string(),
|
||||
ContractType::VerifyingKeyArtifact => "vka".to_string(),
|
||||
ContractType::VerifierManager => "manager".to_string()
|
||||
}
|
||||
)
|
||||
}
|
||||
@@ -238,16 +232,16 @@ impl From<&str> for ContractType {
|
||||
"verifier" => ContractType::Verifier { reusable: false },
|
||||
"verifier/reusable" => ContractType::Verifier { reusable: true },
|
||||
"vka" => ContractType::VerifyingKeyArtifact,
|
||||
"manager" => ContractType::VerifierManager,
|
||||
_ => {
|
||||
log::error!("Invalid value for ContractType");
|
||||
log::warn!("Defaulting to verifier");
|
||||
ContractType::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// wrapper for H160 to make it easy to parse into flag vals
|
||||
pub struct H160Flag {
|
||||
@@ -514,7 +508,7 @@ pub enum Commands {
|
||||
/// Gets an SRS from a circuit settings file.
|
||||
#[command(name = "get-srs")]
|
||||
GetSrs {
|
||||
/// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs
|
||||
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
|
||||
#[arg(long, default_value = None, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Path to the circuit settings .json file to read in logrows from. Overriden by logrows if specified.
|
||||
@@ -561,7 +555,7 @@ pub enum Commands {
|
||||
/// The path to save the proving key to
|
||||
#[arg(long, default_value = DEFAULT_PK_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
pk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
@@ -588,7 +582,7 @@ pub enum Commands {
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF_AGGREGATED, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
@@ -630,7 +624,7 @@ pub enum Commands {
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
|
||||
compiled_circuit: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to output the verification key file to
|
||||
@@ -707,7 +701,7 @@ pub enum Commands {
|
||||
/// The path to output the proof file to
|
||||
#[arg(long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
proof_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
#[arg(
|
||||
@@ -739,7 +733,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEvmVerifier {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -761,7 +755,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
|
||||
#[command(name = "create-evm-vka")]
|
||||
CreateEvmVKArtifact {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -804,7 +798,7 @@ pub enum Commands {
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// The path to load the desired verification key file
|
||||
@@ -837,7 +831,7 @@ pub enum Commands {
|
||||
/// The path to the verification key file (generated using the setup command)
|
||||
#[arg(long, default_value = DEFAULT_VK, value_hint = clap::ValueHint::FilePath)]
|
||||
vk_path: Option<PathBuf>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
@@ -855,7 +849,7 @@ pub enum Commands {
|
||||
/// reduced srs
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION, action = clap::ArgAction::SetTrue)]
|
||||
reduced_srs: Option<bool>,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
#[arg(long, value_hint = clap::ValueHint::FilePath)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
@@ -882,14 +876,6 @@ pub enum Commands {
|
||||
/// Private secp256K1 key in hex format, 64 chars, no 0x prefix, of the account signing transactions. If None the private key will be generated by Anvil
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
/// Deployed verifier manager contract's address
|
||||
/// Used to facilitate reusable verifier and vk artifact deployment
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_verifier_manager: Option<H160Flag>,
|
||||
/// Deployed reusable verifier contract's address
|
||||
/// Use to facilitate reusable verifier and vk artifact deployment
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_reusable_verifier: Option<H160Flag>,
|
||||
/// Contract type to be deployed
|
||||
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
|
||||
contract: ContractType,
|
||||
|
||||
105
src/eth.rs
105
src/eth.rs
@@ -31,7 +31,7 @@ use alloy::transports::{RpcError, TransportErrorKind};
|
||||
use foundry_compilers::artifacts::Settings as SolcSettings;
|
||||
use foundry_compilers::error::{SolcError, SolcIoError};
|
||||
use foundry_compilers::Solc;
|
||||
use halo2_solidity_verifier::{encode_calldata, encode_deploy};
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Fr, G1Affine};
|
||||
use halo2curves::group::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
@@ -213,16 +213,6 @@ abigen!(
|
||||
}
|
||||
);
|
||||
|
||||
// The bytecode here was generated from running solc compiler version 0.8.20 with optimization enabled and runs param set to 1.
|
||||
abigen!(
|
||||
#[allow(missing_docs)]
|
||||
#[sol(
|
||||
rpc,
|
||||
bytecode = "60806040525f80546001600160a01b03191673f39fd6e51aad88f6f4ce6ab8827279cfffb92266179055348015610034575f80fd5b5061003e33610043565b610092565b5f80546001600160a01b038381166001600160a01b0319831681178455604051919092169283917f8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e09190a35050565b6103dc8061009f5f395ff3fe608060405234801561000f575f80fd5b5060043610610060575f3560e01c80635717ecef146100645780635d34fd561461009b578063715018a6146100bb5780637a33ac87146100c55780638da5cb5b14610125578063f2fde38b1461012d575b5f80fd5b6100866100723660046102a7565b60016020525f908152604090205460ff1681565b60405190151581526020015b60405180910390f35b6100ae6100a93660046102e8565b610140565b6040516100929190610392565b6100c36101bf565b005b6100ae6100d33660046102e8565b8051602091820120604080516001600160f81b0319818501523060601b6001600160601b03191660218201525f6035820152605580820193909352815180820390930183526075019052805191012090565b6100ae6101d2565b6100c361013b3660046102a7565b6101e0565b5f610149610226565b5f8251602084015ff59050803b61015e575f80fd5b6001600160a01b0381165f90815260016020819052604091829020805460ff19169091179055517f27bf8213352a1c07513a54703c920b9e437940154edead05874c43279acf166c906101b2908390610392565b60405180910390a1919050565b6101c7610226565b6101d05f610258565b565b5f546001600160a01b031690565b6101e8610226565b6001600160a01b03811661021a575f604051631e4fbdf760e01b81526004016102119190610392565b60405180910390fd5b61022381610258565b50565b3361022f6101d2565b6001600160a01b0316146101d0573360405163118cdaa760e01b81526004016102119190610392565b5f80546001600160a01b038381166001600160a01b0319831681178455604051919092169283917f8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e09190a35050565b5f602082840312156102b7575f80fd5b81356001600160a01b03811681146102cd575f80fd5b9392505050565b634e487b7160e01b5f52604160045260245ffd5b5f602082840312156102f8575f80fd5b81356001600160401b038082111561030e575f80fd5b818401915084601f830112610321575f80fd5b813581811115610333576103336102d4565b604051601f8201601f19908116603f0116810190838211818310171561035b5761035b6102d4565b81604052828152876020848701011115610373575f80fd5b826020860160208301375f928101602001929092525095945050505050565b6001600160a01b039190911681526020019056fea26469706673582212201d85104628b308554b775f612650220008f8e318f66dc4ace466d82d70bae4e264736f6c63430008140033"
|
||||
)]
|
||||
EZKLVerifierManager,
|
||||
"./abis/EZKLVerifierManager.json"
|
||||
);
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EthError {
|
||||
#[error("a transport error occurred: {0}")]
|
||||
@@ -362,99 +352,6 @@ pub async fn deploy_contract_via_solidity(
|
||||
Ok(contract)
|
||||
}
|
||||
|
||||
pub async fn deploy_vka(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
verifier_manager: H160,
|
||||
reusable_verifier: H160,
|
||||
) -> Result<H160, EthError> {
|
||||
let (client, _) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
// Create an instance of the EZKLVerifierManager contract
|
||||
let verifier_manager_contract = EZKLVerifierManager::new(verifier_manager, client.clone());
|
||||
|
||||
// Get the bytecode of the contract to be deployed
|
||||
let (_, bytecode, _run_time_bytecode) =
|
||||
get_contract_artifacts(sol_code_path.clone(), contract_name, runs).await?;
|
||||
|
||||
// Check if the reusable verifier is already deployed
|
||||
let deployed_verifier: bool = verifier_manager_contract
|
||||
.verifierAddresses(reusable_verifier)
|
||||
.call()
|
||||
.await?
|
||||
._0;
|
||||
|
||||
if deployed_verifier == false {
|
||||
panic!("The reusable verifier for this VKA has not been deployed yet.");
|
||||
}
|
||||
|
||||
let encoded = encode_deploy(&bytecode);
|
||||
|
||||
debug!("encoded: {:#?}", hex::encode(&encoded));
|
||||
|
||||
let input: TransactionInput = encoded.into();
|
||||
|
||||
let tx = TransactionRequest::default()
|
||||
.to(reusable_verifier)
|
||||
.input(input);
|
||||
debug!("transaction {:#?}", tx);
|
||||
|
||||
let result = client.call(&tx).await;
|
||||
|
||||
if let Err(e) = result {
|
||||
return Err(EvmVerificationError::SolidityExecution(e.to_string()).into());
|
||||
}
|
||||
|
||||
// Now send the tx
|
||||
let _ = client.send_transaction(tx).await?;
|
||||
|
||||
let result = result?;
|
||||
debug!("result: {:#?}", result.to_vec());
|
||||
|
||||
let contract = H160::from_slice(&result.to_vec()[12..32]);
|
||||
return Ok(contract);
|
||||
}
|
||||
|
||||
pub async fn deploy_reusable_verifier(
|
||||
sol_code_path: PathBuf,
|
||||
rpc_url: Option<&str>,
|
||||
runs: usize,
|
||||
private_key: Option<&str>,
|
||||
contract_name: &str,
|
||||
verifier_manager: H160,
|
||||
) -> Result<H160, EthError> {
|
||||
let (client, _) = setup_eth_backend(rpc_url, private_key).await?;
|
||||
|
||||
// Create an instance of the EZKLVerifierManager contract
|
||||
let verifier_manager_contract = EZKLVerifierManager::new(verifier_manager, client.clone());
|
||||
|
||||
// Get the bytecode of the contract to be deployed
|
||||
let (_, bytecode, _run_time_bytecode) =
|
||||
get_contract_artifacts(sol_code_path.clone(), contract_name, runs).await?;
|
||||
|
||||
// Deploy the contract using the EZKLVerifierManager
|
||||
let output = verifier_manager_contract
|
||||
.deployVerifier(bytecode.clone().into())
|
||||
.call()
|
||||
.await?;
|
||||
let out = verifier_manager_contract
|
||||
.precomputeAddress(bytecode.clone().into())
|
||||
.call()
|
||||
.await?;
|
||||
// assert that out == output
|
||||
assert_eq!(out._0, output.addr);
|
||||
// Get the deployed contract address from the receipt
|
||||
let contract = output.addr;
|
||||
let _ = verifier_manager_contract
|
||||
.deployVerifier(bytecode.into())
|
||||
.send()
|
||||
.await?;
|
||||
return Ok(contract);
|
||||
}
|
||||
|
||||
///
|
||||
pub async fn deploy_da_verifier_via_solidity(
|
||||
settings_path: PathBuf,
|
||||
|
||||
105
src/execute.rs
105
src/execute.rs
@@ -410,46 +410,24 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
commitment.into(),
|
||||
)
|
||||
.map(|e| serde_json::to_string(&e).unwrap()),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvm {
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
addr_path,
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
addr_verifier_manager,
|
||||
addr_reusable_verifier,
|
||||
contract,
|
||||
} => {
|
||||
// if contract type is either verifier/reusable
|
||||
match contract {
|
||||
ContractType::Verifier { reusable: true } => {
|
||||
if addr_verifier_manager.is_none() {
|
||||
panic!("Must pass a verifier manager address for reusable verifier")
|
||||
}
|
||||
}
|
||||
ContractType::VerifyingKeyArtifact => {
|
||||
if addr_verifier_manager.is_none() || addr_reusable_verifier.is_none() {
|
||||
panic!(
|
||||
"Must pass a verifier manager address and reusable verifier address for verifying key artifact"
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
deploy_evm(
|
||||
sol_code_path.unwrap_or(DEFAULT_SOL_CODE.into()),
|
||||
rpc_url,
|
||||
addr_path.unwrap_or(DEFAULT_CONTRACT_ADDRESS.into()),
|
||||
optimizer_runs,
|
||||
private_key,
|
||||
addr_verifier_manager.map(|s| s.into()),
|
||||
addr_reusable_verifier.map(|s| s.into()),
|
||||
contract,
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::DeployEvmDataAttestation {
|
||||
data,
|
||||
settings_path,
|
||||
@@ -693,17 +671,18 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
pb.finish_with_message("SRS validated.");
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone(), commitment))?;
|
||||
let computed_srs_path = get_srs_path(k, srs_path.clone(), commitment);
|
||||
let mut file = std::fs::File::create(&computed_srs_path)?;
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
info!("Saved SRS to {}.", computed_srs_path.as_os_str().to_str().unwrap_or("disk"));
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -749,7 +728,7 @@ pub(crate) async fn gen_witness(
|
||||
None
|
||||
};
|
||||
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
let mut input = circuit.load_graph_input(&data).await?;
|
||||
#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))]
|
||||
let mut input = circuit.load_graph_input(&data)?;
|
||||
|
||||
@@ -1439,7 +1418,6 @@ pub(crate) async fn create_evm_verifier(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_vka(
|
||||
vk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
@@ -1468,20 +1446,9 @@ pub(crate) async fn create_evm_vka(
|
||||
num_instance,
|
||||
);
|
||||
|
||||
let (reusable_verifier, vk_solidity) = generator.render_separately()?;
|
||||
let vk_solidity = generator.render_separately()?.1;
|
||||
|
||||
// Remove the first line of vk_solidity (license identifier). Same license identifier for all contracts in this .sol
|
||||
let vk_solidity = vk_solidity
|
||||
.lines()
|
||||
.skip(1)
|
||||
.collect::<Vec<&str>>()
|
||||
.join("\n");
|
||||
|
||||
// We store each contracts to the same file...
|
||||
// We need to do this so that during the deployment transaction we make sure
|
||||
// verifier manager links the VKA to the correct reusable_verifier.
|
||||
let combined_solidity = format!("{}\n\n{}", reusable_verifier, vk_solidity);
|
||||
File::create(sol_code_path.clone())?.write_all(combined_solidity.as_bytes())?;
|
||||
File::create(sol_code_path.clone())?.write_all(vk_solidity.as_bytes())?;
|
||||
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Halo2VerifyingArtifact", 0).await?;
|
||||
@@ -1599,51 +1566,21 @@ pub(crate) async fn deploy_evm(
|
||||
addr_path: PathBuf,
|
||||
runs: usize,
|
||||
private_key: Option<String>,
|
||||
verifier_manager: Option<alloy::primitives::Address>,
|
||||
reusable_verifier: Option<alloy::primitives::Address>,
|
||||
contract: ContractType,
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::eth::{deploy_reusable_verifier, deploy_vka};
|
||||
|
||||
let contract_name = match contract {
|
||||
ContractType::Verifier { reusable: false } => "Halo2Verifier",
|
||||
ContractType::Verifier { reusable: true } => "Halo2VerifierReusable",
|
||||
ContractType::VerifyingKeyArtifact => "Halo2VerifyingArtifact",
|
||||
ContractType::VerifierManager => "EZKLVerifierManager",
|
||||
};
|
||||
|
||||
let contract_address = if contract_name == "Halo2VerifierReusable" {
|
||||
// Use VerifierManager to deploy the contract
|
||||
deploy_reusable_verifier(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
verifier_manager.unwrap(),
|
||||
)
|
||||
.await?
|
||||
} else if contract_name == "Halo2VerifyingArtifact" {
|
||||
deploy_vka(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
verifier_manager.unwrap(),
|
||||
reusable_verifier.unwrap(),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
)
|
||||
.await?
|
||||
};
|
||||
let contract_address = deploy_contract_via_solidity(
|
||||
sol_code_path,
|
||||
rpc_url.as_deref(),
|
||||
runs,
|
||||
private_key.as_deref(),
|
||||
contract_name,
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("Contract deployed at: {:#?}", contract_address);
|
||||
|
||||
@@ -2085,7 +2022,7 @@ pub(crate) fn mock_aggregate(
|
||||
}
|
||||
}
|
||||
// proof aggregation
|
||||
let pb = {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Aggregating (may take a while)...");
|
||||
pb
|
||||
@@ -2096,7 +2033,7 @@ pub(crate) fn mock_aggregate(
|
||||
let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()])
|
||||
.map_err(|e| ExecutionError::MockProverError(e.to_string()))?;
|
||||
prover.verify().map_err(ExecutionError::VerifyError)?;
|
||||
pb.finish_with_message("Done.");
|
||||
pb.finish_with_message("Done.");
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
@@ -2190,7 +2127,7 @@ pub(crate) fn aggregate(
|
||||
}
|
||||
|
||||
// proof aggregation
|
||||
let pb = {
|
||||
let pb = {
|
||||
let pb = init_spinner();
|
||||
pb.set_message("Aggregating (may take a while)...");
|
||||
pb
|
||||
@@ -2339,7 +2276,7 @@ pub(crate) fn aggregate(
|
||||
);
|
||||
snark.save(&proof_path)?;
|
||||
|
||||
pb.finish_with_message("Done.");
|
||||
pb.finish_with_message("Done.");
|
||||
|
||||
Ok(snark)
|
||||
}
|
||||
|
||||
@@ -763,81 +763,41 @@ pub fn new_op_from_onnx(
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
if unit == 0. {
|
||||
SupportedOp::Linear(PolyOp::ReLU)
|
||||
if const_inputs.len() > 0 {
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Max".to_string()));
|
||||
};
|
||||
if unit == 0. {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
SupportedOp::Linear(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Hybrid(HybridOp::Max)
|
||||
}
|
||||
} else {
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
SupportedOp::Nonlinear(LookupOp::Max {
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
SupportedOp::Hybrid(HybridOp::Max)
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "max".to_string()));
|
||||
}
|
||||
}
|
||||
"Min" => {
|
||||
// Extract the min value
|
||||
// first find the input that is a constant
|
||||
// and then extract the value
|
||||
let const_inputs = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.is_constant())
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_inputs.len() != 1 {
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_inputs[0];
|
||||
let boxed_op = inputs[const_idx].opkind();
|
||||
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
|
||||
if c.len() == 1 {
|
||||
c[0]
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
} else {
|
||||
return Err(GraphError::OpMismatch(idx, "Min".to_string()));
|
||||
};
|
||||
|
||||
if inputs.len() == 2 {
|
||||
if let Some(node) = inputs.get_mut(const_idx) {
|
||||
node.decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
}
|
||||
|
||||
// get the non-constant index
|
||||
let non_const_idx = if const_idx == 0 { 1 } else { 0 };
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::Min {
|
||||
scale: scale_to_multiplier(inputs[non_const_idx].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
SupportedOp::Hybrid(HybridOp::Min)
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "min".to_string()));
|
||||
}
|
||||
@@ -849,7 +809,6 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
use_range_check_for_int: true,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -864,8 +823,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
SupportedOp::Nonlinear(LookupOp::LeakyReLU {
|
||||
SupportedOp::Linear(PolyOp::LeakyReLU {
|
||||
slope: crate::circuit::utils::F32(leaky_op.alpha),
|
||||
scale: scales.params,
|
||||
})
|
||||
}
|
||||
"Scan" => {
|
||||
@@ -1123,14 +1083,17 @@ pub fn new_op_from_onnx(
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
|
||||
"Ceil" => SupportedOp::Hybrid(HybridOp::Ceil {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Floor" => SupportedOp::Nonlinear(LookupOp::Floor {
|
||||
"Floor" => SupportedOp::Hybrid(HybridOp::Floor {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"Round" => SupportedOp::Nonlinear(LookupOp::Round {
|
||||
"Round" => SupportedOp::Hybrid(HybridOp::Round {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
legs: run_args.decomp_legs,
|
||||
}),
|
||||
"RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
@@ -1146,10 +1109,17 @@ pub fn new_op_from_onnx(
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
}
|
||||
SupportedOp::Nonlinear(LookupOp::Pow {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(c.raw_values[0]),
|
||||
})
|
||||
|
||||
let exponent = c.raw_values[0];
|
||||
|
||||
if exponent.fract() == 0.0 {
|
||||
SupportedOp::Linear(PolyOp::Pow(exponent as u32))
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Pow {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
a: crate::circuit::utils::F32(exponent),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support constant pow for now")
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ pub fn get_rep(
|
||||
n: usize,
|
||||
) -> Result<Vec<IntegerRep>, DecompositionError> {
|
||||
// check if x is too large
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) {
|
||||
if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 {
|
||||
return Err(DecompositionError::TooLarge(*x, base, n));
|
||||
}
|
||||
let mut rep = vec![0; n + 1];
|
||||
@@ -1421,85 +1421,6 @@ pub fn slice<T: TensorType + Send + Sync>(
|
||||
pub mod nonlinearities {
|
||||
use super::*;
|
||||
|
||||
/// Ceiling operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
///
|
||||
/// use ezkl::tensor::ops::nonlinearities::ceil;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = ceil(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn ceil(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.ceil() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Floor operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::floor;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = floor(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 2, 2, 4, 4, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn floor(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.floor() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Round operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::round;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = round(&x, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn round(a: &Tensor<IntegerRep>, scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale;
|
||||
let rounded = kix.round() * scale;
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Round half to even operator.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
@@ -1553,35 +1474,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Applies Kronecker delta to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::kronecker_delta;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = kronecker_delta(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn kronecker_delta<T: TensorType + std::cmp::PartialEq + Send + Sync>(
|
||||
a: &Tensor<T>,
|
||||
) -> Tensor<T> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
if a_i == T::zero().unwrap() {
|
||||
Ok::<_, TensorError>(T::one().unwrap())
|
||||
} else {
|
||||
Ok::<_, TensorError>(T::zero().unwrap())
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies sigmoid to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -1750,27 +1642,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies sign to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::sign;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[-2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = sign(&x);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn sign(a: &Tensor<IntegerRep>) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(a_i.signum()))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies square root to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -2254,101 +2125,6 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies leaky relu to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale` - Single value
|
||||
/// * `slope` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::leakyrelu;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = leakyrelu(&x, 0.1);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn leakyrelu(a: &Tensor<IntegerRep>, slope: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rounded = if a_i < 0 {
|
||||
let d_inv_x = (slope) * (a_i as f64);
|
||||
d_inv_x.round() as IntegerRep
|
||||
} else {
|
||||
let d_inv_x = a_i as f64;
|
||||
d_inv_x.round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies max to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - scalar
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::max;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = max(&x, 1.0, 1.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn max(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
|
||||
// calculate value of output
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let d_inv_x = (a_i as f64) / scale_input;
|
||||
let rounded = if d_inv_x <= threshold {
|
||||
(threshold * scale_input).round() as IntegerRep
|
||||
} else {
|
||||
(d_inv_x * scale_input).round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies min to a tensor of integers.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - scalar
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::min;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, -5]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = min(&x, 1.0, 2.0);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn min(a: &Tensor<IntegerRep>, scale_input: f64, threshold: f64) -> Tensor<IntegerRep> {
|
||||
// calculate value of output
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let d_inv_x = (a_i as f64) / scale_input;
|
||||
let rounded = if d_inv_x >= threshold {
|
||||
(threshold * scale_input).round() as IntegerRep
|
||||
} else {
|
||||
(d_inv_x * scale_input).round() as IntegerRep
|
||||
};
|
||||
Ok::<_, TensorError>(rounded)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise divides a tensor with a const integer element.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -2429,104 +2205,6 @@ pub mod nonlinearities {
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::greater_than;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
/// let result = greater_than(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) > 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::greater_than_equal;
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
/// let result = greater_than_equal(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn greater_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) >= 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise less than
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::less_than;
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
///
|
||||
/// let result = less_than(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn less_than(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) < 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise less than
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `b` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::less_than_equal;
|
||||
///
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 7, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2.0;
|
||||
///
|
||||
/// let result = less_than_equal(&x, k);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn less_than_equal(a: &Tensor<IntegerRep>, b: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) <= 0_f64)))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ops that return the transcript i.e intermediate calcs of an op
|
||||
|
||||
Binary file not shown.
@@ -205,7 +205,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 94] = [
|
||||
const TESTS: [&str; 95] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -304,6 +304,7 @@ mod native_tests {
|
||||
"lstm_large", // 91
|
||||
"lstm_medium", // 92
|
||||
"lenet_5", // 93
|
||||
"rsqrt", // 94
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -542,7 +543,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=93 {
|
||||
seq!(N in 0..=94 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -851,9 +852,11 @@ mod native_tests {
|
||||
fn kzg_prove_and_verify_tight_lookup_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let path = test_dir.into_path();
|
||||
let path = path.to_str().unwrap();
|
||||
crate::native_tests::mv_test_(path, test);
|
||||
prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, false, "single", Commitments::KZG, 1);
|
||||
test_dir.close().unwrap();
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
@@ -1000,21 +1003,13 @@ mod native_tests {
|
||||
use crate::native_tests::run_js_tests;
|
||||
use ezkl::logger::init_logger;
|
||||
use crate::native_tests::lazy_static;
|
||||
use std::sync::Once;
|
||||
|
||||
// Global variables to store verifier hashes and identical verifiers
|
||||
lazy_static! {
|
||||
static ref ANVIL_INSTANCE: std::sync::Mutex<Option<std::process::Child>> = std::sync::Mutex::new(None);
|
||||
// create a new variable of type
|
||||
static ref REUSABLE_VERIFIER_ADDR: std::sync::Mutex<Option<String>> = std::sync::Mutex::new(None);
|
||||
}
|
||||
|
||||
static INIT: Once = Once::new();
|
||||
|
||||
fn initialize() {
|
||||
INIT.call_once(|| {
|
||||
let anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
*ANVIL_INSTANCE.lock().unwrap() = Some(anvil_child);
|
||||
});
|
||||
}
|
||||
|
||||
/// Currently only on chain inputs that return a non-negative value are supported.
|
||||
const TESTS_ON_CHAIN_INPUT: [&str; 17] = [
|
||||
@@ -1126,10 +1121,9 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=93 {
|
||||
seq!(N in 0..4 {
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier_(test: &str) {
|
||||
initialize();
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
@@ -1137,18 +1131,28 @@ mod native_tests {
|
||||
init_logger();
|
||||
log::error!("Running kzg_evm_prove_and_verify_reusable_verifier_ for test: {}", test);
|
||||
// default vis
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", false);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", &mut REUSABLE_VERIFIER_ADDR.lock().unwrap(), false);
|
||||
// public/public vis
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", false);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", &mut Some(reusable_verifier_address), false);
|
||||
// hashed input
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", false);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", &mut Some(reusable_verifier_address), false);
|
||||
|
||||
match REUSABLE_VERIFIER_ADDR.try_lock() {
|
||||
Ok(mut addr) => {
|
||||
*addr = Some(reusable_verifier_address.clone());
|
||||
log::error!("Reusing the same verifeir deployed at address: {}", reusable_verifier_address);
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to acquire lock on REUSABLE_VERIFIER_ADDR");
|
||||
}
|
||||
}
|
||||
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn kzg_evm_prove_and_verify_reusable_verifier_with_overflow_(test: &str) {
|
||||
initialize();
|
||||
// verifier too big to fit on chain with overflow calibration target
|
||||
if test == "1l_eltwise_div" || test == "lenet_5" || test == "ltsf" || test == "lstm_large" {
|
||||
return;
|
||||
@@ -1160,13 +1164,24 @@ mod native_tests {
|
||||
init_logger();
|
||||
log::error!("Running kzg_evm_prove_and_verify_reusable_verifier_with_overflow_ for test: {}", test);
|
||||
// default vis
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", true);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "private", "private", "public", &mut REUSABLE_VERIFIER_ADDR.lock().unwrap(), true);
|
||||
// public/public vis
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", true);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "public", "private", "public", &mut Some(reusable_verifier_address), true);
|
||||
// hashed input
|
||||
kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", true);
|
||||
let reusable_verifier_address: String = kzg_evm_prove_and_verify_reusable_verifier(2, path, test.to_string(), "hashed", "private", "public", &mut Some(reusable_verifier_address), true);
|
||||
|
||||
match REUSABLE_VERIFIER_ADDR.try_lock() {
|
||||
Ok(mut addr) => {
|
||||
*addr = Some(reusable_verifier_address.clone());
|
||||
log::error!("Reusing the same verifeir deployed at address: {}", reusable_verifier_address);
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to acquire lock on REUSABLE_VERIFIER_ADDR");
|
||||
}
|
||||
}
|
||||
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1619,7 +1634,6 @@ mod native_tests {
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(args)
|
||||
.stdout(std::process::Stdio::null())
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
@@ -2216,8 +2230,9 @@ mod native_tests {
|
||||
input_visibility: &str,
|
||||
param_visibility: &str,
|
||||
output_visibility: &str,
|
||||
reusable_verifier_address: &mut Option<String>,
|
||||
overflow: bool,
|
||||
) {
|
||||
) -> String {
|
||||
let anvil_url = ANVIL_URL.as_str();
|
||||
|
||||
prove_and_verify(
|
||||
@@ -2240,82 +2255,57 @@ mod native_tests {
|
||||
|
||||
let vk_arg = format!("{}/{}/key.vk", test_dir, example_name);
|
||||
let rpc_arg = format!("--rpc-url={}", anvil_url);
|
||||
// addr path for verifier manager contract
|
||||
let addr_path_arg = format!("--addr-path={}/{}/addr.txt", test_dir, example_name);
|
||||
let verifier_manager_arg: String;
|
||||
let settings_arg = format!("--settings-path={}", settings_path);
|
||||
// reusable verifier sol_arg
|
||||
let sol_arg = format!("--sol-code-path={}/{}/kzg.sol", test_dir, example_name);
|
||||
|
||||
// create the reusable verifier
|
||||
let args = vec![
|
||||
"create-evm-verifier",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg,
|
||||
"--reusable",
|
||||
];
|
||||
// if the reusable verifier address is not set, create the verifier
|
||||
let deployed_addr_arg = match reusable_verifier_address {
|
||||
Some(addr) => addr.clone(),
|
||||
None => {
|
||||
// create the reusable verifier
|
||||
let args = vec![
|
||||
"create-evm-verifier",
|
||||
"--vk-path",
|
||||
&vk_arg,
|
||||
&settings_arg,
|
||||
&sol_arg,
|
||||
"--reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// deploy the verifier manager
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
// set the sol code path to be contracts/VerifierManager.sol relative to root
|
||||
"--sol-code-path=contracts/VerifierManager.sol",
|
||||
"-C=manager",
|
||||
];
|
||||
// deploy the verifier
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
sol_arg.as_str(),
|
||||
"-C=verifier/reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address of the verifier manager
|
||||
let addr = std::fs::read_to_string(format!("{}/{}/addr.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
// read in the address
|
||||
let addr =
|
||||
std::fs::read_to_string(format!("{}/{}/addr.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
verifier_manager_arg = format!("--addr-verifier-manager={}", addr);
|
||||
|
||||
// if the reusable verifier address is not set, deploy the verifier manager and then create the verifier
|
||||
let rv_addr = {
|
||||
// addr path for rv contract
|
||||
let addr_path_arg = format!("--addr-path={}/{}/addr_rv.txt", test_dir, example_name);
|
||||
// deploy the reusable verifier via the verifier router.
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg.as_str(),
|
||||
sol_arg.as_str(),
|
||||
verifier_manager_arg.as_str(),
|
||||
"-C=verifier/reusable",
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// read in the address of the verifier manager
|
||||
let addr =
|
||||
std::fs::read_to_string(format!("{}/{}/addr_rv.txt", test_dir, example_name))
|
||||
.expect("failed to read address file");
|
||||
|
||||
addr
|
||||
let deployed_addr_arg = format!("--addr-verifier={}", addr);
|
||||
// set the reusable verifier address
|
||||
*reusable_verifier_address = Some(addr);
|
||||
deployed_addr_arg
|
||||
}
|
||||
};
|
||||
|
||||
let addr_path_arg_vk = format!("--addr-path={}/{}/addr_vk.txt", test_dir, example_name);
|
||||
let sol_arg_vk: String = format!("--sol-code-path={}/{}/vk.sol", test_dir, example_name);
|
||||
// create the verifier
|
||||
let addr_path_arg_vk = format!("--addr-path={}/{}/addr_vk.txt", test_dir, example_name);
|
||||
let sol_arg_vk: String = format!("--sol-code-path={}/{}/vk.sol", test_dir, example_name);
|
||||
// create the verifier
|
||||
@@ -2333,15 +2323,11 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let rv_addr_arg = format!("--addr-reusable-verifier={}", rv_addr);
|
||||
|
||||
// deploy the vka via the "DeployVKA" command on the reusable verifier
|
||||
// deploy the vka
|
||||
let args = vec![
|
||||
"deploy-evm",
|
||||
rpc_arg.as_str(),
|
||||
addr_path_arg_vk.as_str(),
|
||||
verifier_manager_arg.as_str(),
|
||||
rv_addr_arg.as_str(),
|
||||
sol_arg_vk.as_str(),
|
||||
"-C=vka",
|
||||
];
|
||||
@@ -2371,8 +2357,6 @@ mod native_tests {
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
let deployed_addr_arg = format!("--addr-verifier={}", rv_addr);
|
||||
|
||||
// now verify the proof
|
||||
let pf_arg = format!("{}/{}/proof.pf", test_dir, example_name);
|
||||
let args = vec![
|
||||
@@ -2432,6 +2416,9 @@ mod native_tests {
|
||||
i
|
||||
);
|
||||
}
|
||||
|
||||
// Returned deploy_addr_arg for reusable verifier
|
||||
deployed_addr_arg
|
||||
}
|
||||
|
||||
// run js browser evm verify tests for a given example
|
||||
|
||||
@@ -124,41 +124,40 @@ mod py_tests {
|
||||
}
|
||||
|
||||
const TESTS: [&str; 34] = [
|
||||
"ezkl_demo_batch.ipynb",
|
||||
"proof_splitting.ipynb", // 0
|
||||
"variance.ipynb",
|
||||
"mnist_gan.ipynb",
|
||||
// "mnist_vae.ipynb",
|
||||
"keras_simple_demo.ipynb",
|
||||
"mnist_gan_proof_splitting.ipynb", // 4
|
||||
"hashed_vis.ipynb", // 5
|
||||
"simple_demo_all_public.ipynb",
|
||||
"data_attest.ipynb",
|
||||
"little_transformer.ipynb",
|
||||
"simple_demo_aggregated_proofs.ipynb",
|
||||
"ezkl_demo.ipynb", // 10
|
||||
"lstm.ipynb",
|
||||
"set_membership.ipynb", // 12
|
||||
"decision_tree.ipynb",
|
||||
"random_forest.ipynb",
|
||||
"gradient_boosted_trees.ipynb", // 15
|
||||
"xgboost.ipynb",
|
||||
"lightgbm.ipynb",
|
||||
"svm.ipynb",
|
||||
"simple_demo_public_input_output.ipynb",
|
||||
"simple_demo_public_network_output.ipynb", // 20
|
||||
"gcn.ipynb",
|
||||
"linear_regression.ipynb",
|
||||
"stacked_regression.ipynb",
|
||||
"data_attest_hashed.ipynb",
|
||||
"kzg_vis.ipynb", // 25
|
||||
"kmeans.ipynb",
|
||||
"solvency.ipynb",
|
||||
"sklearn_mlp.ipynb",
|
||||
"generalized_inverse.ipynb",
|
||||
"mnist_classifier.ipynb", // 30
|
||||
"world_rotation.ipynb",
|
||||
"logistic_regression.ipynb",
|
||||
"ezkl_demo_batch.ipynb", // 0
|
||||
"proof_splitting.ipynb", // 1
|
||||
"variance.ipynb", // 2
|
||||
"mnist_gan.ipynb", // 3
|
||||
"keras_simple_demo.ipynb", // 4
|
||||
"mnist_gan_proof_splitting.ipynb", // 5
|
||||
"hashed_vis.ipynb", // 6
|
||||
"simple_demo_all_public.ipynb", // 7
|
||||
"data_attest.ipynb", // 8
|
||||
"little_transformer.ipynb", // 9
|
||||
"simple_demo_aggregated_proofs.ipynb", // 10
|
||||
"ezkl_demo.ipynb", // 11
|
||||
"lstm.ipynb", // 12
|
||||
"set_membership.ipynb", // 13
|
||||
"decision_tree.ipynb", // 14
|
||||
"random_forest.ipynb", // 15
|
||||
"gradient_boosted_trees.ipynb", // 16
|
||||
"xgboost.ipynb", // 17
|
||||
"lightgbm.ipynb", // 18
|
||||
"svm.ipynb", // 19
|
||||
"simple_demo_public_input_output.ipynb", // 20
|
||||
"simple_demo_public_network_output.ipynb", // 21
|
||||
"gcn.ipynb", // 22
|
||||
"linear_regression.ipynb", // 23
|
||||
"stacked_regression.ipynb", // 24
|
||||
"data_attest_hashed.ipynb", // 25
|
||||
"kzg_vis.ipynb", // 26
|
||||
"kmeans.ipynb", // 27
|
||||
"solvency.ipynb", // 28
|
||||
"sklearn_mlp.ipynb", // 29
|
||||
"generalized_inverse.ipynb", // 30
|
||||
"mnist_classifier.ipynb", // 31
|
||||
"world_rotation.ipynb", // 32
|
||||
"logistic_regression.ipynb", // 33
|
||||
];
|
||||
|
||||
macro_rules! test_func {
|
||||
|
||||
@@ -1 +1 @@
|
||||
[{"type":"function","name":"deployVKA","inputs":[{"name":"bytecode","type":"bytes","internalType":"bytes"}],"outputs":[{"name":"addr","type":"address","internalType":"address"}],"stateMutability":"nonpayable"},{"type":"function","name":"precomputeAddress","inputs":[{"name":"bytecode","type":"bytes","internalType":"bytes"}],"outputs":[{"name":"","type":"address","internalType":"address"}],"stateMutability":"view"},{"type":"function","name":"verifyProof","inputs":[{"name":"vk","type":"address","internalType":"address"},{"name":"proof","type":"bytes","internalType":"bytes"},{"name":"instances","type":"uint256[]","internalType":"uint256[]"}],"outputs":[{"name":"","type":"bool","internalType":"bool"}],"stateMutability":"nonpayable"},{"type":"function","name":"vkaLog","inputs":[{"name":"","type":"address","internalType":"address"}],"outputs":[{"name":"","type":"bool","internalType":"bool"}],"stateMutability":"view"},{"type":"event","name":"DeployedVKArtifact","inputs":[{"name":"vka","type":"address","indexed":false,"internalType":"address"}],"anonymous":false},{"type":"error","name":"UnloggedVka","inputs":[{"name":"vka","type":"address","internalType":"address"}]}]
|
||||
[{"type":"function","name":"verifyProof","inputs":[{"internalType":"bytes","name":"proof","type":"bytes"},{"internalType":"uint256[]","name":"instances","type":"uint256[]"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"nonpayable"}]
|
||||
Reference in New Issue
Block a user