mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
7 Commits
ac/documen
...
release-v2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c48ff1a4e9 | ||
|
|
fe978caa85 | ||
|
|
1bef92407c | ||
|
|
5ff1c48ede | ||
|
|
ab4997d0c2 | ||
|
|
701e69dd2f | ||
|
|
f631445e26 |
21
.github/workflows/rust.yml
vendored
21
.github/workflows/rust.yml
vendored
@@ -245,6 +245,25 @@ jobs:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
|
||||
|
||||
foudry-solidity-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
with:
|
||||
persist-credentials: false
|
||||
submodules: recursive
|
||||
|
||||
- name: Install Foundry
|
||||
uses: foundry-rs/foundry-toolchain@v1
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd tests/foundry
|
||||
forge install https://github.com/foundry-rs/forge-std --no-git --no-commit
|
||||
forge test -vvvv --fuzz-runs 64
|
||||
|
||||
mock-proving-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -875,4 +894,4 @@ jobs:
|
||||
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
|
||||
-parallel-testing-enabled NO \
|
||||
-resultBundlePath ../../exampleTestResults \
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -9,6 +9,7 @@ pkg
|
||||
!AttestData.sol
|
||||
!VerifierBase.sol
|
||||
!LoadInstances.sol
|
||||
!AttestData.t.sol
|
||||
*.pf
|
||||
*.vk
|
||||
*.pk
|
||||
@@ -49,3 +50,5 @@ timingData.json
|
||||
!tests/assets/vk.key
|
||||
docs/python/build
|
||||
!tests/assets/vk_aggr.key
|
||||
cache
|
||||
out
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2441,7 +2441,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_solidity_verifier"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier#7def6101d32331182f91483832e4fd293d75f33e"
|
||||
source = "git+https://github.com/alexander-camuto/halo2-solidity-verifier#80c20d6ab57d3b28b2a28df4b63c30923bde17e1"
|
||||
dependencies = [
|
||||
"askama",
|
||||
"blake2b_simd",
|
||||
|
||||
@@ -297,8 +297,4 @@ inherits = "dev"
|
||||
opt-level = 3
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = [
|
||||
"-O4",
|
||||
"--flexible-inline-max-function-size",
|
||||
"4294967295",
|
||||
]
|
||||
wasm-opt = ["-O4", "--flexible-inline-max-function-size", "4294967295"]
|
||||
|
||||
312
abis/DataAttestation.json
Normal file
312
abis/DataAttestation.json
Normal file
@@ -0,0 +1,312 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "_callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_decimals",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_bits",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "_instanceOffset",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "constructor"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "HALF_ORDER",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "ORDER",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "instances",
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"name": "attestData",
|
||||
"outputs": [],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "callData",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "contractAddress",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "getInstancesCalldata",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "instances",
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "getInstancesMemory",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "instances",
|
||||
"type": "uint256[]"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "index",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "getScalars",
|
||||
"outputs": [
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "decimals",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "bits",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"internalType": "struct DataAttestation.Scalars",
|
||||
"name": "",
|
||||
"type": "tuple"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "instanceOffset",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "x",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "y",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "denominator",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "mulDiv",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "result",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "int256",
|
||||
"name": "x",
|
||||
"type": "int256"
|
||||
},
|
||||
{
|
||||
"components": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "decimals",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "bits",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"internalType": "struct DataAttestation.Scalars",
|
||||
"name": "_scalars",
|
||||
"type": "tuple"
|
||||
}
|
||||
],
|
||||
"name": "quantizeData",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "int256",
|
||||
"name": "quantized_data",
|
||||
"type": "int256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "target",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "data",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "staticCall",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "int256",
|
||||
"name": "x",
|
||||
"type": "int256"
|
||||
}
|
||||
],
|
||||
"name": "toFieldElement",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "field_element",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "pure",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "verifier",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "verifyWithDataAttestation",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
@@ -1,167 +0,0 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address[]",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address[]"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes[][]",
|
||||
"name": "_callData",
|
||||
"type": "bytes[][]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[][]",
|
||||
"name": "_decimals",
|
||||
"type": "uint256[][]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_scales",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "_instanceOffset",
|
||||
"type": "uint8"
|
||||
},
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "constructor"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "accountCalls",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "contractAddress",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "callCount",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "admin",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "instanceOffset",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "scales",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address[]",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address[]"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes[][]",
|
||||
"name": "_callData",
|
||||
"type": "bytes[][]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[][]",
|
||||
"name": "_decimals",
|
||||
"type": "uint256[][]"
|
||||
}
|
||||
],
|
||||
"name": "updateAccountCalls",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "updateAdmin",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "verifier",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "verifyWithDataAttestation",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
@@ -1,147 +0,0 @@
|
||||
[
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "_callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "_decimals",
|
||||
"type": "uint256"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256[]",
|
||||
"name": "_scales",
|
||||
"type": "uint256[]"
|
||||
},
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "_instanceOffset",
|
||||
"type": "uint8"
|
||||
},
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "constructor"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "accountCall",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "contractAddress",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "decimals",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "admin",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [],
|
||||
"name": "instanceOffset",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "uint8",
|
||||
"name": "",
|
||||
"type": "uint8"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_contractAddresses",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "_callData",
|
||||
"type": "bytes"
|
||||
},
|
||||
{
|
||||
"internalType": "uint256",
|
||||
"name": "_decimals",
|
||||
"type": "uint256"
|
||||
}
|
||||
],
|
||||
"name": "updateAccountCalls",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "_admin",
|
||||
"type": "address"
|
||||
}
|
||||
],
|
||||
"name": "updateAdmin",
|
||||
"outputs": [],
|
||||
"stateMutability": "nonpayable",
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"internalType": "address",
|
||||
"name": "verifier",
|
||||
"type": "address"
|
||||
},
|
||||
{
|
||||
"internalType": "bytes",
|
||||
"name": "encoded",
|
||||
"type": "bytes"
|
||||
}
|
||||
],
|
||||
"name": "verifyWithDataAttestation",
|
||||
"outputs": [
|
||||
{
|
||||
"internalType": "bool",
|
||||
"name": "",
|
||||
"type": "bool"
|
||||
}
|
||||
],
|
||||
"stateMutability": "view",
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
@@ -8,21 +8,27 @@ contract LoadInstances {
|
||||
*/
|
||||
function getInstancesMemory(
|
||||
bytes memory encoded
|
||||
) internal pure returns (uint256[] memory instances) {
|
||||
) public pure returns (uint256[] memory instances) {
|
||||
bytes4 funcSig;
|
||||
uint256 instances_offset;
|
||||
uint256 instances_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := mload(add(encoded, 0x20))
|
||||
|
||||
}
|
||||
if (funcSig == 0xaf83a18d) {
|
||||
instances_offset = 0x64;
|
||||
} else if (funcSig == 0x1e8e1e13) {
|
||||
instances_offset = 0x44;
|
||||
} else {
|
||||
revert("Invalid function signature");
|
||||
}
|
||||
assembly {
|
||||
// Fetch instances offset which is 4 + 32 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
instances_offset := mload(
|
||||
add(encoded, add(0x44, mul(0x20, eq(funcSig, 0xaf83a18d))))
|
||||
)
|
||||
instances_offset := mload(add(encoded, instances_offset))
|
||||
|
||||
instances_length := mload(add(add(encoded, 0x24), instances_offset))
|
||||
}
|
||||
@@ -41,6 +47,10 @@ contract LoadInstances {
|
||||
)
|
||||
}
|
||||
}
|
||||
require(
|
||||
funcSig == 0xaf83a18d || funcSig == 0x1e8e1e13,
|
||||
"Invalid function signature"
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @dev Parse the instances array from the Halo2Verifier encoded calldata.
|
||||
@@ -49,23 +59,31 @@ contract LoadInstances {
|
||||
*/
|
||||
function getInstancesCalldata(
|
||||
bytes calldata encoded
|
||||
) internal pure returns (uint256[] memory instances) {
|
||||
) public pure returns (uint256[] memory instances) {
|
||||
bytes4 funcSig;
|
||||
uint256 instances_offset;
|
||||
uint256 instances_length;
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := calldataload(encoded.offset)
|
||||
|
||||
}
|
||||
if (funcSig == 0xaf83a18d) {
|
||||
instances_offset = 0x44;
|
||||
} else if (funcSig == 0x1e8e1e13) {
|
||||
instances_offset = 0x24;
|
||||
} else {
|
||||
revert("Invalid function signature");
|
||||
}
|
||||
// We need to create a new assembly block in order for solidity
|
||||
// to cast the funcSig to a bytes4 type. Otherwise it will load the entire first 32 bytes of the calldata
|
||||
// within the block
|
||||
assembly {
|
||||
// Fetch instances offset which is 4 + 32 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 +32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
instances_offset := calldataload(
|
||||
add(
|
||||
encoded.offset,
|
||||
add(0x24, mul(0x20, eq(funcSig, 0xaf83a18d)))
|
||||
)
|
||||
add(encoded.offset, instances_offset)
|
||||
)
|
||||
|
||||
instances_length := calldataload(
|
||||
@@ -96,7 +114,7 @@ contract LoadInstances {
|
||||
// The kzg commitments of a given model, all aggregated into a single bytes array.
|
||||
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
|
||||
// It will be used to check that the proof commitments match the expected commitments.
|
||||
bytes constant COMMITMENT_KZG = hex"";
|
||||
bytes constant COMMITMENT_KZG = hex"1234";
|
||||
|
||||
contract SwapProofCommitments {
|
||||
/**
|
||||
@@ -113,17 +131,20 @@ contract SwapProofCommitments {
|
||||
assembly {
|
||||
// fetch function sig. Either `verifyProof(bytes,uint256[])` or `verifyProof(address,bytes,uint256[])`
|
||||
funcSig := calldataload(encoded.offset)
|
||||
|
||||
}
|
||||
if (funcSig == 0xaf83a18d) {
|
||||
proof_offset = 0x24;
|
||||
} else if (funcSig == 0x1e8e1e13) {
|
||||
proof_offset = 0x04;
|
||||
} else {
|
||||
revert("Invalid function signature");
|
||||
}
|
||||
assembly {
|
||||
// Fetch proof offset which is 4 + 32 bytes away from
|
||||
// start of encoded for `verifyProof(bytes,uint256[])`,
|
||||
// and 4 + 32 + 32 away for `verifyProof(address,bytes,uint256[])`
|
||||
|
||||
proof_offset := calldataload(
|
||||
add(
|
||||
encoded.offset,
|
||||
add(0x04, mul(0x20, eq(funcSig, 0xaf83a18d)))
|
||||
)
|
||||
)
|
||||
proof_offset := calldataload(add(encoded.offset, proof_offset))
|
||||
|
||||
proof_length := calldataload(
|
||||
add(add(encoded.offset, 0x04), proof_offset)
|
||||
@@ -154,7 +175,7 @@ contract SwapProofCommitments {
|
||||
let wordCommitment := mload(add(commitment, i))
|
||||
equal := eq(wordProof, wordCommitment)
|
||||
if eq(equal, 0) {
|
||||
return(0, 0)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -163,36 +184,38 @@ contract SwapProofCommitments {
|
||||
} /// end checkKzgCommits
|
||||
}
|
||||
|
||||
contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @notice Struct used to make view only call to account to fetch the data that EZKL reads from.
|
||||
* @param the address of the account to make calls to
|
||||
* @param the abi encoded function calls to make to the `contractAddress`
|
||||
*/
|
||||
struct AccountCall {
|
||||
address contractAddress;
|
||||
bytes callData;
|
||||
contract DataAttestation is LoadInstances, SwapProofCommitments {
|
||||
// the address of the account to make calls to
|
||||
address public immutable contractAddress;
|
||||
|
||||
// the abi encoded function calls to make to the `contractAddress` that returns the attested to data
|
||||
bytes public callData;
|
||||
|
||||
struct Scalars {
|
||||
// The number of base 10 decimals to scale the data by.
|
||||
// For most ERC20 tokens this is 1e18
|
||||
uint256 decimals;
|
||||
// The number of fractional bits of the fixed point EZKL data points.
|
||||
uint256 bits;
|
||||
}
|
||||
AccountCall public accountCall;
|
||||
|
||||
uint[] scales;
|
||||
Scalars[] private scalars;
|
||||
|
||||
address public admin;
|
||||
function getScalars(uint256 index) public view returns (Scalars memory) {
|
||||
return scalars[index];
|
||||
}
|
||||
|
||||
/**
|
||||
* @notice EZKL P value
|
||||
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
|
||||
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
|
||||
*/
|
||||
uint256 constant ORDER =
|
||||
uint256 public constant ORDER =
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
);
|
||||
|
||||
uint256 constant INPUT_LEN = 0;
|
||||
|
||||
uint256 constant OUTPUT_LEN = 0;
|
||||
uint256 public constant HALF_ORDER = ORDER >> 1;
|
||||
|
||||
uint8 public instanceOffset;
|
||||
|
||||
@@ -204,53 +227,27 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
|
||||
constructor(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256 _decimals,
|
||||
uint[] memory _scales,
|
||||
uint8 _instanceOffset,
|
||||
address _admin
|
||||
uint256[] memory _decimals,
|
||||
uint[] memory _bits,
|
||||
uint8 _instanceOffset
|
||||
) {
|
||||
admin = _admin;
|
||||
for (uint i; i < _scales.length; i++) {
|
||||
scales.push(1 << _scales[i]);
|
||||
require(
|
||||
_bits.length == _decimals.length,
|
||||
"Invalid scalar array lengths"
|
||||
);
|
||||
for (uint i; i < _bits.length; i++) {
|
||||
scalars.push(Scalars(10 ** _decimals[i], 1 << _bits[i]));
|
||||
}
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
contractAddress = _contractAddresses;
|
||||
callData = _callData;
|
||||
instanceOffset = _instanceOffset;
|
||||
}
|
||||
|
||||
function updateAdmin(address _admin) external {
|
||||
require(msg.sender == admin, "Only admin can update admin");
|
||||
if (_admin == address(0)) {
|
||||
revert();
|
||||
}
|
||||
admin = _admin;
|
||||
}
|
||||
|
||||
function updateAccountCalls(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256 _decimals
|
||||
) external {
|
||||
require(msg.sender == admin, "Only admin can update account calls");
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
}
|
||||
|
||||
function populateAccountCalls(
|
||||
address _contractAddresses,
|
||||
bytes memory _callData,
|
||||
uint256 _decimals
|
||||
) internal {
|
||||
AccountCall memory _accountCall = accountCall;
|
||||
_accountCall.contractAddress = _contractAddresses;
|
||||
_accountCall.callData = _callData;
|
||||
_accountCall.decimals = 10 ** _decimals;
|
||||
accountCall = _accountCall;
|
||||
}
|
||||
|
||||
function mulDiv(
|
||||
uint256 x,
|
||||
uint256 y,
|
||||
uint256 denominator
|
||||
) internal pure returns (uint256 result) {
|
||||
) public pure returns (uint256 result) {
|
||||
unchecked {
|
||||
uint256 prod0;
|
||||
uint256 prod1;
|
||||
@@ -298,21 +295,28 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
|
||||
* @param x - One of the elements of the data returned from the account calls
|
||||
* @param _decimals - Number of base 10 decimals to scale the data by.
|
||||
* @param _scale - The base 2 scale used to convert the floating point value into a fixed point value.
|
||||
* @param _scalars - The scaling factors for the data returned from the account calls.
|
||||
*
|
||||
*/
|
||||
function quantizeData(
|
||||
int x,
|
||||
uint256 _decimals,
|
||||
uint256 _scale
|
||||
) internal pure returns (int256 quantized_data) {
|
||||
Scalars memory _scalars
|
||||
) public pure returns (int256 quantized_data) {
|
||||
if (_scalars.bits == 1 && _scalars.decimals == 1) {
|
||||
return x;
|
||||
}
|
||||
bool neg = x < 0;
|
||||
if (neg) x = -x;
|
||||
uint output = mulDiv(uint256(x), _scale, _decimals);
|
||||
if (mulmod(uint256(x), _scale, _decimals) * 2 >= _decimals) {
|
||||
uint output = mulDiv(uint256(x), _scalars.bits, _scalars.decimals);
|
||||
if (
|
||||
mulmod(uint256(x), _scalars.bits, _scalars.decimals) * 2 >=
|
||||
_scalars.decimals
|
||||
) {
|
||||
output += 1;
|
||||
}
|
||||
if (output > HALF_ORDER) {
|
||||
revert("Overflow field modulus");
|
||||
}
|
||||
quantized_data = neg ? -int256(output) : int256(output);
|
||||
}
|
||||
/**
|
||||
@@ -324,7 +328,7 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
|
||||
function staticCall(
|
||||
address target,
|
||||
bytes memory data
|
||||
) internal view returns (bytes memory) {
|
||||
) public view returns (bytes memory) {
|
||||
(bool success, bytes memory returndata) = target.staticcall(data);
|
||||
if (success) {
|
||||
if (returndata.length == 0) {
|
||||
@@ -345,7 +349,7 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
|
||||
*/
|
||||
function toFieldElement(
|
||||
int256 x
|
||||
) internal pure returns (uint256 field_element) {
|
||||
) public pure returns (uint256 field_element) {
|
||||
// The casting down to uint256 is safe because the order is about 2^254, and the value
|
||||
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
|
||||
return uint256(x + int(ORDER)) % ORDER;
|
||||
@@ -355,315 +359,16 @@ contract DataAttestationSingle is LoadInstances, SwapProofCommitments {
|
||||
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
|
||||
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
|
||||
*/
|
||||
function attestData(uint256[] memory instances) internal view {
|
||||
require(
|
||||
instances.length >= INPUT_LEN + OUTPUT_LEN,
|
||||
"Invalid public inputs length"
|
||||
);
|
||||
AccountCall memory _accountCall = accountCall;
|
||||
uint[] memory _scales = scales;
|
||||
bytes memory returnData = staticCall(
|
||||
_accountCall.contractAddress,
|
||||
_accountCall.callData
|
||||
);
|
||||
function attestData(uint256[] memory instances) public view {
|
||||
bytes memory returnData = staticCall(contractAddress, callData);
|
||||
int256[] memory x = abi.decode(returnData, (int256[]));
|
||||
uint _offset;
|
||||
int output = quantizeData(x[0], _accountCall.decimals, _scales[0]);
|
||||
uint field_element = toFieldElement(output);
|
||||
int output;
|
||||
uint fieldElement;
|
||||
for (uint i = 0; i < x.length; i++) {
|
||||
if (field_element != instances[i + instanceOffset]) {
|
||||
_offset += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
uint length = x.length - _offset;
|
||||
for (uint i = 1; i < length; i++) {
|
||||
output = quantizeData(x[i], _accountCall.decimals, _scales[i]);
|
||||
field_element = toFieldElement(output);
|
||||
require(
|
||||
field_element == instances[i + instanceOffset + _offset],
|
||||
"Public input does not match"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Verify the proof with the data attestation.
|
||||
* @param verifier - The address of the verifier contract.
|
||||
* @param encoded - The verifier calldata.
|
||||
*/
|
||||
function verifyWithDataAttestation(
|
||||
address verifier,
|
||||
bytes calldata encoded
|
||||
) public view returns (bool) {
|
||||
require(verifier.code.length > 0, "Address: call to non-contract");
|
||||
attestData(getInstancesCalldata(encoded));
|
||||
// static call the verifier contract to verify the proof
|
||||
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
|
||||
|
||||
if (success) {
|
||||
return abi.decode(returndata, (bool));
|
||||
} else {
|
||||
revert("low-level call to verifier failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This contract serves as a Data Attestation Verifier for the EZKL model.
|
||||
// It is designed to read and attest to instances of proofs generated from a specified circuit.
|
||||
// It is particularly constructed to read only int256 data from specified on-chain contracts' view functions.
|
||||
|
||||
// Overview of the contract functionality:
|
||||
// 1. Initialization: Through the constructor, it sets up the contract calls that the EZKL model will read from.
|
||||
// 2. Data Quantization: Quantizes the returned data into a scaled fixed-point representation. See the `quantizeData` method for details.
|
||||
// 3. Static Calls: Makes static calls to fetch data from other contracts. See the `staticCall` method.
|
||||
// 4. Field Element Conversion: The fixed-point representation is then converted into a field element modulo P using the `toFieldElement` method.
|
||||
// 5. Data Attestation: The `attestData` method validates that the public instances match the data fetched and processed by the contract.
|
||||
// 6. Proof Verification: The `verifyWithDataAttestationMulti` method parses the instances out of the encoded calldata and calls the `attestData` method to validate the public instances,
|
||||
// 6b. Optional KZG Commitment Verification: It also checks the KZG commitments in the proof against the expected commitments using the `checkKzgCommits` method.
|
||||
// then calls the `verifyProof` method to verify the proof on the verifier.
|
||||
|
||||
contract DataAttestationMulti is LoadInstances, SwapProofCommitments {
|
||||
/**
|
||||
* @notice Struct used to make view only calls to accounts to fetch the data that EZKL reads from.
|
||||
* @param the address of the account to make calls to
|
||||
* @param the abi encoded function calls to make to the `contractAddress`
|
||||
*/
|
||||
struct AccountCall {
|
||||
address contractAddress;
|
||||
mapping(uint256 => bytes) callData;
|
||||
mapping(uint256 => uint256) decimals;
|
||||
uint callCount;
|
||||
}
|
||||
AccountCall[] public accountCalls;
|
||||
|
||||
uint[] public scales;
|
||||
|
||||
address public admin;
|
||||
|
||||
/**
|
||||
* @notice EZKL P value
|
||||
* @dev In order to prevent the verifier from accepting two version of the same pubInput, n and the quantity (n + P), where n + P <= 2^256, we require that all instances are stricly less than P. a
|
||||
* @dev The reason for this is that the assmebly code of the verifier performs all arithmetic operations modulo P and as a consequence can't distinguish between n and n + P.
|
||||
*/
|
||||
uint256 constant ORDER =
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
);
|
||||
|
||||
uint256 constant INPUT_CALLS = 0;
|
||||
|
||||
uint256 constant OUTPUT_CALLS = 0;
|
||||
|
||||
uint8 public instanceOffset;
|
||||
|
||||
/**
|
||||
* @dev Initialize the contract with account calls the EZKL model will read from.
|
||||
* @param _contractAddresses - The calls to all the contracts EZKL reads storage from.
|
||||
* @param _callData - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
|
||||
*/
|
||||
constructor(
|
||||
address[] memory _contractAddresses,
|
||||
bytes[][] memory _callData,
|
||||
uint256[][] memory _decimals,
|
||||
uint[] memory _scales,
|
||||
uint8 _instanceOffset,
|
||||
address _admin
|
||||
) {
|
||||
admin = _admin;
|
||||
for (uint i; i < _scales.length; i++) {
|
||||
scales.push(1 << _scales[i]);
|
||||
}
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
instanceOffset = _instanceOffset;
|
||||
}
|
||||
|
||||
function updateAdmin(address _admin) external {
|
||||
require(msg.sender == admin, "Only admin can update admin");
|
||||
if (_admin == address(0)) {
|
||||
revert();
|
||||
}
|
||||
admin = _admin;
|
||||
}
|
||||
|
||||
function updateAccountCalls(
|
||||
address[] memory _contractAddresses,
|
||||
bytes[][] memory _callData,
|
||||
uint256[][] memory _decimals
|
||||
) external {
|
||||
require(msg.sender == admin, "Only admin can update account calls");
|
||||
populateAccountCalls(_contractAddresses, _callData, _decimals);
|
||||
}
|
||||
|
||||
function populateAccountCalls(
|
||||
address[] memory _contractAddresses,
|
||||
bytes[][] memory _callData,
|
||||
uint256[][] memory _decimals
|
||||
) internal {
|
||||
require(
|
||||
_contractAddresses.length == _callData.length &&
|
||||
accountCalls.length == _contractAddresses.length,
|
||||
"Invalid input length"
|
||||
);
|
||||
require(
|
||||
_decimals.length == _contractAddresses.length,
|
||||
"Invalid number of decimals"
|
||||
);
|
||||
// fill in the accountCalls storage array
|
||||
uint counter = 0;
|
||||
for (uint256 i = 0; i < _contractAddresses.length; i++) {
|
||||
AccountCall storage accountCall = accountCalls[i];
|
||||
accountCall.contractAddress = _contractAddresses[i];
|
||||
accountCall.callCount = _callData[i].length;
|
||||
for (uint256 j = 0; j < _callData[i].length; j++) {
|
||||
accountCall.callData[j] = _callData[i][j];
|
||||
accountCall.decimals[j] = 10 ** _decimals[i][j];
|
||||
}
|
||||
// count the total number of storage reads across all of the accounts
|
||||
counter += _callData[i].length;
|
||||
}
|
||||
require(
|
||||
counter == INPUT_CALLS + OUTPUT_CALLS,
|
||||
"Invalid number of calls"
|
||||
);
|
||||
}
|
||||
|
||||
function mulDiv(
|
||||
uint256 x,
|
||||
uint256 y,
|
||||
uint256 denominator
|
||||
) internal pure returns (uint256 result) {
|
||||
unchecked {
|
||||
uint256 prod0;
|
||||
uint256 prod1;
|
||||
assembly {
|
||||
let mm := mulmod(x, y, not(0))
|
||||
prod0 := mul(x, y)
|
||||
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
|
||||
}
|
||||
|
||||
if (prod1 == 0) {
|
||||
return prod0 / denominator;
|
||||
}
|
||||
|
||||
require(denominator > prod1, "Math: mulDiv overflow");
|
||||
|
||||
uint256 remainder;
|
||||
assembly {
|
||||
remainder := mulmod(x, y, denominator)
|
||||
prod1 := sub(prod1, gt(remainder, prod0))
|
||||
prod0 := sub(prod0, remainder)
|
||||
}
|
||||
|
||||
uint256 twos = denominator & (~denominator + 1);
|
||||
assembly {
|
||||
denominator := div(denominator, twos)
|
||||
prod0 := div(prod0, twos)
|
||||
twos := add(div(sub(0, twos), twos), 1)
|
||||
}
|
||||
|
||||
prod0 |= prod1 * twos;
|
||||
|
||||
uint256 inverse = (3 * denominator) ^ 2;
|
||||
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
|
||||
result = prod0 * inverse;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @dev Quantize the data returned from the account calls to the scale used by the EZKL model.
|
||||
* @param data - The data returned from the account calls.
|
||||
* @param decimals - The number of decimals the data returned from the account calls has (for floating point representation).
|
||||
* @param scale - The scale used to convert the floating point value into a fixed point value.
|
||||
*/
|
||||
function quantizeData(
|
||||
bytes memory data,
|
||||
uint256 decimals,
|
||||
uint256 scale
|
||||
) internal pure returns (int256 quantized_data) {
|
||||
int x = abi.decode(data, (int256));
|
||||
bool neg = x < 0;
|
||||
if (neg) x = -x;
|
||||
uint output = mulDiv(uint256(x), scale, decimals);
|
||||
if (mulmod(uint256(x), scale, decimals) * 2 >= decimals) {
|
||||
output += 1;
|
||||
}
|
||||
quantized_data = neg ? -int256(output) : int256(output);
|
||||
}
|
||||
/**
|
||||
* @dev Make a static call to the account to fetch the data that EZKL reads from.
|
||||
* @param target - The address of the account to make calls to.
|
||||
* @param data - The abi encoded function calls to make to the `contractAddress` that EZKL reads storage from.
|
||||
* @return The data returned from the account calls. (Must come from either a view or pure function. Will throw an error otherwise)
|
||||
*/
|
||||
function staticCall(
|
||||
address target,
|
||||
bytes memory data
|
||||
) internal view returns (bytes memory) {
|
||||
(bool success, bytes memory returndata) = target.staticcall(data);
|
||||
if (success) {
|
||||
if (returndata.length == 0) {
|
||||
require(
|
||||
target.code.length > 0,
|
||||
"Address: call to non-contract"
|
||||
);
|
||||
}
|
||||
return returndata;
|
||||
} else {
|
||||
revert("Address: low-level call failed");
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @dev Convert the fixed point quantized data into a field element.
|
||||
* @param x - The quantized data.
|
||||
* @return field_element - The field element.
|
||||
*/
|
||||
function toFieldElement(
|
||||
int256 x
|
||||
) internal pure returns (uint256 field_element) {
|
||||
// The casting down to uint256 is safe because the order is about 2^254, and the value
|
||||
// of x ranges of -2^127 to 2^127, so x + int(ORDER) is always positive.
|
||||
return uint256(x + int(ORDER)) % ORDER;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Make the account calls to fetch the data that EZKL reads from and attest to the data.
|
||||
* @param instances - The public instances to the proof (the data in the proof that publicly accessible to the verifier).
|
||||
*/
|
||||
function attestData(uint256[] memory instances) internal view {
|
||||
require(
|
||||
instances.length >= INPUT_CALLS + OUTPUT_CALLS,
|
||||
"Invalid public inputs length"
|
||||
);
|
||||
uint256 _accountCount = accountCalls.length;
|
||||
uint counter = 0;
|
||||
for (uint8 i = 0; i < _accountCount; ++i) {
|
||||
address account = accountCalls[i].contractAddress;
|
||||
for (uint8 j = 0; j < accountCalls[i].callCount; j++) {
|
||||
bytes memory returnData = staticCall(
|
||||
account,
|
||||
accountCalls[i].callData[j]
|
||||
);
|
||||
uint256 scale = scales[counter];
|
||||
int256 quantized_data = quantizeData(
|
||||
returnData,
|
||||
accountCalls[i].decimals[j],
|
||||
scale
|
||||
);
|
||||
uint256 field_element = toFieldElement(quantized_data);
|
||||
require(
|
||||
field_element == instances[counter + instanceOffset],
|
||||
"Public input does not match"
|
||||
);
|
||||
counter++;
|
||||
output = quantizeData(x[i], scalars[i]);
|
||||
fieldElement = toFieldElement(output);
|
||||
if (fieldElement != instances[i]) {
|
||||
revert("Public input does not match");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,52 @@
|
||||
# EZKL Security Note: Quantization-Induced Model Backdoors
|
||||
# EZKL Security Note: Quantization-Activated Model Backdoors
|
||||
|
||||
> Note: this only affects a situation where a party separate to an application's developer has access to the model's weights and can modify them. This is a common scenario in adversarial machine learning research, but can be less common in real-world applications. If you're building your models in house and deploying them yourself, this is less of a concern. If you're building a permisionless system where anyone can submit models, this is more of a concern.
|
||||
## Model backdoors and provenance
|
||||
|
||||
Models processed through EZKL's quantization step can harbor backdoors that are dormant in the original full-precision model but activate during quantization. These backdoors force specific outputs when triggered, with impact varying by application.
|
||||
Machine learning models inherently suffer from robustness issues, which can lead to various
|
||||
kinds of attacks, from backdoors to evasion attacks. These vulnerabilities are a direct byproductof how machine learning models learn and cannot be remediated.
|
||||
|
||||
Key Factors:
|
||||
We say a model has a backdoor whenever a specific attacker-chosen trigger in the input leads
|
||||
to the model misbehaving. For instance, if we have an image classifier discriminating cats from dogs, the ability to turn any image of a cat into an image classified as a dog by changing a specific pixel pattern constitutes a backdoor.
|
||||
|
||||
- Larger models increase attack feasibility through more parameter capacity
|
||||
- Smaller quantization scales facilitate attacks by allowing greater weight modifications
|
||||
- Rebase ratio of 1 enables exploitation of convolutional layer consistency
|
||||
Backdoors can be introduced using many different vectors. An attacker can introduce a
|
||||
backdoor using traditional security vulnerabilities. For instance, they could directly alter the file containing model weights or dynamically hack the Python code of the model. In addition, backdoors can be introduced by the training data through a process known as poisoning. In this case, an attacker adds malicious data points to the dataset before the model is trained so that the model learns to associate the backdoor trigger with the intended misbehavior.
|
||||
|
||||
Limitations:
|
||||
All these vectors constitute a whole range of provenance challenges, as any component of an
|
||||
AI system can virtually be an entrypoint for a backdoor. Although provenance is already a
|
||||
concern with traditional code, the issue is exacerbated with AI, as retraining a model is
|
||||
cost-prohibitive. It is thus impractical to translate the “recompile it yourself” thinking to AI.
|
||||
|
||||
- Attack effectiveness depends on calibration settings and internal rescaling operations.
|
||||
## Quantization activated backdoors
|
||||
|
||||
Backdoors are a generic concern in AI that is outside the scope of EZKL. However, EZKL may
|
||||
activate a specific subset of backdoors. Several academic papers have demonstrated the
|
||||
possibility, both in theory and in practice, of implanting undetectable and inactive backdoors in a full precision model that can be reactivated by quantization.
|
||||
|
||||
An external attacker may trick the user of an application running EZKL into loading a model
|
||||
containing a quantization backdoor. This backdoor is active in the resulting model and circuit but not in the full-precision model supplied to EZKL, compromising the integrity of the target application and the resulting proof.
|
||||
|
||||
### When is this a concern for me as a user?
|
||||
|
||||
Any untrusted component in your AI stack may be a backdoor vector. In practice, the most
|
||||
sensitive parts include:
|
||||
|
||||
- Datasets downloaded from the web or containing crowdsourced data
|
||||
- Models downloaded from the web even after finetuning
|
||||
- Untrusted software dependencies (well-known frameworks such as PyTorch can typically
|
||||
be considered trusted)
|
||||
- Any component loaded through an unsafe serialization format, such as Pickle.
|
||||
Because backdoors are inherent to ML and cannot be eliminated, reviewing the provenance of
|
||||
these sensitive components is especially important.
|
||||
|
||||
### Responsibilities of the user and EZKL
|
||||
|
||||
As EZKL cannot prevent backdoored models from being used, it is the responsibility of the user to review the provenance of all the components in their AI stack to ensure that no backdoor could have been implanted. EZKL shall not be held responsible for misleading prediction proofs resulting from using a backdoored model or for any harm caused to a system or its users due to a misbehaving model.
|
||||
|
||||
### Limitations:
|
||||
|
||||
- Attack effectiveness depends on calibration settings and internal rescaling operations.
|
||||
- Further research needed on backdoor persistence through witness/proof stages.
|
||||
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model.
|
||||
- Can be mitigated by evaluating the quantized model (using `ezkl gen-witness`), rather than relying on the evaluation of the original model in pytorch or onnx-runtime as difference in evaluation could reveal a backdoor.
|
||||
|
||||
References:
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '21.0.0'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -272,33 +272,21 @@
|
||||
"\n",
|
||||
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
|
||||
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like:\n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
|
||||
" \n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": [\n",
|
||||
" {\n",
|
||||
" \"call_data\": [\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
|
||||
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
|
||||
" 5\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
|
||||
" 5\n",
|
||||
" ]\n",
|
||||
" ],\n",
|
||||
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
"}"
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": {\n",
|
||||
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
|
||||
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
|
||||
" \"len\": 12 // The number of data points returned by the view function (the length of the array)\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -307,7 +295,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await ezkl.setup_test_evm_witness(\n",
|
||||
"await ezkl.setup_test_evm_data(\n",
|
||||
" data_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" # we write the call data to the same file as the input data\n",
|
||||
|
||||
@@ -337,6 +337,7 @@
|
||||
"w3 = Web3(HTTPProvider(RPC_URL))\n",
|
||||
"\n",
|
||||
"def test_on_chain_data(res):\n",
|
||||
" print(f'poseidon_hash: {res[\"processed_outputs\"][\"poseidon_hash\"]}')\n",
|
||||
" # Step 0: Convert the tensor to a flat list\n",
|
||||
" data = [int(ezkl.felt_to_big_endian(res['processed_outputs']['poseidon_hash'][0]), 0)]\n",
|
||||
"\n",
|
||||
@@ -356,6 +357,9 @@
|
||||
" arr.push(_numbers[i]);\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" function getArr() public view returns (uint[] memory) {\n",
|
||||
" return arr;\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
@@ -382,31 +386,30 @@
|
||||
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
|
||||
"\n",
|
||||
" # Step 4: Interact with the contract\n",
|
||||
" calldata = []\n",
|
||||
" for i, _ in enumerate(data):\n",
|
||||
" call = contract.functions.arr(i).build_transaction()\n",
|
||||
" calldata.append((call['data'][2:], 0))\n",
|
||||
" calldata = contract.functions.getArr().build_transaction()['data'][2:]\n",
|
||||
"\n",
|
||||
" # Prepare the calls_to_account object\n",
|
||||
" # If you were calling view functions across multiple contracts,\n",
|
||||
" # you would have multiple entries in the calls_to_account array,\n",
|
||||
" # one for each contract.\n",
|
||||
" calls_to_account = [{\n",
|
||||
" decimals = [0] * len(data)\n",
|
||||
" call_to_account = {\n",
|
||||
" 'call_data': calldata,\n",
|
||||
" 'decimals': decimals,\n",
|
||||
" 'address': contract.address[2:], # remove the '0x' prefix\n",
|
||||
" }]\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" print(f'calls_to_account: {calls_to_account}')\n",
|
||||
" print(f'call_to_account: {call_to_account}')\n",
|
||||
"\n",
|
||||
" return calls_to_account\n",
|
||||
" return call_to_account\n",
|
||||
"\n",
|
||||
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
|
||||
"start_anvil()\n",
|
||||
"\n",
|
||||
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
|
||||
"calls_to_account = test_on_chain_data(res)\n",
|
||||
"call_to_account = test_on_chain_data(res)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = [data_array], output_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
|
||||
"data = dict(input_data = [data_array], output_data = {'rpc': RPC_URL, 'call': call_to_account })\n",
|
||||
"\n",
|
||||
"# Serialize on-chain data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n",
|
||||
@@ -634,7 +637,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ezkl",
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -648,7 +651,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -276,33 +276,21 @@
|
||||
"\n",
|
||||
"- For file data sources, the raw floating point values that eventually get quantized, converted into field elements and stored in `witness.json` to be consumed by the circuit are stored. The output data contains the expected floating point values returned as outputs from running your vanilla pytorch model on the given inputs.\n",
|
||||
"- For on chain data sources, the input_data field contains all the data necessary to read and format the on chain data into something digestable by EZKL (aka field elements :-D). \n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like:\n",
|
||||
"Here is what the schema for an on-chain data source graph input file should look like for a single call data source:\n",
|
||||
" \n",
|
||||
"```json\n",
|
||||
"{\n",
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": [\n",
|
||||
" {\n",
|
||||
" \"call_data\": [\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns a single on-chain data point (we only support uint256 returns for now)\n",
|
||||
" 7 // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000001\",\n",
|
||||
" 5\n",
|
||||
" ],\n",
|
||||
" [\n",
|
||||
" \"71e5ee5f0000000000000000000000000000000000000000000000000000000000000002\",\n",
|
||||
" 5\n",
|
||||
" ]\n",
|
||||
" ],\n",
|
||||
" \"address\": \"5fbdb2315678afecb367f032d93f642f64180aa3\" // The address of the contract that we are calling to get the data. \n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
"}"
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": {\n",
|
||||
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
|
||||
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
|
||||
" \"len\": 3 // The number of data points returned by the view function (the length of the array)\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -311,7 +299,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"await ezkl.setup_test_evm_witness(\n",
|
||||
"await ezkl.setup_test_evm_data(\n",
|
||||
" data_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" # we write the call data to the same file as the input data\n",
|
||||
@@ -337,7 +325,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = await ezkl.get_srs( settings_path)\n"
|
||||
"res = await ezkl.get_srs( settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -348,27 +336,6 @@
|
||||
"We now need to generate the circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -391,6 +358,27 @@
|
||||
"assert os.path.isfile(settings_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
@@ -581,7 +569,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ezkl",
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -595,7 +583,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
|
||||
@@ -220,15 +220,6 @@
|
||||
"Check that the generated verifiers are identical for all models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"start_anvil()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -60,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -94,7 +94,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -134,7 +134,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -183,7 +183,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -201,7 +201,7 @@
|
||||
"run_args.input_visibility = \"public\"\n",
|
||||
"run_args.param_visibility = \"private\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.decomp_legs=6\n",
|
||||
"run_args.decomp_legs=5\n",
|
||||
"run_args.num_inner_cols = 1\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]"
|
||||
]
|
||||
@@ -270,7 +270,7 @@
|
||||
"{\n",
|
||||
" \"input_data\": {\n",
|
||||
" \"rpc\": \"http://localhost:3030\", // The rpc endpoint of the chain you are deploying your verifier to\n",
|
||||
" \"calls\": {\n",
|
||||
" \"call\": {\n",
|
||||
" \"call_data\": \"1f3be514000000000000000000000000c6962004f452be9203591991d15f6b388e09e8d00000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000b000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000000000000000000000000000000000000000000009000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000070000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000000500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000\", // The abi encoded call data to a view function that returns an array of on-chain data points we are attesting to. \n",
|
||||
" \"decimals\": 0, // The number of decimal places of the large uint256 value. This is our way of representing large wei values as floating points on chain, since the evm only natively supports integer values.\n",
|
||||
" \"address\": \"9A213F53334279C128C37DA962E5472eCD90554f\", // The address of the contract that we are calling to get the data. \n",
|
||||
@@ -295,7 +295,6 @@
|
||||
"import torch\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"# This function counts the decimal places of a floating point number\n",
|
||||
"def count_decimal_places(num):\n",
|
||||
" num_str = str(num)\n",
|
||||
" if '.' in num_str:\n",
|
||||
@@ -303,69 +302,28 @@
|
||||
" else:\n",
|
||||
" return 0\n",
|
||||
"\n",
|
||||
"# setup web3 instance\n",
|
||||
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
|
||||
"\n",
|
||||
"def set_next_block_timestamp(anvil_url, timestamp):\n",
|
||||
" # Send the JSON-RPC request to Anvil\n",
|
||||
" payload = {\n",
|
||||
" \"jsonrpc\": \"2.0\",\n",
|
||||
" \"id\": 1,\n",
|
||||
" \"method\": \"evm_setNextBlockTimestamp\",\n",
|
||||
" \"params\": [timestamp]\n",
|
||||
" }\n",
|
||||
" response = requests.post(anvil_url, json=payload)\n",
|
||||
" if response.status_code == 200:\n",
|
||||
" print(f\"Next block timestamp set to: {timestamp}\")\n",
|
||||
" else:\n",
|
||||
" print(f\"Failed to set next block timestamp: {response.text}\")\n",
|
||||
"\n",
|
||||
"def on_chain_data(tensor):\n",
|
||||
" # Step 0: Convert the tensor to a flat list\n",
|
||||
" data = tensor.view(-1).tolist()\n",
|
||||
"\n",
|
||||
" # Step 1: Prepare the calldata\n",
|
||||
" secondsAgo = [len(data) - 1 - i for i in range(len(data))]\n",
|
||||
"\n",
|
||||
" # Step 2: Prepare and compile the contract UniTickAttestor contract\n",
|
||||
" contract_source_code = '''\n",
|
||||
" // SPDX-License-Identifier: MIT\n",
|
||||
" pragma solidity ^0.8.20;\n",
|
||||
"\n",
|
||||
" /// @title Pool state that is not stored\n",
|
||||
" /// @notice Contains view functions to provide information about the pool that is computed rather than stored on the\n",
|
||||
" /// blockchain. The functions here may have variable gas costs.\n",
|
||||
" interface IUniswapV3PoolDerivedState {\n",
|
||||
" /// @notice Returns the cumulative tick and liquidity as of each timestamp `secondsAgo` from the current block timestamp\n",
|
||||
" /// @dev To get a time weighted average tick or liquidity-in-range, you must call this with two values, one representing\n",
|
||||
" /// the beginning of the period and another for the end of the period. E.g., to get the last hour time-weighted average tick,\n",
|
||||
" /// you must call it with secondsAgos = [3600, 0].\n",
|
||||
" /// log base sqrt(1.0001) of token1 / token0. The TickMath library can be used to go from a tick value to a ratio.\n",
|
||||
" /// @dev The time weighted average tick represents the geometric time weighted average price of the pool, in\n",
|
||||
" /// @param secondsAgos From how long ago each cumulative tick and liquidity value should be returned\n",
|
||||
" /// @return tickCumulatives Cumulative tick values as of each `secondsAgos` from the current block timestamp\n",
|
||||
" /// @return secondsPerLiquidityCumulativeX128s Cumulative seconds per liquidity-in-range value as of each `secondsAgos` from the current block\n",
|
||||
" /// timestamp\n",
|
||||
" function observe(\n",
|
||||
" uint32[] calldata secondsAgos\n",
|
||||
" )\n",
|
||||
" external\n",
|
||||
" view\n",
|
||||
" returns (\n",
|
||||
" int56[] memory tickCumulatives,\n",
|
||||
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
|
||||
" );\n",
|
||||
" ) external view returns (\n",
|
||||
" int56[] memory tickCumulatives,\n",
|
||||
" uint160[] memory secondsPerLiquidityCumulativeX128s\n",
|
||||
" );\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" /// @title Uniswap Wrapper around `pool.observe` that stores the parameters for fetching and then attesting to historical data\n",
|
||||
" /// @notice Provides functions to integrate with V3 pool oracle\n",
|
||||
" contract UniTickAttestor {\n",
|
||||
" /**\n",
|
||||
" * @notice Calculates time-weighted means of tick and liquidity for a given Uniswap V3 pool\n",
|
||||
" * @param pool Address of the pool that we want to observe\n",
|
||||
" * @param secondsAgo Number of seconds in the past from which to calculate the time-weighted means\n",
|
||||
" * @return tickCumulatives The cumulative tick values as of each `secondsAgo` from the current block timestamp\n",
|
||||
" */\n",
|
||||
" int256[] private cachedTicks;\n",
|
||||
"\n",
|
||||
" function consult(\n",
|
||||
" IUniswapV3PoolDerivedState pool,\n",
|
||||
" uint32[] memory secondsAgo\n",
|
||||
@@ -376,6 +334,21 @@
|
||||
" tickCumulatives[i] = int256(_ticks[i]);\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" function cache_price(\n",
|
||||
" IUniswapV3PoolDerivedState pool,\n",
|
||||
" uint32[] memory secondsAgo\n",
|
||||
" ) public {\n",
|
||||
" (int56[] memory _ticks,) = pool.observe(secondsAgo);\n",
|
||||
" cachedTicks = new int256[](_ticks.length);\n",
|
||||
" for (uint256 i = 0; i < _ticks.length; i++) {\n",
|
||||
" cachedTicks[i] = int256(_ticks[i]);\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" function readPriceCache() public view returns (int256[] memory) {\n",
|
||||
" return cachedTicks;\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" '''\n",
|
||||
"\n",
|
||||
@@ -385,69 +358,44 @@
|
||||
" \"settings\": {\"outputSelection\": {\"*\": {\"*\": [\"metadata\", \"evm.bytecode\", \"abi\"]}}}\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
" # Get bytecode\n",
|
||||
" bytecode = compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['evm']['bytecode']['object']\n",
|
||||
"\n",
|
||||
" # Get ABI\n",
|
||||
" # In production if you are reading from really large contracts you can just use\n",
|
||||
" # a stripped down version of the ABI of the contract you are calling, containing only the view functions you will fetch data from.\n",
|
||||
" abi = json.loads(compiled_sol['contracts']['UniTickAttestor.sol']['UniTickAttestor']['metadata'])['output']['abi']\n",
|
||||
"\n",
|
||||
" # Step 3: Deploy the contract\n",
|
||||
" # Deploy contract\n",
|
||||
" UniTickAttestor = w3.eth.contract(abi=abi, bytecode=bytecode)\n",
|
||||
" tx_hash = UniTickAttestor.constructor().transact()\n",
|
||||
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
|
||||
" # If you are deploying to production you can skip the 3 lines of code above and just instantiate the contract like this,\n",
|
||||
" # passing the address and abi of the contract you are fetching data from.\n",
|
||||
" contract = w3.eth.contract(address=tx_receipt['contractAddress'], abi=abi)\n",
|
||||
"\n",
|
||||
" # Step 4: Interact with the contract\n",
|
||||
" call = contract.functions.consult(\n",
|
||||
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
|
||||
" # Step 4: Store data via cache_price transaction\n",
|
||||
" tx_hash = contract.functions.cache_price(\n",
|
||||
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
|
||||
" secondsAgo\n",
|
||||
" ).build_transaction()\n",
|
||||
" result = contract.functions.consult(\n",
|
||||
" # Address of the UniV3 usdc-weth pool 0.005 fee\n",
|
||||
" \"0xC6962004f452bE9203591991D15f6b388e09E8D0\",\n",
|
||||
" secondsAgo\n",
|
||||
" ).call()\n",
|
||||
" \n",
|
||||
" print(f'result: {result}')\n",
|
||||
" ).transact()\n",
|
||||
" tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)\n",
|
||||
"\n",
|
||||
" # Step 5: Prepare calldata for readPriceCache\n",
|
||||
" call = contract.functions.readPriceCache().build_transaction()\n",
|
||||
" calldata = call['data'][2:]\n",
|
||||
"\n",
|
||||
" time_stamp = w3.eth.get_block('latest')['timestamp']\n",
|
||||
" # Get stored data\n",
|
||||
" result = contract.functions.readPriceCache().call()\n",
|
||||
" print(f'Cached ticks: {result}')\n",
|
||||
"\n",
|
||||
" print(f'time_stamp: {time_stamp}')\n",
|
||||
" decimals = [0] * len(data)\n",
|
||||
"\n",
|
||||
" # Set the next block timestamp using the fetched time_stamp\n",
|
||||
" set_next_block_timestamp(RPC_URL, time_stamp)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Prepare the calls_to_account object\n",
|
||||
" # If you were calling view functions across multiple contracts,\n",
|
||||
" # you would have multiple entries in the calls_to_account array,\n",
|
||||
" # one for each contract.\n",
|
||||
" call_to_account = {\n",
|
||||
" 'call_data': calldata,\n",
|
||||
" 'decimals': 0,\n",
|
||||
" 'address': contract.address[2:], # remove the '0x' prefix\n",
|
||||
" 'len': len(data),\n",
|
||||
" 'decimals': decimals,\n",
|
||||
" 'address': contract.address[2:],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" print(f'call_to_account: {call_to_account}')\n",
|
||||
"\n",
|
||||
" return call_to_account\n",
|
||||
"\n",
|
||||
"# Now let's start the Anvil process. You don't need to do this if you are deploying to a non-local chain.\n",
|
||||
"start_anvil()\n",
|
||||
"call_to_account = on_chain_data(x)\n",
|
||||
"\n",
|
||||
"# Now let's call our function, passing in the same input tensor we used to export the model 2 cells above.\n",
|
||||
"calls_to_account = on_chain_data(x)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
|
||||
"\n",
|
||||
"# Serialize on-chain data into file:\n",
|
||||
"data = dict(input_data = {'rpc': RPC_URL, 'call': call_to_account })\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))"
|
||||
]
|
||||
},
|
||||
@@ -692,34 +640,7 @@
|
||||
"source": [
|
||||
"# !export RUST_BACKTRACE=1\n",
|
||||
"\n",
|
||||
"calls_to_account = on_chain_data(x)\n",
|
||||
"\n",
|
||||
"data = dict(input_data = {'rpc': RPC_URL, 'calls': calls_to_account })\n",
|
||||
"\n",
|
||||
"# Serialize on-chain data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n",
|
||||
"\n",
|
||||
"# setup web3 instance\n",
|
||||
"w3 = Web3(HTTPProvider(RPC_URL)) \n",
|
||||
"\n",
|
||||
"time_stamp = w3.eth.get_block('latest')['timestamp']\n",
|
||||
"\n",
|
||||
"print(f'time_stamp: {time_stamp}')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
||||
"\n",
|
||||
"res = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(res)\n",
|
||||
"# print(res)\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"# read the verifier address\n",
|
||||
"addr_verifier = None\n",
|
||||
|
||||
@@ -246,7 +246,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ezkl.setup_test_evm_witness(\n",
|
||||
"ezkl.setup_test_evm_data(\n",
|
||||
" data_path,\n",
|
||||
" compiled_model_path,\n",
|
||||
" # we write the call data to the same file as the input data\n",
|
||||
@@ -374,14 +374,6 @@
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc888848",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -525,7 +517,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@@ -539,7 +531,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
4
ezkl.pyi
4
ezkl.pyi
@@ -706,9 +706,9 @@ def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Pa
|
||||
"""
|
||||
...
|
||||
|
||||
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
|
||||
def setup_test_evm_data(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
Setup test evm witness
|
||||
Setup test evm data
|
||||
|
||||
Arguments
|
||||
---------
|
||||
|
||||
@@ -1819,10 +1819,10 @@ fn create_evm_data_attestation(
|
||||
test_data,
|
||||
input_source,
|
||||
output_source,
|
||||
rpc_url=None,
|
||||
rpc_url=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup_test_evm_witness(
|
||||
fn setup_test_evm_data(
|
||||
py: Python,
|
||||
data_path: String,
|
||||
compiled_circuit_path: PathBuf,
|
||||
@@ -1832,7 +1832,7 @@ fn setup_test_evm_witness(
|
||||
rpc_url: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::setup_test_evm_witness(
|
||||
crate::execute::setup_test_evm_data(
|
||||
data_path,
|
||||
compiled_circuit_path,
|
||||
test_data,
|
||||
@@ -1842,7 +1842,7 @@ fn setup_test_evm_witness(
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run setup_test_evm_witness: {}", e);
|
||||
let err_str = format!("Failed to run setup_test_evm_data: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
@@ -2107,7 +2107,7 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_witness, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(setup_test_evm_data, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_data_attestation, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
|
||||
|
||||
@@ -21,7 +21,10 @@ pub enum BaseOp {
|
||||
|
||||
/// Matches a [BaseOp] to an operation over inputs
|
||||
impl BaseOp {
|
||||
/// forward func
|
||||
/// forward func for non-accumulating operations
|
||||
/// # Panics
|
||||
/// Panics if called on an accumulating operation
|
||||
/// # Examples
|
||||
pub fn nonaccum_f<
|
||||
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
|
||||
>(
|
||||
@@ -37,7 +40,9 @@ impl BaseOp {
|
||||
}
|
||||
}
|
||||
|
||||
/// forward func
|
||||
/// forward func for accumulating operations
|
||||
/// # Panics
|
||||
/// Panics if called on a non-accumulating operation
|
||||
pub fn accum_f<
|
||||
T: TensorType + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Neg<Output = T>,
|
||||
>(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -159,6 +159,8 @@ impl std::str::FromStr for InputType {
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl From<DatumType> for InputType {
|
||||
/// # Panics
|
||||
/// Panics if the datum type is not supported
|
||||
fn from(datum_type: DatumType) -> Self {
|
||||
match datum_type {
|
||||
DatumType::Bool => InputType::Bool,
|
||||
@@ -317,13 +319,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
|
||||
@@ -49,7 +49,7 @@ pub enum PolyOp {
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
stride: usize,
|
||||
stride: isize,
|
||||
modulo: usize,
|
||||
},
|
||||
DeConv {
|
||||
@@ -108,13 +108,8 @@ pub enum PolyOp {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
@@ -188,7 +183,8 @@ impl<
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={}, data_format={:?}, kernel_format={:?})",
|
||||
stride, padding, output_padding, group, data_format, kernel_format)
|
||||
stride, padding, output_padding, group, data_format, kernel_format
|
||||
)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
|
||||
@@ -401,7 +401,6 @@ pub enum Commands {
|
||||
/// Generates the witness from an input file.
|
||||
GenWitness {
|
||||
/// The path to the .json data file
|
||||
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
@@ -435,7 +434,7 @@ pub enum Commands {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_MODEL, value_hint = clap::ValueHint::FilePath)]
|
||||
model: Option<PathBuf>,
|
||||
/// The path to the .json data file to output
|
||||
/// The path to the .json data file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
/// Hand-written parser for graph variables, eg. batch_size=1
|
||||
@@ -448,7 +447,6 @@ pub enum Commands {
|
||||
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
|
||||
CalibrateSettings {
|
||||
/// The path to the .json calibration data file.
|
||||
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
/// The path to the .onnx model file
|
||||
@@ -633,7 +631,6 @@ pub enum Commands {
|
||||
#[command(arg_required_else_help = true)]
|
||||
SetupTestEvmData {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
|
||||
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
@@ -654,20 +651,6 @@ pub enum Commands {
|
||||
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
|
||||
output_source: TestDataSource,
|
||||
},
|
||||
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
|
||||
#[command(arg_required_else_help = true)]
|
||||
TestUpdateAccountCalls {
|
||||
/// The path to the verifier contract's address
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr: H160Flag,
|
||||
/// The path to the .json data file.
|
||||
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
|
||||
#[arg(short = 'D', long, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
},
|
||||
/// Swaps the positions in the transcript that correspond to commitments
|
||||
SwapProofCommitments {
|
||||
/// The path to the proof file
|
||||
@@ -875,7 +858,6 @@ pub enum Commands {
|
||||
#[command(name = "deploy-evm-da")]
|
||||
DeployEvmDataAttestation {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
/// You can also pass the input data as a string, eg. --data '{"input_data": [1.0,2.0,3.0]}' directly and skip the file
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<String>,
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
|
||||
690
src/eth.rs
690
src/eth.rs
File diff suppressed because one or more lines are too long
@@ -2,13 +2,10 @@ use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::eth::{
|
||||
deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_multi_sol,
|
||||
fix_da_single_sol,
|
||||
};
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_sol};
|
||||
#[allow(unused_imports)]
|
||||
use crate::eth::{get_contract_artifacts, verify_proof_via_solidity};
|
||||
use crate::graph::input::{Calls, GraphData};
|
||||
use crate::graph::input::GraphData;
|
||||
use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
|
||||
use crate::graph::{TestDataSource, TestSources};
|
||||
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
|
||||
@@ -301,7 +298,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
input_source,
|
||||
output_source,
|
||||
} => {
|
||||
setup_test_evm_witness(
|
||||
setup_test_evm_data(
|
||||
data.unwrap_or(DEFAULT_DATA.into()),
|
||||
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
|
||||
test_data,
|
||||
@@ -311,11 +308,6 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
Commands::TestUpdateAccountCalls {
|
||||
addr,
|
||||
data,
|
||||
rpc_url,
|
||||
} => test_update_account_calls(addr, data.unwrap_or(DEFAULT_DATA.into()), rpc_url).await,
|
||||
Commands::SwapProofCommitments {
|
||||
proof_path,
|
||||
witness_path,
|
||||
@@ -1540,50 +1532,28 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
let data =
|
||||
GraphData::from_str(&input).unwrap_or_else(|_| GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
debug!("data attestation data: {:?}", data);
|
||||
|
||||
// The number of input and output instances we attest to for the single call data attestation
|
||||
let mut input_len = None;
|
||||
let mut output_len = None;
|
||||
|
||||
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
|
||||
if let Some(DataSource::OnChain(source)) = data.output_data {
|
||||
if visibility.output.is_private() {
|
||||
return Err("private output data on chain is not supported on chain".into());
|
||||
}
|
||||
let mut on_chain_output_data = vec![];
|
||||
match source.calls {
|
||||
Calls::Multiple(calls) => {
|
||||
for call in calls {
|
||||
on_chain_output_data.push(call);
|
||||
}
|
||||
}
|
||||
Calls::Single(call) => {
|
||||
output_len = Some(call.len);
|
||||
}
|
||||
}
|
||||
Some(on_chain_output_data)
|
||||
} else {
|
||||
None
|
||||
output_len = Some(source.call.decimals.len());
|
||||
};
|
||||
|
||||
let input_data = if let DataSource::OnChain(source) = data.input_data {
|
||||
if let DataSource::OnChain(source) = data.input_data {
|
||||
if visibility.input.is_private() {
|
||||
return Err("private input data on chain is not supported on chain".into());
|
||||
}
|
||||
let mut on_chain_input_data = vec![];
|
||||
match source.calls {
|
||||
Calls::Multiple(calls) => {
|
||||
for call in calls {
|
||||
on_chain_input_data.push(call);
|
||||
}
|
||||
}
|
||||
Calls::Single(call) => {
|
||||
input_len = Some(call.len);
|
||||
}
|
||||
}
|
||||
Some(on_chain_input_data)
|
||||
} else {
|
||||
None
|
||||
input_len = Some(source.call.decimals.len());
|
||||
};
|
||||
|
||||
// If both model inputs and outputs are attested to then we
|
||||
|
||||
// Read the settings file. Look if either the run_ars.input_visibility, run_args.output_visibility or run_args.param_visibility is KZGCommit
|
||||
// if so, then we need to load the witness
|
||||
|
||||
@@ -1604,24 +1574,16 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
None
|
||||
};
|
||||
|
||||
// if either input_len or output_len is Some then we are in the single call data attestation mode
|
||||
if input_len.is_some() || output_len.is_some() {
|
||||
let output = fix_da_single_sol(input_len, output_len)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationSingle", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
} else {
|
||||
let output = fix_da_multi_sol(input_data, output_data, commitment_bytes)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestationMulti", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
}
|
||||
let output: String = fix_da_sol(
|
||||
commitment_bytes,
|
||||
input_len.is_none() && output_len.is_none(),
|
||||
)?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
@@ -1869,7 +1831,7 @@ pub(crate) fn setup(
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
pub(crate) async fn setup_test_evm_witness(
|
||||
pub(crate) async fn setup_test_evm_data(
|
||||
data_path: String,
|
||||
compiled_circuit_path: PathBuf,
|
||||
test_data: PathBuf,
|
||||
@@ -1905,17 +1867,6 @@ pub(crate) async fn setup_test_evm_witness(
|
||||
}
|
||||
|
||||
use crate::pfsys::ProofType;
|
||||
pub(crate) async fn test_update_account_calls(
|
||||
addr: H160Flag,
|
||||
data: String,
|
||||
rpc_url: Option<String>,
|
||||
) -> Result<String, EZKLError> {
|
||||
use crate::eth::update_account_calls;
|
||||
|
||||
update_account_calls(addr.into(), data, rpc_url.as_deref()).await?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn prove(
|
||||
|
||||
@@ -4,8 +4,6 @@ use crate::circuit::InputType;
|
||||
use crate::fieldutils::integer_rep_to_felt;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::postgres::Client;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::tensor::Tensor;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -168,85 +166,26 @@ impl<'de> Deserialize<'de> for FileSourceInner {
|
||||
/// Organized as a vector of vectors where each inner vector represents a row/entry
|
||||
pub type FileSource = Vec<Vec<FileSourceInner>>;
|
||||
|
||||
/// Represents different types of calls for fetching on-chain data
|
||||
#[derive(Clone, Debug, PartialOrd, PartialEq)]
|
||||
pub enum Calls {
|
||||
/// Multiple calls to different accounts, each returning individual values
|
||||
Multiple(Vec<CallsToAccount>),
|
||||
/// Single call returning an array of values
|
||||
Single(CallToAccount),
|
||||
}
|
||||
/// Represents which parts of the model (input/output) are attested to on-chain
|
||||
pub type InputOutput = (bool, bool);
|
||||
|
||||
impl Default for Calls {
|
||||
fn default() -> Self {
|
||||
Calls::Multiple(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Calls {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match self {
|
||||
Calls::Single(data) => data.serialize(serializer),
|
||||
Calls::Multiple(data) => data.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT
|
||||
// UNTAGGED ENUMS WONT WORK :( as highlighted here:
|
||||
impl<'de> Deserialize<'de> for Calls {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let this_json: Box<serde_json::value::RawValue> = Deserialize::deserialize(deserializer)?;
|
||||
let multiple_try: Result<Vec<CallsToAccount>, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = multiple_try {
|
||||
return Ok(Calls::Multiple(t));
|
||||
}
|
||||
let single_try: Result<CallToAccount, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = single_try {
|
||||
return Ok(Calls::Single(t));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize Calls"))
|
||||
}
|
||||
}
|
||||
/// Configuration for accessing on-chain data sources
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct OnChainSource {
|
||||
/// Call specifications for fetching data
|
||||
pub calls: Calls,
|
||||
pub call: CallToAccount,
|
||||
/// RPC endpoint URL for accessing the chain
|
||||
pub rpc: RPCUrl,
|
||||
}
|
||||
|
||||
impl OnChainSource {
|
||||
/// Creates a new OnChainSource with multiple calls
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `calls` - Vector of call specifications
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_multiple(calls: Vec<CallsToAccount>, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Multiple(calls),
|
||||
rpc,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new OnChainSource with a single call
|
||||
/// Creates a new OnChainSource
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `call` - Call specification
|
||||
/// * `rpc` - RPC endpoint URL
|
||||
pub fn new_single(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource {
|
||||
calls: Calls::Single(call),
|
||||
rpc,
|
||||
}
|
||||
pub fn new(call: CallToAccount, rpc: RPCUrl) -> Self {
|
||||
OnChainSource { call, rpc }
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
@@ -263,11 +202,8 @@ impl OnChainSource {
|
||||
scales: Vec<crate::Scale>,
|
||||
mut shapes: Vec<Vec<usize>>,
|
||||
rpc: Option<&str>,
|
||||
) -> Result<(Vec<Tensor<Fp>>, Self), GraphError> {
|
||||
use crate::eth::{
|
||||
evm_quantize_multi, read_on_chain_inputs_multi, test_on_chain_data,
|
||||
DEFAULT_ANVIL_ENDPOINT,
|
||||
};
|
||||
) -> Result<Self, GraphError> {
|
||||
use crate::eth::{read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT};
|
||||
use log::debug;
|
||||
|
||||
// Set up local anvil instance for reading on-chain data
|
||||
@@ -281,46 +217,15 @@ impl OnChainSource {
|
||||
shapes[idx] = vec![i.len()];
|
||||
}
|
||||
}
|
||||
|
||||
let calls_to_accounts = test_on_chain_data(client.clone(), data).await?;
|
||||
debug!("Calls to accounts: {:?}", calls_to_accounts);
|
||||
let inputs =
|
||||
read_on_chain_inputs_multi(client.clone(), client_address, &calls_to_accounts).await?;
|
||||
debug!("Inputs: {:?}", inputs);
|
||||
|
||||
let mut quantized_evm_inputs = vec![];
|
||||
|
||||
let mut prev = 0;
|
||||
for (idx, i) in data.iter().enumerate() {
|
||||
quantized_evm_inputs.extend(
|
||||
evm_quantize_multi(
|
||||
client.clone(),
|
||||
vec![scales[idx]; i.len()],
|
||||
&(
|
||||
inputs.0[prev..i.len()].to_vec(),
|
||||
inputs.1[prev..i.len()].to_vec(),
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
prev += i.len();
|
||||
}
|
||||
|
||||
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
|
||||
let mut inputs: Vec<Tensor<Fp>> = vec![];
|
||||
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
|
||||
let mut t: Tensor<Fp> = input.iter().cloned().collect();
|
||||
t.reshape(&shape)?;
|
||||
inputs.push(t);
|
||||
}
|
||||
|
||||
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
|
||||
|
||||
let call_to_account = test_on_chain_data(client.clone(), data).await?;
|
||||
debug!("Call to account: {:?}", call_to_account);
|
||||
let inputs = read_on_chain_inputs(client.clone(), client_address, &call_to_account).await?;
|
||||
debug!("Inputs: {:?}", inputs);
|
||||
|
||||
// Fill the input_data field of the GraphData struct
|
||||
Ok((
|
||||
inputs,
|
||||
OnChainSource::new_multiple(calls_to_accounts.clone(), used_rpc),
|
||||
))
|
||||
Ok(OnChainSource::new(call_to_account, used_rpc))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,11 +247,9 @@ pub struct CallToAccount {
|
||||
/// ABI-encoded function call data
|
||||
pub call_data: Call,
|
||||
/// Number of decimal places for float conversion
|
||||
pub decimals: Decimals,
|
||||
pub decimals: Vec<Decimals>,
|
||||
/// Contract address to call
|
||||
pub address: String,
|
||||
/// Expected length of returned array
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
/// Represents different sources of input/output data for the EZKL model
|
||||
@@ -683,17 +586,6 @@ impl GraphData {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallsToAccount {
|
||||
/// Converts CallsToAccount to Python object
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("account", &self.address).unwrap();
|
||||
dict.set_item("call_data", &self.call_data).unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
// Additional Python bindings for various types...
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -860,21 +752,10 @@ impl ToPyObject for CallToAccount {
|
||||
dict.set_item("account", &self.address).unwrap();
|
||||
dict.set_item("call_data", &self.call_data).unwrap();
|
||||
dict.set_item("decimals", &self.decimals).unwrap();
|
||||
dict.set_item("len", &self.len).unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for Calls {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
match self {
|
||||
Calls::Multiple(calls) => calls.to_object(py),
|
||||
Calls::Single(call) => call.to_object(py),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for DataSource {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -883,7 +764,7 @@ impl ToPyObject for DataSource {
|
||||
DataSource::OnChain(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("rpc_url", &source.rpc).unwrap();
|
||||
dict.set_item("calls_to_accounts", &source.calls.to_object(py))
|
||||
dict.set_item("calls_to_accounts", &source.call.to_object(py))
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
|
||||
@@ -1026,24 +1026,11 @@ impl GraphCircuit {
|
||||
shapes: &Vec<Vec<usize>>,
|
||||
scales: Vec<crate::Scale>,
|
||||
) -> Result<Vec<Tensor<Fp>>, GraphError> {
|
||||
use crate::eth::{
|
||||
evm_quantize_multi, evm_quantize_single, read_on_chain_inputs_multi,
|
||||
read_on_chain_inputs_single, setup_eth_backend,
|
||||
};
|
||||
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
|
||||
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
|
||||
let quantized_evm_inputs = match source.calls {
|
||||
input::Calls::Single(call) => {
|
||||
let (inputs, decimals) =
|
||||
read_on_chain_inputs_single(client.clone(), client_address, call).await?;
|
||||
|
||||
evm_quantize_single(client, scales, &inputs, decimals).await?
|
||||
}
|
||||
input::Calls::Multiple(calls) => {
|
||||
let inputs =
|
||||
read_on_chain_inputs_multi(client.clone(), client_address, &calls).await?;
|
||||
evm_quantize_multi(client, scales, &inputs).await?
|
||||
}
|
||||
};
|
||||
let input = read_on_chain_inputs(client.clone(), client_address, &source.call).await?;
|
||||
let quantized_evm_inputs =
|
||||
evm_quantize(client, scales, &input, &source.call.decimals).await?;
|
||||
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
|
||||
let mut inputs: Vec<Tensor<Fp>> = vec![];
|
||||
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
|
||||
@@ -1444,6 +1431,8 @@ impl GraphCircuit {
|
||||
let output_scales = self.model().graph.get_output_scales()?;
|
||||
let input_shapes = self.model().graph.input_shapes()?;
|
||||
let output_shapes = self.model().graph.output_shapes()?;
|
||||
let mut input_data = None;
|
||||
let mut output_data = None;
|
||||
|
||||
if matches!(
|
||||
test_on_chain_data.data_sources.input,
|
||||
@@ -1454,23 +1443,12 @@ impl GraphCircuit {
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
|
||||
let input_data = match &data.input_data {
|
||||
DataSource::File(input_data) => input_data,
|
||||
input_data = match &data.input_data {
|
||||
DataSource::File(input_data) => Some(input_data),
|
||||
_ => {
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
return Err(GraphError::MissingDataSource);
|
||||
}
|
||||
};
|
||||
// Get the flatten length of input_data
|
||||
// if the input source is a field then set scale to 0
|
||||
|
||||
let datam: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
|
||||
input_data,
|
||||
input_scales,
|
||||
input_shapes,
|
||||
test_on_chain_data.rpc.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
data.input_data = datam.1.into();
|
||||
}
|
||||
if matches!(
|
||||
test_on_chain_data.data_sources.output,
|
||||
@@ -1481,20 +1459,43 @@ impl GraphCircuit {
|
||||
return Err(GraphError::OnChainDataSource);
|
||||
}
|
||||
|
||||
let output_data = match &data.output_data {
|
||||
Some(DataSource::File(output_data)) => output_data,
|
||||
Some(DataSource::OnChain(_)) => return Err(GraphError::OnChainDataSource),
|
||||
output_data = match &data.output_data {
|
||||
Some(DataSource::File(output_data)) => Some(output_data),
|
||||
_ => return Err(GraphError::MissingDataSource),
|
||||
};
|
||||
let datum: (Vec<Tensor<Fp>>, OnChainSource) = OnChainSource::test_from_file_data(
|
||||
output_data,
|
||||
output_scales,
|
||||
output_shapes,
|
||||
test_on_chain_data.rpc.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
data.output_data = Some(datum.1.into());
|
||||
}
|
||||
// Merge the input and output data
|
||||
let mut file_data: Vec<Vec<input::FileSourceInner>> = vec![];
|
||||
let mut scales: Vec<crate::Scale> = vec![];
|
||||
let mut shapes: Vec<Vec<usize>> = vec![];
|
||||
if let Some(input_data) = input_data {
|
||||
file_data.extend(input_data.clone());
|
||||
scales.extend(input_scales.clone());
|
||||
shapes.extend(input_shapes.clone());
|
||||
}
|
||||
if let Some(output_data) = output_data {
|
||||
file_data.extend(output_data.clone());
|
||||
scales.extend(output_scales.clone());
|
||||
shapes.extend(output_shapes.clone());
|
||||
};
|
||||
// print file data
|
||||
debug!("file data: {:?}", file_data);
|
||||
|
||||
let on_chain_data: OnChainSource = OnChainSource::test_from_file_data(
|
||||
&file_data,
|
||||
scales,
|
||||
shapes,
|
||||
test_on_chain_data.rpc.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
// Here we update the GraphData struct with the on-chain data
|
||||
if input_data.is_some() {
|
||||
data.input_data = on_chain_data.clone().into();
|
||||
}
|
||||
if output_data.is_some() {
|
||||
data.output_data = Some(on_chain_data.into());
|
||||
}
|
||||
debug!("test on-chain data: {:?}", data);
|
||||
// Save the updated GraphData struct to the data_path
|
||||
data.save(test_on_chain_data.data)?;
|
||||
Ok(())
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
use super::errors::GraphError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
use super::errors::GraphError;
|
||||
use super::{Rescaled, SupportedOp, Visibility};
|
||||
use crate::circuit::Op;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -22,6 +22,7 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
Downsample,
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
@@ -31,7 +32,6 @@ use tract_onnx::tract_core::ops::{
|
||||
einsum::EinSum,
|
||||
element_wise::ElementWiseOp,
|
||||
nn::{LeakyRelu, Reduce, Softmax},
|
||||
Downsample,
|
||||
};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::{
|
||||
@@ -1398,7 +1398,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::Downsample {
|
||||
axis: downsample_node.axis,
|
||||
stride: downsample_node.stride as usize,
|
||||
stride: downsample_node.stride,
|
||||
modulo: downsample_node.modulo,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use clap::ValueEnum;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
|
||||
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
|
||||
use halo2curves::CurveAffine;
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
|
||||
use halo2curves::serde::SerdeObject;
|
||||
use halo2curves::CurveAffine;
|
||||
use instant::Instant;
|
||||
use log::{debug, info, trace};
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
@@ -51,6 +51,9 @@ use pyo3::types::PyDictMethods;
|
||||
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
|
||||
/// Converts a string to a `SerdeFormat`.
|
||||
/// # Panics
|
||||
/// Panics if the provided `s` is not a valid `SerdeFormat` (i.e. not one of "processed", "raw-bytes-unchecked", or "raw-bytes").
|
||||
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
match s {
|
||||
"processed" => halo2_proofs::SerdeFormat::Processed,
|
||||
@@ -321,7 +324,7 @@ where
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
|
||||
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
|
||||
where
|
||||
@@ -345,9 +348,9 @@ where
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
|
||||
C: CurveAffine + Serialize + DeserializeOwned,
|
||||
> Snark<F, C>
|
||||
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
|
||||
C: CurveAffine + Serialize + DeserializeOwned,
|
||||
> Snark<F, C>
|
||||
where
|
||||
C::Scalar: Serialize + DeserializeOwned,
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
|
||||
@@ -27,7 +27,7 @@ pub use var::*;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
fieldutils::{IntegerRep, integer_rep_to_felt},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -415,7 +415,7 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
Err(_) => {
|
||||
return Err(TensorError::FileLoadError(
|
||||
"Failed to read tensor".to_string(),
|
||||
))
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -926,6 +926,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
));
|
||||
}
|
||||
self.dims = vec![];
|
||||
}
|
||||
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
|
||||
self.dims = Vec::from(new_dims);
|
||||
} else {
|
||||
let product = if new_dims != [0] {
|
||||
new_dims.iter().product::<usize>()
|
||||
@@ -1104,6 +1107,10 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
let mut output = self.clone();
|
||||
output.reshape(shape)?;
|
||||
return Ok(output);
|
||||
} else if self.dims() == &[0] && shape.iter().product::<usize>() == 1 {
|
||||
let mut output = self.clone();
|
||||
output.reshape(shape)?;
|
||||
return Ok(output);
|
||||
}
|
||||
|
||||
if self.dims().len() > shape.len() {
|
||||
@@ -1254,7 +1261,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get last element of empty tensor".to_string(),
|
||||
))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1279,7 +1286,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get first element of empty tensor".to_string(),
|
||||
))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1692,8 +1699,8 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
|
||||
lhs.par_iter_mut()
|
||||
.zip(rhs)
|
||||
.map(|(o, r)| {
|
||||
match T::zero() { Some(zero) => {
|
||||
.map(|(o, r)| match T::zero() {
|
||||
Some(zero) => {
|
||||
if r != zero {
|
||||
*o = o.clone() % r;
|
||||
Ok(())
|
||||
@@ -1702,11 +1709,10 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync + P
|
||||
"Cannot divide by zero in remainder".to_string(),
|
||||
))
|
||||
}
|
||||
} _ => {
|
||||
Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
))
|
||||
}}
|
||||
}
|
||||
_ => Err(TensorError::InvalidArgument(
|
||||
"Undefined zero value".to_string(),
|
||||
)),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
|
||||
@@ -535,30 +535,101 @@ pub fn mult<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::S
|
||||
/// let result = downsample(&x, 1, 2, 2).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[3, 6]), &[2, 1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// // Test case 1: Negative stride along dimension 0
|
||||
/// // This should flip the order along dimension 0
|
||||
/// let result = downsample(&x, 0, -1, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 5, 6, 1, 2, 3]), // Flipped order of rows
|
||||
/// &[2, 3]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 2: Negative stride along dimension 1
|
||||
/// // This should flip the order along dimension 1
|
||||
/// let result = downsample(&x, 1, -1, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 2, 1, 6, 5, 4]), // Flipped order of columns
|
||||
/// &[2, 3]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 3: Negative stride with stride magnitude > 1
|
||||
/// // This should both skip and flip
|
||||
/// let result = downsample(&x, 1, -2, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 1, 6, 4]), // Take every 2nd element in reverse
|
||||
/// &[2, 2]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Test case 4: Negative stride with non-zero modulo
|
||||
/// // This should start at (size - 1 - modulo) and reverse
|
||||
/// let result = downsample(&x, 1, -2, 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 5]), // Start at second element from end, take every 2nd in reverse
|
||||
/// &[2, 1]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Create a larger test case for more complex downsampling
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
|
||||
/// &[3, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// // Test case 5: Negative stride with modulo on larger tensor
|
||||
/// let result = downsample(&y, 1, -2, 1).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[3, 1, 7, 5, 11, 9]), // Start at one after reverse, take every 2nd
|
||||
/// &[3, 2]
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn downsample<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
dim: usize,
|
||||
stride: usize,
|
||||
stride: isize, // Changed from usize to isize to support negative strides
|
||||
modulo: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let mut output_shape = input.dims().to_vec();
|
||||
// now downsample along axis dim offset by modulo, rounding up (+1 if remaidner is non-zero)
|
||||
let remainder = (input.dims()[dim] - modulo) % stride;
|
||||
let div = (input.dims()[dim] - modulo) / stride;
|
||||
output_shape[dim] = div + (remainder > 0) as usize;
|
||||
let mut output = Tensor::<T>::new(None, &output_shape)?;
|
||||
// Handle negative stride case
|
||||
if stride == 0 {
|
||||
return Err(TensorError::DimMismatch(
|
||||
"downsample stride cannot be zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if modulo > input.dims()[dim] {
|
||||
let stride_abs = stride.unsigned_abs();
|
||||
let mut output_shape = input.dims().to_vec();
|
||||
|
||||
if modulo >= input.dims()[dim] {
|
||||
return Err(TensorError::DimMismatch("downsample".to_string()));
|
||||
}
|
||||
|
||||
// now downsample along axis dim offset by modulo
|
||||
// Calculate output shape based on the absolute value of stride
|
||||
let remainder = (input.dims()[dim] - modulo) % stride_abs;
|
||||
let div = (input.dims()[dim] - modulo) / stride_abs;
|
||||
output_shape[dim] = div + (remainder > 0) as usize;
|
||||
|
||||
let mut output = Tensor::<T>::new(None, &output_shape)?;
|
||||
|
||||
// Calculate indices based on stride direction
|
||||
let indices = (0..output_shape.len())
|
||||
.map(|i| {
|
||||
if i == dim {
|
||||
let mut index = vec![0; output_shape[i]];
|
||||
for (i, idx) in index.iter_mut().enumerate() {
|
||||
*idx = i * stride + modulo;
|
||||
for (j, idx) in index.iter_mut().enumerate() {
|
||||
if stride > 0 {
|
||||
// Positive stride: move forward from modulo
|
||||
*idx = j * stride_abs + modulo;
|
||||
} else {
|
||||
// Negative stride: move backward from (size - 1 - modulo)
|
||||
*idx = (input.dims()[dim] - 1 - modulo) - j * stride_abs;
|
||||
}
|
||||
}
|
||||
index
|
||||
} else {
|
||||
@@ -2275,7 +2346,11 @@ pub mod nonlinearities {
|
||||
pub fn recip(a: &Tensor<IntegerRep>, input_scale: f64, out_scale: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rescaled = (a_i as f64) / input_scale;
|
||||
let denom = (1_f64) / (rescaled + f64::EPSILON);
|
||||
let denom = if rescaled == 0_f64 {
|
||||
(1_f64) / (rescaled + f64::EPSILON)
|
||||
} else {
|
||||
(1_f64) / (rescaled)
|
||||
};
|
||||
let d_inv_x = out_scale * denom;
|
||||
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
|
||||
})
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashSet;
|
||||
|
||||
use log::{debug, error, warn};
|
||||
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
use crate::circuit::{CheckMode, region::ConstantsMap};
|
||||
|
||||
use super::*;
|
||||
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
|
||||
@@ -403,7 +403,10 @@ impl VarTensor {
|
||||
let mut assigned_coord = 0;
|
||||
let mut res: ValTensor<F> = match values {
|
||||
ValTensor::Instance { .. } => {
|
||||
unimplemented!("cannot assign instance to advice columns with omissions")
|
||||
error!(
|
||||
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
}
|
||||
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
|
||||
v.enum_map(|coord, k| {
|
||||
@@ -569,8 +572,13 @@ impl VarTensor {
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
match values {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
ValTensor::Instance { .. } => {
|
||||
error!(
|
||||
"duplication is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
}
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
let duplication_freq = if single_inner_col {
|
||||
self.col_size()
|
||||
} else {
|
||||
@@ -583,21 +591,20 @@ impl VarTensor {
|
||||
self.num_inner_cols()
|
||||
};
|
||||
|
||||
let duplication_offset = if single_inner_col {
|
||||
row
|
||||
} else {
|
||||
offset
|
||||
};
|
||||
|
||||
let duplication_offset = if single_inner_col { row } else { offset };
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
|
||||
let mut res: ValTensor<F> = v
|
||||
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.unwrap()
|
||||
.into();
|
||||
|
||||
let constants_map = res.create_constants_map();
|
||||
constants.extend(constants_map);
|
||||
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.set_scale(values.scale());
|
||||
@@ -627,9 +634,13 @@ impl VarTensor {
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
match values {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
|
||||
ValTensor::Instance { .. } => {
|
||||
error!(
|
||||
"duplication is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
}
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
let duplication_freq = self.block_size();
|
||||
|
||||
let num_repeats = self.num_inner_cols();
|
||||
@@ -637,17 +648,31 @@ impl VarTensor {
|
||||
let duplication_offset = offset;
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
let v = v
|
||||
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.map_err(|e| {
|
||||
error!("Error duplicating values: {:?}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
let mut res: ValTensor<F> = {
|
||||
v.enum_map(|coord, k| {
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord, constants)?;
|
||||
Ok::<_, halo2_proofs::plonk::Error>(cell)
|
||||
|
||||
})?.into()};
|
||||
let cell =
|
||||
self.assign_value(region, offset, k.clone(), coord, constants)?;
|
||||
Ok::<_, halo2_proofs::plonk::Error>(cell)
|
||||
})?
|
||||
.into()
|
||||
};
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.map_err(|e| {
|
||||
error!("Error duplicating values: {:?}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.reshape(dims).map_err(|e| {
|
||||
error!("Error duplicating values: {:?}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
res.set_scale(values.scale());
|
||||
|
||||
Ok((res, total_used_len))
|
||||
@@ -681,61 +706,71 @@ impl VarTensor {
|
||||
let mut prev_cell = None;
|
||||
|
||||
match values {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
|
||||
ValTensor::Instance { .. } => {
|
||||
error!(
|
||||
"duplication is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
}
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
let duplication_freq = self.col_size();
|
||||
let num_repeats = 1;
|
||||
let duplication_offset = row;
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let v = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
let mut res: ValTensor<F> =
|
||||
v.enum_map(|coord, k| {
|
||||
let v = v
|
||||
.duplicate_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.unwrap();
|
||||
let mut res: ValTensor<F> = v
|
||||
.enum_map(|coord, k| {
|
||||
let step = self.num_inner_cols();
|
||||
|
||||
let step = self.num_inner_cols();
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord * step);
|
||||
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
|
||||
// assert that duplication occurred correctly
|
||||
assert_eq!(
|
||||
Into::<IntegerRep>::into(k.clone()),
|
||||
Into::<IntegerRep>::into(v[coord - 1].clone())
|
||||
);
|
||||
};
|
||||
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord * step);
|
||||
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
|
||||
// assert that duplication occurred correctly
|
||||
assert_eq!(Into::<IntegerRep>::into(k.clone()), Into::<IntegerRep>::into(v[coord - 1].clone()));
|
||||
};
|
||||
let cell =
|
||||
self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
let at_end_of_column = z == duplication_freq - 1;
|
||||
let at_beginning_of_column = z == 0;
|
||||
|
||||
let at_end_of_column = z == duplication_freq - 1;
|
||||
let at_beginning_of_column = z == 0;
|
||||
|
||||
if at_end_of_column {
|
||||
// if we are at the end of the column, we need to copy the cell to the next column
|
||||
prev_cell = Some(cell.clone());
|
||||
} else if coord > 0 && at_beginning_of_column {
|
||||
if let Some(prev_cell) = prev_cell.as_ref() {
|
||||
let cell = if let Some(cell) = cell.cell() {
|
||||
cell
|
||||
if at_end_of_column {
|
||||
// if we are at the end of the column, we need to copy the cell to the next column
|
||||
prev_cell = Some(cell.clone());
|
||||
} else if coord > 0 && at_beginning_of_column {
|
||||
if let Some(prev_cell) = prev_cell.as_ref() {
|
||||
let cell = if let Some(cell) = cell.cell() {
|
||||
cell
|
||||
} else {
|
||||
error!("Error getting cell: {:?}", (x, y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
|
||||
prev_cell
|
||||
} else {
|
||||
error!("Error getting prev cell: {:?}", (x, y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
region.constrain_equal(prev_cell, cell)?;
|
||||
} else {
|
||||
error!("Error getting cell: {:?}", (x,y));
|
||||
error!("Previous cell was not set");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
let prev_cell = if let Some(prev_cell) = prev_cell.cell() {
|
||||
prev_cell
|
||||
} else {
|
||||
error!("Error getting prev cell: {:?}", (x,y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
};
|
||||
region.constrain_equal(prev_cell,cell)?;
|
||||
} else {
|
||||
error!("Previous cell was not set");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(cell)
|
||||
|
||||
})?.into();
|
||||
Ok(cell)
|
||||
})?
|
||||
.into();
|
||||
|
||||
let total_used_len = res.len();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset)
|
||||
.unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.set_scale(values.scale());
|
||||
@@ -771,21 +806,30 @@ impl VarTensor {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
_ => {
|
||||
error!("VarTensor was not initialized");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
},
|
||||
// Handle copying previously assigned value
|
||||
ValType::PrevAssigned(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
_ => {
|
||||
error!("VarTensor was not initialized");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
},
|
||||
// Handle copying previously assigned constant
|
||||
ValType::AssignedConstant(v, val) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
_ => {
|
||||
error!("VarTensor was not initialized");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
},
|
||||
// Handle assigning evaluated value
|
||||
ValType::AssignedValue(v) => match &self {
|
||||
@@ -794,7 +838,10 @@ impl VarTensor {
|
||||
.assign_advice(|| "k", advices[x][y], z, || v)?
|
||||
.evaluate(),
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
_ => {
|
||||
error!("VarTensor was not initialized");
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
},
|
||||
// Handle constant value assignment with caching
|
||||
ValType::Constant(v) => {
|
||||
|
||||
14
tests/foundry/.gitignore
vendored
Normal file
14
tests/foundry/.gitignore
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# Compiler files
|
||||
cache/
|
||||
out/
|
||||
|
||||
# Ignores development broadcast logs
|
||||
!/broadcast
|
||||
/broadcast/*/31337/
|
||||
/broadcast/**/dry-run/
|
||||
|
||||
# Docs
|
||||
docs/
|
||||
|
||||
# Dotenv file
|
||||
.env
|
||||
66
tests/foundry/README.md
Normal file
66
tests/foundry/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
## Foundry
|
||||
|
||||
**Foundry is a blazing fast, portable and modular toolkit for Ethereum application development written in Rust.**
|
||||
|
||||
Foundry consists of:
|
||||
|
||||
- **Forge**: Ethereum testing framework (like Truffle, Hardhat and DappTools).
|
||||
- **Cast**: Swiss army knife for interacting with EVM smart contracts, sending transactions and getting chain data.
|
||||
- **Anvil**: Local Ethereum node, akin to Ganache, Hardhat Network.
|
||||
- **Chisel**: Fast, utilitarian, and verbose solidity REPL.
|
||||
|
||||
## Documentation
|
||||
|
||||
https://book.getfoundry.sh/
|
||||
|
||||
## Usage
|
||||
|
||||
### Build
|
||||
|
||||
```shell
|
||||
$ forge build
|
||||
```
|
||||
|
||||
### Test
|
||||
|
||||
```shell
|
||||
$ forge test
|
||||
```
|
||||
|
||||
### Format
|
||||
|
||||
```shell
|
||||
$ forge fmt
|
||||
```
|
||||
|
||||
### Gas Snapshots
|
||||
|
||||
```shell
|
||||
$ forge snapshot
|
||||
```
|
||||
|
||||
### Anvil
|
||||
|
||||
```shell
|
||||
$ anvil
|
||||
```
|
||||
|
||||
### Deploy
|
||||
|
||||
```shell
|
||||
$ forge script script/Counter.s.sol:CounterScript --rpc-url <your_rpc_url> --private-key <your_private_key>
|
||||
```
|
||||
|
||||
### Cast
|
||||
|
||||
```shell
|
||||
$ cast <subcommand>
|
||||
```
|
||||
|
||||
### Help
|
||||
|
||||
```shell
|
||||
$ forge --help
|
||||
$ anvil --help
|
||||
$ cast --help
|
||||
```
|
||||
6
tests/foundry/foundry.toml
Normal file
6
tests/foundry/foundry.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[profile.default]
|
||||
src = "../../contracts"
|
||||
out = "out"
|
||||
libs = ["lib"]
|
||||
|
||||
# See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options
|
||||
1
tests/foundry/remappings.txt
Normal file
1
tests/foundry/remappings.txt
Normal file
@@ -0,0 +1 @@
|
||||
contracts/=../../contracts/
|
||||
429
tests/foundry/test/AttestData.t.sol
Normal file
429
tests/foundry/test/AttestData.t.sol
Normal file
@@ -0,0 +1,429 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
pragma solidity ^0.8.20;
|
||||
|
||||
import "forge-std/Test.sol";
|
||||
import {console} from "forge-std/console.sol";
|
||||
import "contracts/AttestData.sol" as AttestData;
|
||||
|
||||
contract MockVKA {
|
||||
constructor() {}
|
||||
}
|
||||
|
||||
contract MockVerifier {
|
||||
bool public shouldVerify;
|
||||
|
||||
constructor(bool _shouldVerify) {
|
||||
shouldVerify = _shouldVerify;
|
||||
}
|
||||
|
||||
function verifyProof(
|
||||
bytes calldata,
|
||||
uint256[] calldata
|
||||
) external view returns (bool) {
|
||||
require(shouldVerify, "Verification failed");
|
||||
return shouldVerify;
|
||||
}
|
||||
}
|
||||
|
||||
contract MockVerifierSeperate {
|
||||
bool public shouldVerify;
|
||||
|
||||
constructor(bool _shouldVerify) {
|
||||
shouldVerify = _shouldVerify;
|
||||
}
|
||||
|
||||
function verifyProof(
|
||||
address,
|
||||
bytes calldata,
|
||||
uint256[] calldata
|
||||
) external view returns (bool) {
|
||||
require(shouldVerify, "Verification failed");
|
||||
return shouldVerify;
|
||||
}
|
||||
}
|
||||
|
||||
contract MockTargetContract {
|
||||
int256[] public data;
|
||||
|
||||
constructor(int256[] memory _data) {
|
||||
data = _data;
|
||||
}
|
||||
|
||||
function setData(int256[] memory _data) external {
|
||||
data = _data;
|
||||
}
|
||||
|
||||
function getData() external view returns (int256[] memory) {
|
||||
return data;
|
||||
}
|
||||
}
|
||||
|
||||
contract DataAttestationTest is Test {
|
||||
AttestData.DataAttestation das;
|
||||
MockVerifier verifier;
|
||||
MockVerifierSeperate verifierSeperate;
|
||||
MockVKA vka;
|
||||
MockTargetContract target;
|
||||
int256[] mockData = [int256(1e18), -int256(5e17)];
|
||||
uint256[] decimals = [18, 18];
|
||||
uint256[] bits = [13, 13];
|
||||
uint8 instanceOffset = 0;
|
||||
bytes callData;
|
||||
|
||||
function setUp() public {
|
||||
target = new MockTargetContract(mockData);
|
||||
verifier = new MockVerifier(true);
|
||||
verifierSeperate = new MockVerifierSeperate(true);
|
||||
vka = new MockVKA();
|
||||
|
||||
callData = abi.encodeWithSignature("getData()");
|
||||
|
||||
das = new AttestData.DataAttestation(
|
||||
address(target),
|
||||
callData,
|
||||
decimals,
|
||||
bits,
|
||||
instanceOffset
|
||||
);
|
||||
}
|
||||
|
||||
// Fork of mulDivRound which doesn't revert on overflow and returns a boolean instead to indicate overflow
|
||||
function mulDivRound(
|
||||
uint256 x,
|
||||
uint256 y,
|
||||
uint256 denominator
|
||||
) public pure returns (uint256 result, bool overflow) {
|
||||
unchecked {
|
||||
uint256 prod0;
|
||||
uint256 prod1;
|
||||
assembly {
|
||||
let mm := mulmod(x, y, not(0))
|
||||
prod0 := mul(x, y)
|
||||
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
|
||||
}
|
||||
uint256 remainder = mulmod(x, y, denominator);
|
||||
bool addOne;
|
||||
if (remainder * 2 >= denominator) {
|
||||
addOne = true;
|
||||
}
|
||||
|
||||
if (prod1 == 0) {
|
||||
if (addOne) {
|
||||
return ((prod0 / denominator) + 1, false);
|
||||
}
|
||||
return (prod0 / denominator, false);
|
||||
}
|
||||
|
||||
if (denominator > prod1) {
|
||||
return (0, true);
|
||||
}
|
||||
|
||||
assembly {
|
||||
prod1 := sub(prod1, gt(remainder, prod0))
|
||||
prod0 := sub(prod0, remainder)
|
||||
}
|
||||
|
||||
uint256 twos = denominator & (~denominator + 1);
|
||||
assembly {
|
||||
denominator := div(denominator, twos)
|
||||
prod0 := div(prod0, twos)
|
||||
twos := add(div(sub(0, twos), twos), 1)
|
||||
}
|
||||
|
||||
prod0 |= prod1 * twos;
|
||||
|
||||
uint256 inverse = (3 * denominator) ^ 2;
|
||||
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
inverse *= 2 - denominator * inverse;
|
||||
|
||||
result = prod0 * inverse;
|
||||
if (addOne) {
|
||||
result += 1;
|
||||
}
|
||||
return (result, false);
|
||||
}
|
||||
}
|
||||
struct SampleAttestation {
|
||||
int256 mockData;
|
||||
uint8 decimals;
|
||||
uint8 bits;
|
||||
}
|
||||
function test_fuzzAttestedData(
|
||||
SampleAttestation[] memory _attestations
|
||||
) public {
|
||||
vm.assume(_attestations.length == 1);
|
||||
int256[] memory _mockData = new int256[](1);
|
||||
uint256[] memory _decimals = new uint256[](1);
|
||||
uint256[] memory _bits = new uint256[](1);
|
||||
uint256[] memory _instances = new uint256[](1);
|
||||
for (uint256 i = 0; i < 1; i++) {
|
||||
SampleAttestation memory attestation = _attestations[i];
|
||||
_mockData[i] = attestation.mockData;
|
||||
vm.assume(attestation.mockData != type(int256).min); /// Will overflow int256 during negation op
|
||||
vm.assume(attestation.decimals < 77); /// Else will exceed uint256 bounds
|
||||
vm.assume(attestation.bits < 128); /// Else will exceed EZKL fixed point bounds for int128 type
|
||||
bool neg = attestation.mockData < 0;
|
||||
if (neg) {
|
||||
attestation.mockData = -attestation.mockData;
|
||||
}
|
||||
(uint256 _result, bool overflow) = mulDivRound(
|
||||
uint256(attestation.mockData),
|
||||
uint256(1 << attestation.bits),
|
||||
uint256(10 ** attestation.decimals)
|
||||
);
|
||||
vm.assume(!overflow);
|
||||
vm.assume(_result < das.HALF_ORDER());
|
||||
if (neg) {
|
||||
// No possibility of overflow here since output is less than or equal to HALF_ORDER
|
||||
// and therefore falls within the max range of int256 without overflow
|
||||
vm.assume(-int256(_result) > type(int128).min);
|
||||
_instances[i] =
|
||||
uint256(int(das.ORDER()) - int256(_result)) %
|
||||
das.ORDER();
|
||||
} else {
|
||||
vm.assume(_result < uint128(type(int128).max));
|
||||
_instances[i] = _result;
|
||||
}
|
||||
_decimals[i] = attestation.decimals;
|
||||
_bits[i] = attestation.bits;
|
||||
}
|
||||
// Update the attested data
|
||||
target.setData(_mockData);
|
||||
// Deploy the new data attestation contract
|
||||
AttestData.DataAttestation dasNew = new AttestData.DataAttestation(
|
||||
address(target),
|
||||
callData,
|
||||
_decimals,
|
||||
_bits,
|
||||
instanceOffset
|
||||
);
|
||||
bytes memory proof = hex"1234"; // Would normally contain commitments
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
proof,
|
||||
_instances
|
||||
);
|
||||
|
||||
AttestData.DataAttestation.Scalars memory _scalars = AttestData
|
||||
.DataAttestation
|
||||
.Scalars(10 ** _decimals[0], 1 << _bits[0]);
|
||||
|
||||
int256 output = dasNew.quantizeData(_mockData[0], _scalars);
|
||||
console.log("output: ", output);
|
||||
uint256 fieldElement = dasNew.toFieldElement(output);
|
||||
// output should equal to _instances[0]
|
||||
assertEq(fieldElement, _instances[0]);
|
||||
|
||||
bool verificationResult = dasNew.verifyWithDataAttestation(
|
||||
address(verifier),
|
||||
encoded
|
||||
);
|
||||
assertTrue(verificationResult);
|
||||
}
|
||||
|
||||
// Test deployment parameters
|
||||
function testDeployment() public view {
|
||||
assertEq(das.contractAddress(), address(target));
|
||||
assertEq(das.callData(), abi.encodeWithSignature("getData()"));
|
||||
assertEq(das.instanceOffset(), instanceOffset);
|
||||
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
assertEq(scalar.decimals, 1e18);
|
||||
assertEq(scalar.bits, 1 << 13);
|
||||
}
|
||||
|
||||
// Test quantizeData function
|
||||
function testQuantizeData() public view {
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
|
||||
int256 positive = das.quantizeData(1e18, scalar);
|
||||
assertEq(positive, int256(scalar.bits));
|
||||
|
||||
int256 negative = das.quantizeData(-1e18, scalar);
|
||||
assertEq(negative, -int256(scalar.bits));
|
||||
|
||||
// Test rounding
|
||||
int half = int(0.5e18 / scalar.bits);
|
||||
int256 rounded = das.quantizeData(half, scalar);
|
||||
assertEq(rounded, 1);
|
||||
}
|
||||
|
||||
// Test staticCall functionality
|
||||
function testStaticCall() public view {
|
||||
bytes memory result = das.staticCall(
|
||||
address(target),
|
||||
abi.encodeWithSignature("getData()")
|
||||
);
|
||||
int256[] memory decoded = abi.decode(result, (int256[]));
|
||||
assertEq(decoded[0], mockData[0]);
|
||||
assertEq(decoded[1], mockData[1]);
|
||||
}
|
||||
|
||||
// Test attestData validation
|
||||
function testAttestDataSuccess() public view {
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
das.attestData(instances); // Should not revert
|
||||
}
|
||||
|
||||
function testAttestDataFailure() public {
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
instances[0] = das.toFieldElement(1e18); // Incorrect value
|
||||
instances[1] = das.toFieldElement(5e17);
|
||||
|
||||
vm.expectRevert("Public input does not match");
|
||||
das.attestData(instances);
|
||||
}
|
||||
|
||||
// Test full verification flow
|
||||
function testSuccessfulVerification() public view {
|
||||
// Prepare valid instances
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
|
||||
// Create valid calldata (mock)
|
||||
bytes memory proof = hex"1234"; // Would normally contain commitments
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
proof,
|
||||
instances
|
||||
);
|
||||
bytes memory encoded_vka = abi.encodeWithSignature(
|
||||
"verifyProof(address,bytes,uint256[])",
|
||||
address(vka),
|
||||
proof,
|
||||
instances
|
||||
);
|
||||
|
||||
bool result = das.verifyWithDataAttestation(address(verifier), encoded);
|
||||
assertTrue(result);
|
||||
result = das.verifyWithDataAttestation(
|
||||
address(verifierSeperate),
|
||||
encoded_vka
|
||||
);
|
||||
assertTrue(result);
|
||||
}
|
||||
|
||||
function testLoadInstances() public view {
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
|
||||
// Create valid calldata (mock)
|
||||
bytes memory proof = hex"1234"; // Would normally contain commitments
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
proof,
|
||||
instances
|
||||
);
|
||||
bytes memory encoded_vka = abi.encodeWithSignature(
|
||||
"verifyProof(address,bytes,uint256[])",
|
||||
address(vka),
|
||||
proof,
|
||||
instances
|
||||
);
|
||||
|
||||
// Load encoded instances from calldata
|
||||
uint256[] memory extracted_instances_calldata = das
|
||||
.getInstancesCalldata(encoded);
|
||||
assertEq(extracted_instances_calldata[0], instances[0]);
|
||||
assertEq(extracted_instances_calldata[1], instances[1]);
|
||||
// Load encoded instances from memory
|
||||
uint256[] memory extracted_instances_memory = das.getInstancesMemory(
|
||||
encoded
|
||||
);
|
||||
assertEq(extracted_instances_memory[0], instances[0]);
|
||||
assertEq(extracted_instances_memory[1], instances[1]);
|
||||
// Load encoded with vk instances from calldata
|
||||
uint256[] memory extracted_instances_calldata_vk = das
|
||||
.getInstancesCalldata(encoded_vka);
|
||||
assertEq(extracted_instances_calldata_vk[0], instances[0]);
|
||||
assertEq(extracted_instances_calldata_vk[1], instances[1]);
|
||||
// Load encoded with vk instances from memory
|
||||
uint256[] memory extracted_instances_memory_vk = das.getInstancesMemory(
|
||||
encoded_vka
|
||||
);
|
||||
assertEq(extracted_instances_memory_vk[0], instances[0]);
|
||||
assertEq(extracted_instances_memory_vk[1], instances[1]);
|
||||
}
|
||||
|
||||
function testInvalidCommitments() public {
|
||||
// Create calldata with invalid commitments
|
||||
bytes memory invalidProof = hex"5678";
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
invalidProof,
|
||||
instances
|
||||
);
|
||||
|
||||
vm.expectRevert("Invalid KZG commitments");
|
||||
das.verifyWithDataAttestation(address(verifier), encoded);
|
||||
}
|
||||
|
||||
function testInvalidVerifier() public {
|
||||
MockVerifier invalidVerifier = new MockVerifier(false);
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
bytes memory encoded = abi.encodeWithSignature(
|
||||
"verifyProof(bytes,uint256[])",
|
||||
hex"1234",
|
||||
instances
|
||||
);
|
||||
|
||||
vm.expectRevert("low-level call to verifier failed");
|
||||
das.verifyWithDataAttestation(address(invalidVerifier), encoded);
|
||||
}
|
||||
|
||||
// Test edge cases
|
||||
function testZeroValueQuantization() public view {
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
int256 zero = das.quantizeData(0, scalar);
|
||||
assertEq(zero, 0);
|
||||
}
|
||||
|
||||
function testOverflowProtection() public {
|
||||
int256 order = int(
|
||||
uint256(
|
||||
0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
|
||||
)
|
||||
);
|
||||
// int256 half_order = int(order >> 1);
|
||||
AttestData.DataAttestation.Scalars memory scalar = AttestData
|
||||
.DataAttestation
|
||||
.Scalars(1, 1 << 2);
|
||||
|
||||
vm.expectRevert("Overflow field modulus");
|
||||
das.quantizeData(order, scalar); // Value that would overflow
|
||||
}
|
||||
|
||||
function testInvalidFunctionSignature() public {
|
||||
uint256[] memory instances = new uint256[](2);
|
||||
AttestData.DataAttestation.Scalars memory scalar = das.getScalars(0);
|
||||
instances[0] = das.toFieldElement(int(scalar.bits));
|
||||
instances[1] = das.toFieldElement(-int(scalar.bits >> 1));
|
||||
bytes memory encoded_invalid_sig = abi.encodeWithSignature(
|
||||
"verifyProofff(bytes,uint256[])",
|
||||
hex"1234",
|
||||
instances
|
||||
);
|
||||
|
||||
vm.expectRevert("Invalid function signature");
|
||||
das.verifyWithDataAttestation(address(verifier), encoded_invalid_sig);
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[cfg(test)]
|
||||
mod native_tests {
|
||||
|
||||
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::Commitments;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
use ezkl::pfsys::Snark;
|
||||
use ezkl::Commitments;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::Bn256;
|
||||
use lazy_static::lazy_static;
|
||||
@@ -522,7 +522,7 @@ mod native_tests {
|
||||
use crate::native_tests::run_js_tests;
|
||||
use crate::native_tests::render_circuit;
|
||||
use crate::native_tests::model_serialization_different_binaries;
|
||||
|
||||
|
||||
use tempdir::TempDir;
|
||||
use ezkl::Commitments;
|
||||
|
||||
@@ -2293,7 +2293,12 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
|
||||
if status.success() {
|
||||
log::error!("Verification unexpectedly succeeded for modified proof {}. Flipped bit {} in byte {}", i, random_bit, random_byte);
|
||||
log::error!(
|
||||
"Verification unexpectedly succeeded for modified proof {}. Flipped bit {} in byte {}",
|
||||
i,
|
||||
random_bit,
|
||||
random_byte
|
||||
);
|
||||
}
|
||||
|
||||
assert!(
|
||||
@@ -2435,23 +2440,43 @@ mod native_tests {
|
||||
));
|
||||
}
|
||||
input.save(data_path.clone().into()).unwrap();
|
||||
let args = vec![
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
data_path.as_str(),
|
||||
"-M",
|
||||
&model_path,
|
||||
"--test-data",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
test_input_source.as_str(),
|
||||
test_output_source.as_str(),
|
||||
];
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
data_path.as_str(),
|
||||
"-M",
|
||||
&model_path,
|
||||
"--test-data",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
test_input_source.as_str(),
|
||||
test_output_source.as_str(),
|
||||
])
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// generate the witness, passing the vk path to generate the necessary kzg commits only
|
||||
// if input visibility is NOT hashed
|
||||
if input_visibility != "hashed" {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"gen-witness",
|
||||
"-D",
|
||||
&test_on_chain_data_path,
|
||||
"-M",
|
||||
&model_path,
|
||||
"-O",
|
||||
&witness_path,
|
||||
"--vk-path",
|
||||
&format!("{}/{}/key.vk", test_dir, example_name),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
@@ -2594,56 +2619,6 @@ mod native_tests {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// Create a new set of test on chain data only for the on-chain input source
|
||||
if input_source != "file" || output_source != "file" {
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
data_path.as_str(),
|
||||
"-M",
|
||||
&model_path,
|
||||
"--test-data",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
test_input_source.as_str(),
|
||||
test_output_source.as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
let deployed_addr_arg = format!("--addr={}", addr_da);
|
||||
|
||||
let args: Vec<&str> = vec![
|
||||
"test-update-account-calls",
|
||||
deployed_addr_arg.as_str(),
|
||||
"-D",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
}
|
||||
// As sanity check, add example that should fail.
|
||||
let args = vec![
|
||||
"verify-evm",
|
||||
"--proof-path",
|
||||
PF_FAILURE,
|
||||
deployed_addr_verifier_arg.as_str(),
|
||||
deployed_addr_da_arg.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/{}", *CARGO_TARGET_DIR, TEST_BINARY))
|
||||
.args(args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(!status.success());
|
||||
}
|
||||
|
||||
fn build_ezkl() {
|
||||
|
||||
Reference in New Issue
Block a user