mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2be181db35 | ||
|
|
de9e3f2673 | ||
|
|
a1450f8df7 | ||
|
|
ea535e2ecd | ||
|
|
f8aa91ed08 | ||
|
|
a59e3780b2 | ||
|
|
345fb5672a | ||
|
|
70daaff2e4 |
@@ -1,4 +1,4 @@
|
||||
name: Build and Publish WASM<>JS Bindings
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -14,7 +14,7 @@ defaults:
|
||||
run:
|
||||
working-directory: .
|
||||
jobs:
|
||||
wasm-publish:
|
||||
publish-wasm-bindings:
|
||||
name: publish-wasm-bindings
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -174,3 +174,40 @@ jobs:
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: ["publish-wasm-bindings"]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
npm install
|
||||
npm run build
|
||||
npm ci
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
16
.github/workflows/rust.yml
vendored
16
.github/workflows/rust.yml
vendored
@@ -303,12 +303,24 @@ jobs:
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
|
||||
- name: Build @ezkljs/verify package
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
pnpm build:commonjs
|
||||
cd ..
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
@@ -364,7 +376,7 @@ jobs:
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies for js tests
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
env:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -45,6 +45,7 @@ var/
|
||||
*.whl
|
||||
*.bak
|
||||
node_modules
|
||||
/dist
|
||||
timingData.json
|
||||
!tests/wasm/pk.key
|
||||
!tests/wasm/vk.key
|
||||
@@ -74,6 +74,10 @@ For more details visit the [docs](https://docs.ezkl.xyz).
|
||||
|
||||
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
|
||||
|
||||
#### In-browser EVM verifier
|
||||
|
||||
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
|
||||
|
||||
|
||||
### building the project 🔨
|
||||
|
||||
|
||||
48
examples/onnx/gather_nd/gen.py
Normal file
48
examples/onnx/gather_nd/gen.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from torch import nn
|
||||
import json
|
||||
import numpy as np
|
||||
import tf2onnx
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
|
||||
# gather_nd in tf then export to onnx
|
||||
|
||||
|
||||
|
||||
|
||||
x = in1 = Input((15, 18,))
|
||||
w = in2 = Input((15, 1), dtype=tf.int32)
|
||||
x = tf.gather_nd(x, w, batch_dims=1)
|
||||
tm = Model((in1, in2), x )
|
||||
tm.summary()
|
||||
tm.compile(optimizer='adam', loss='mse')
|
||||
|
||||
shape = [1, 15, 18]
|
||||
index_shape = [1, 15, 1]
|
||||
# After training, export to onnx (network.onnx) and create a data file (input.json)
|
||||
x = 0.1*np.random.rand(1,*shape)
|
||||
# w = random int tensor
|
||||
w = np.random.randint(0, 10, index_shape)
|
||||
|
||||
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
|
||||
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
|
||||
|
||||
model_path = "network.onnx"
|
||||
|
||||
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
|
||||
|
||||
|
||||
d = x.reshape([-1]).tolist()
|
||||
d1 = w.reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d, d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/gather_nd/input.json
Normal file
1
examples/onnx/gather_nd/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/gather_nd/network.onnx
Normal file
BIN
examples/onnx/gather_nd/network.onnx
Normal file
Binary file not shown.
76
examples/onnx/scatter_nd/gen.py
Normal file
76
examples/onnx/scatter_nd/gen.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
import json
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Just one Linear layer
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# Use this line if you want to visualize the weights
|
||||
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
||||
self.channels = configs.enc_in
|
||||
self.individual = configs.individual
|
||||
if self.individual:
|
||||
self.Linear = nn.ModuleList()
|
||||
for i in range(self.channels):
|
||||
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
|
||||
else:
|
||||
self.Linear = nn.Linear(self.seq_len, self.pred_len)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input length, Channel]
|
||||
if self.individual:
|
||||
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
|
||||
for i in range(self.channels):
|
||||
output[:,:,i] = self.Linear[i](x[:,:,i])
|
||||
x = output
|
||||
else:
|
||||
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
|
||||
return x # [Batch, Output length, Channel]
|
||||
|
||||
class Configs:
|
||||
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = enc_in
|
||||
self.individual = individual
|
||||
|
||||
model = 'Linear'
|
||||
seq_len = 10
|
||||
pred_len = 4
|
||||
enc_in = 3
|
||||
|
||||
configs = Configs(seq_len, pred_len, enc_in, True)
|
||||
circuit = Model(configs)
|
||||
|
||||
x = torch.randn(1, seq_len, pred_len)
|
||||
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=15, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
# the model's input names
|
||||
input_names=['input'],
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/scatter_nd/input.json
Normal file
1
examples/onnx/scatter_nd/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
|
||||
BIN
examples/onnx/scatter_nd/network.onnx
Normal file
BIN
examples/onnx/scatter_nd/network.onnx
Normal file
Binary file not shown.
60
in-browser-evm-verifier/README.md
Normal file
60
in-browser-evm-verifier/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# inbrowser-evm-verify
|
||||
|
||||
We would like the Solidity verifier to be canonical and usually all you ever need. For this, we need to be able to run that verifier in browser.
|
||||
|
||||
## How to use (Node js)
|
||||
|
||||
```ts
|
||||
import localEVMVerify from '@ezkljs/verify';
|
||||
|
||||
// Load in the proof file as a buffer
|
||||
const proofFileBuffer = fs.readFileSync(`${path}/${example}/proof.pf`)
|
||||
|
||||
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
|
||||
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
|
||||
|
||||
const result = await localEVMVerify(proofFileBuffer, bytecode)
|
||||
|
||||
console.log('result', result)
|
||||
```
|
||||
|
||||
**Note**: Run `ezkl create-evm-verifier` to get the Solidity verifier, with which you can retrieve the bytecode once compiled. We recommend compiling to the Shanghai hardfork target, else you will have to pass an additional parameter specifying the EVM version to the `localEVMVerify` function like so (for Paris hardfork):
|
||||
|
||||
```ts
|
||||
import localEVMVerify, { hardfork } from '@ezkljs/verify';
|
||||
|
||||
const result = await localEVMVerify(proofFileBuffer, bytecode, hardfork['Paris'])
|
||||
```
|
||||
|
||||
**Note**: You can also verify separated vk verifiers using the `localEVMVerify` function. Just pass the vk verifier bytecode as the third parameter like so:
|
||||
```ts
|
||||
import localEVMVerify from '@ezkljs/verify';
|
||||
|
||||
const result = await localEVMVerify(proofFileBuffer, verifierBytecode, VKBytecode)
|
||||
```
|
||||
|
||||
|
||||
## How to use (Browser)
|
||||
|
||||
```ts
|
||||
import localEVMVerify from '@ezkljs/verify';
|
||||
|
||||
// Load in the proof file as a buffer using the web apis (fetch, FileReader, etc)
|
||||
// We use fetch in this example to load the proof file as a buffer
|
||||
const proofFileBuffer = await fetch(`${path}/${example}/proof.pf`).then(res => res.arrayBuffer())
|
||||
|
||||
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
|
||||
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
|
||||
|
||||
const result = await browserEVMVerify(proofFileBuffer, bytecode)
|
||||
|
||||
console.log('result', result)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```ts
|
||||
result: true
|
||||
```
|
||||
|
||||
|
||||
42
in-browser-evm-verifier/package.json
Normal file
42
in-browser-evm-verifier/package.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"name": "@ezkljs/verify",
|
||||
"version": "0.0.0",
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
},
|
||||
"description": "Evm verify EZKL proofs in the browser.",
|
||||
"main": "dist/commonjs/index.js",
|
||||
"module": "dist/esm/index.js",
|
||||
"types": "dist/commonjs/index.d.ts",
|
||||
"files": [
|
||||
"dist",
|
||||
"LICENSE",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "rm -r dist || true",
|
||||
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
|
||||
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
|
||||
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ethereumjs/common": "^4.0.0",
|
||||
"@ethereumjs/evm": "^2.0.0",
|
||||
"@ethereumjs/statemanager": "^2.0.0",
|
||||
"@ethereumjs/tx": "^5.0.0",
|
||||
"@ethereumjs/util": "^9.0.0",
|
||||
"@ethereumjs/vm": "^7.0.0",
|
||||
"@ethersproject/abi": "^5.7.0",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"ethers": "^6.7.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.8.3",
|
||||
"ts-loader": "^9.5.0",
|
||||
"ts-node": "^10.9.1",
|
||||
"resolve-tspaths": "^0.8.16",
|
||||
"tsconfig-paths": "^4.2.0",
|
||||
"typescript": "^5.2.2"
|
||||
}
|
||||
}
|
||||
1479
in-browser-evm-verifier/pnpm-lock.yaml
generated
Normal file
1479
in-browser-evm-verifier/pnpm-lock.yaml
generated
Normal file
File diff suppressed because it is too large
Load Diff
145
in-browser-evm-verifier/src/index.ts
Normal file
145
in-browser-evm-verifier/src/index.ts
Normal file
@@ -0,0 +1,145 @@
|
||||
import { defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
|
||||
import { Address, hexToBytes } from '@ethereumjs/util'
|
||||
import { Chain, Common, Hardfork } from '@ethereumjs/common'
|
||||
import { LegacyTransaction, LegacyTxData } from '@ethereumjs/tx'
|
||||
// import { DefaultStateManager } from '@ethereumjs/statemanager'
|
||||
// import { Blockchain } from '@ethereumjs/blockchain'
|
||||
import { VM } from '@ethereumjs/vm'
|
||||
import { EVM } from '@ethereumjs/evm'
|
||||
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
|
||||
import { getAccountNonce, insertAccount } from './utils/account-utils'
|
||||
import { encodeVerifierCalldata } from '../nodejs/ezkl';
|
||||
import { error } from 'console'
|
||||
|
||||
async function deployContract(
|
||||
vm: VM,
|
||||
common: Common,
|
||||
senderPrivateKey: Uint8Array,
|
||||
deploymentBytecode: string
|
||||
): Promise<Address> {
|
||||
// Contracts are deployed by sending their deployment bytecode to the address 0
|
||||
// The contract params should be abi-encoded and appended to the deployment bytecode.
|
||||
// const data =
|
||||
const data = encodeDeployment(deploymentBytecode)
|
||||
const txData = {
|
||||
data,
|
||||
nonce: await getAccountNonce(vm, senderPrivateKey),
|
||||
}
|
||||
|
||||
const tx = LegacyTransaction.fromTxData(
|
||||
buildTransaction(txData) as LegacyTxData,
|
||||
{ common, allowUnlimitedInitCodeSize: true },
|
||||
).sign(senderPrivateKey)
|
||||
|
||||
const deploymentResult = await vm.runTx({
|
||||
tx,
|
||||
skipBlockGasLimitValidation: true,
|
||||
skipNonce: true
|
||||
})
|
||||
|
||||
if (deploymentResult.execResult.exceptionError) {
|
||||
throw deploymentResult.execResult.exceptionError
|
||||
}
|
||||
|
||||
return deploymentResult.createdAddress!
|
||||
}
|
||||
|
||||
async function verify(
|
||||
vm: VM,
|
||||
contractAddress: Address,
|
||||
caller: Address,
|
||||
proof: Uint8Array | Uint8ClampedArray,
|
||||
vkAddress?: Address | Uint8Array,
|
||||
): Promise<boolean> {
|
||||
if (proof instanceof Uint8Array) {
|
||||
proof = new Uint8ClampedArray(proof.buffer)
|
||||
}
|
||||
if (vkAddress) {
|
||||
const vkAddressBytes = hexToBytes(vkAddress.toString())
|
||||
const vkAddressArray = Array.from(vkAddressBytes)
|
||||
|
||||
let string = JSON.stringify(vkAddressArray)
|
||||
|
||||
const uint8Array = new TextEncoder().encode(string);
|
||||
|
||||
// Step 3: Convert to Uint8ClampedArray
|
||||
vkAddress = new Uint8Array(uint8Array.buffer);
|
||||
|
||||
// convert uitn8array of length
|
||||
error('vkAddress', vkAddress)
|
||||
}
|
||||
const data = encodeVerifierCalldata(proof, vkAddress)
|
||||
|
||||
const verifyResult = await vm.evm.runCall({
|
||||
to: contractAddress,
|
||||
caller: caller,
|
||||
origin: caller, // The tx.origin is also the caller here
|
||||
data: data,
|
||||
})
|
||||
|
||||
if (verifyResult.execResult.exceptionError) {
|
||||
throw verifyResult.execResult.exceptionError
|
||||
}
|
||||
|
||||
const results = AbiCoder.decode(['bool'], verifyResult.execResult.returnValue)
|
||||
|
||||
return results[0]
|
||||
}
|
||||
|
||||
/**
|
||||
* Spins up an ephemeral EVM instance for executing the bytecode of a solidity verifier
|
||||
* @param proof Json serialized proof file
|
||||
* @param bytecode The bytecode of a compiled solidity verifier.
|
||||
* @param bytecode_vk The bytecode of a contract that stores the vk. (Optional, only required if the vk is stored in a separate contract)
|
||||
* @param evmVersion The evm version to use for the verification. (Default: London)
|
||||
* @returns The result of the evm verification.
|
||||
* @throws If the verify transaction reverts
|
||||
*/
|
||||
export default async function localEVMVerify(
|
||||
proof: Uint8Array | Uint8ClampedArray,
|
||||
bytecode_verifier: string,
|
||||
bytecode_vk?: string,
|
||||
evmVersion?: Hardfork,
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
const hardfork = evmVersion ? evmVersion : Hardfork['Shanghai']
|
||||
const common = new Common({ chain: Chain.Mainnet, hardfork })
|
||||
const accountPk = hexToBytes(
|
||||
'0xe331b6d69882b4cb4ea581d88e0b604039a3de5967688d3dcffdd2270c0fd109', // anvil deterministic Pk
|
||||
)
|
||||
|
||||
const evm = new EVM({
|
||||
allowUnlimitedContractSize: true,
|
||||
allowUnlimitedInitCodeSize: true,
|
||||
})
|
||||
|
||||
const vm = await VM.create({ common, evm })
|
||||
const accountAddress = Address.fromPrivateKey(accountPk)
|
||||
|
||||
await insertAccount(vm, accountAddress)
|
||||
|
||||
const verifierAddress = await deployContract(
|
||||
vm,
|
||||
common,
|
||||
accountPk,
|
||||
bytecode_verifier
|
||||
)
|
||||
|
||||
if (bytecode_vk) {
|
||||
const accountPk = hexToBytes("0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"); // anvil deterministic Pk
|
||||
const accountAddress = Address.fromPrivateKey(accountPk)
|
||||
await insertAccount(vm, accountAddress)
|
||||
const output = await deployContract(vm, common, accountPk, bytecode_vk)
|
||||
const result = await verify(vm, verifierAddress, accountAddress, proof, output)
|
||||
return true
|
||||
}
|
||||
|
||||
const result = await verify(vm, verifierAddress, accountAddress, proof)
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
// log or re-throw the error, depending on your needs
|
||||
console.error('An error occurred:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
32
in-browser-evm-verifier/src/utils/account-utils.ts
Normal file
32
in-browser-evm-verifier/src/utils/account-utils.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { VM } from '@ethereumjs/vm'
|
||||
import { Account, Address } from '@ethereumjs/util'
|
||||
|
||||
export const keyPair = {
|
||||
secretKey:
|
||||
'0x3cd7232cd6f3fc66a57a6bedc1a8ed6c228fff0a327e169c2bcc5e869ed49511',
|
||||
publicKey:
|
||||
'0x0406cc661590d48ee972944b35ad13ff03c7876eae3fd191e8a2f77311b0a3c6613407b5005e63d7d8d76b89d5f900cde691497688bb281e07a5052ff61edebdc0',
|
||||
}
|
||||
|
||||
export const insertAccount = async (vm: VM, address: Address) => {
|
||||
const acctData = {
|
||||
nonce: 0,
|
||||
balance: BigInt('1000000000000000000'), // 1 eth
|
||||
}
|
||||
const account = Account.fromAccountData(acctData)
|
||||
|
||||
await vm.stateManager.putAccount(address, account)
|
||||
}
|
||||
|
||||
export const getAccountNonce = async (
|
||||
vm: VM,
|
||||
accountPrivateKey: Uint8Array,
|
||||
) => {
|
||||
const address = Address.fromPrivateKey(accountPrivateKey)
|
||||
const account = await vm.stateManager.getAccount(address)
|
||||
if (account) {
|
||||
return account.nonce
|
||||
} else {
|
||||
return BigInt(0)
|
||||
}
|
||||
}
|
||||
59
in-browser-evm-verifier/src/utils/tx-builder.ts
Normal file
59
in-browser-evm-verifier/src/utils/tx-builder.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { Interface, defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
|
||||
import {
|
||||
AccessListEIP2930TxData,
|
||||
FeeMarketEIP1559TxData,
|
||||
TxData,
|
||||
} from '@ethereumjs/tx'
|
||||
|
||||
type TransactionsData =
|
||||
| TxData
|
||||
| AccessListEIP2930TxData
|
||||
| FeeMarketEIP1559TxData
|
||||
|
||||
export const encodeFunction = (
|
||||
method: string,
|
||||
params?: {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
types: any[]
|
||||
values: unknown[]
|
||||
},
|
||||
): string => {
|
||||
const parameters = params?.types ?? []
|
||||
const methodWithParameters = `function ${method}(${parameters.join(',')})`
|
||||
const signatureHash = new Interface([methodWithParameters]).getSighash(method)
|
||||
const encodedArgs = AbiCoder.encode(parameters, params?.values ?? [])
|
||||
|
||||
return signatureHash + encodedArgs.slice(2)
|
||||
}
|
||||
|
||||
export const encodeDeployment = (
|
||||
bytecode: string,
|
||||
params?: {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
types: any[]
|
||||
values: unknown[]
|
||||
},
|
||||
) => {
|
||||
const deploymentData = '0x' + bytecode
|
||||
if (params) {
|
||||
const argumentsEncoded = AbiCoder.encode(params.types, params.values)
|
||||
return deploymentData + argumentsEncoded.slice(2)
|
||||
}
|
||||
return deploymentData
|
||||
}
|
||||
|
||||
export const buildTransaction = (
|
||||
data: Partial<TransactionsData>,
|
||||
): TransactionsData => {
|
||||
const defaultData: Partial<TransactionsData> = {
|
||||
gasLimit: 3_000_000_000_000_000,
|
||||
gasPrice: 7,
|
||||
value: 0,
|
||||
data: '0x',
|
||||
}
|
||||
|
||||
return {
|
||||
...defaultData,
|
||||
...data,
|
||||
}
|
||||
}
|
||||
7
in-browser-evm-verifier/tsconfig.commonjs.json
Normal file
7
in-browser-evm-verifier/tsconfig.commonjs.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"extends": "./tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"module": "CommonJS",
|
||||
"outDir": "./dist/commonjs"
|
||||
}
|
||||
}
|
||||
7
in-browser-evm-verifier/tsconfig.esm.json
Normal file
7
in-browser-evm-verifier/tsconfig.esm.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"extends": "./tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"module": "ES2020",
|
||||
"outDir": "./dist/esm"
|
||||
}
|
||||
}
|
||||
62
in-browser-evm-verifier/tsconfig.json
Normal file
62
in-browser-evm-verifier/tsconfig.json
Normal file
@@ -0,0 +1,62 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"rootDir": "src",
|
||||
"target": "es2017",
|
||||
"outDir": "dist",
|
||||
"declaration": true,
|
||||
"lib": [
|
||||
"dom",
|
||||
"dom.iterable",
|
||||
"esnext"
|
||||
],
|
||||
"allowJs": true,
|
||||
"checkJs": true,
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"noEmit": false,
|
||||
"esModuleInterop": true,
|
||||
"module": "CommonJS",
|
||||
"moduleResolution": "node",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "preserve",
|
||||
// "incremental": true,
|
||||
"noUncheckedIndexedAccess": true,
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"@/*": [
|
||||
"./src/*"
|
||||
]
|
||||
}
|
||||
},
|
||||
"include": [
|
||||
"src/**/*.ts",
|
||||
"src/**/*.tsx",
|
||||
"src/**/*.cjs",
|
||||
"src/**/*.mjs"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules"
|
||||
],
|
||||
// NEW: Options for file/directory watching
|
||||
"watchOptions": {
|
||||
// Use native file system events for files and directories
|
||||
"watchFile": "useFsEvents",
|
||||
"watchDirectory": "useFsEvents",
|
||||
// Poll files for updates more frequently
|
||||
// when they're updated a lot.
|
||||
"fallbackPolling": "dynamicPriority",
|
||||
// Don't coalesce watch notification
|
||||
"synchronousWatchDirectory": true,
|
||||
// Finally, two additional settings for reducing the amount of possible
|
||||
// files to track work from these directories
|
||||
"excludeDirectories": [
|
||||
"**/node_modules",
|
||||
"_build"
|
||||
],
|
||||
"excludeFiles": [
|
||||
"build/fileWhichChangesOften.ts"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@
|
||||
"test": "jest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@ezkljs/engine": "^2.4.5",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"@ezkljs/verify": "^0.0.6",
|
||||
"@jest/types": "^29.6.3",
|
||||
"@types/file-saver": "^2.0.5",
|
||||
@@ -27,4 +27,4 @@
|
||||
"tsconfig-paths": "^4.2.0",
|
||||
"typescript": "5.1.6"
|
||||
}
|
||||
}
|
||||
}
|
||||
11
pnpm-lock.yaml
generated
11
pnpm-lock.yaml
generated
@@ -6,8 +6,8 @@ settings:
|
||||
|
||||
devDependencies:
|
||||
'@ezkljs/engine':
|
||||
specifier: ^2.4.5
|
||||
version: 2.4.5
|
||||
specifier: ^9.4.4
|
||||
version: 9.4.4
|
||||
'@ezkljs/verify':
|
||||
specifier: ^0.0.6
|
||||
version: 0.0.6(buffer@6.0.3)
|
||||
@@ -785,6 +785,13 @@ packages:
|
||||
json-bigint: 1.0.0
|
||||
dev: true
|
||||
|
||||
/@ezkljs/engine@9.4.4:
|
||||
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
|
||||
dependencies:
|
||||
'@types/json-bigint': 1.0.1
|
||||
json-bigint: 1.0.0
|
||||
dev: true
|
||||
|
||||
/@ezkljs/verify@0.0.6(buffer@6.0.3):
|
||||
resolution: {integrity: sha512-9DHoEhLKl1DBGuUVseXLThuMyYceY08Zymr/OsLH0zbdA9OoISYhb77j4QPm4ANRKEm5dCi8oHDqkwGbFc2xFQ==}
|
||||
dependencies:
|
||||
|
||||
@@ -17,7 +17,6 @@ pub enum BaseOp {
|
||||
Sub,
|
||||
SumInit,
|
||||
Sum,
|
||||
IsZero,
|
||||
IsBoolean,
|
||||
}
|
||||
|
||||
@@ -35,7 +34,6 @@ impl BaseOp {
|
||||
BaseOp::Add => a + b,
|
||||
BaseOp::Sub => a - b,
|
||||
BaseOp::Mult => a * b,
|
||||
BaseOp::IsZero => b,
|
||||
BaseOp::IsBoolean => b,
|
||||
_ => panic!("nonaccum_f called on accumulating operation"),
|
||||
}
|
||||
@@ -76,7 +74,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => "MULT",
|
||||
BaseOp::Sum => "SUM",
|
||||
BaseOp::SumInit => "SUMINIT",
|
||||
BaseOp::IsZero => "ISZERO",
|
||||
BaseOp::IsBoolean => "ISBOOLEAN",
|
||||
}
|
||||
}
|
||||
@@ -93,7 +90,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => (0, 1),
|
||||
BaseOp::Sum => (-1, 2),
|
||||
BaseOp::SumInit => (0, 1),
|
||||
BaseOp::IsZero => (0, 1),
|
||||
BaseOp::IsBoolean => (0, 1),
|
||||
}
|
||||
}
|
||||
@@ -110,7 +106,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => 2,
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
@@ -127,7 +122,6 @@ impl BaseOp {
|
||||
BaseOp::SumInit => 0,
|
||||
BaseOp::CumProd => 1,
|
||||
BaseOp::CumProdInit => 0,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,7 +387,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
|
||||
}
|
||||
}
|
||||
@@ -432,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
vec![expected_output[base_op.constraint_idx()].clone()]
|
||||
}
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,10 +14,17 @@ pub enum PolyOp {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
GatherND {
|
||||
batch_dims: usize,
|
||||
indices: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterND {
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
MultiBroadcastTo {
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
@@ -60,8 +67,6 @@ pub enum PolyOp {
|
||||
len_prod: usize,
|
||||
},
|
||||
Pow(u32),
|
||||
Pack(u32, u32),
|
||||
GlobalSumPool,
|
||||
Concat {
|
||||
axis: usize,
|
||||
},
|
||||
@@ -91,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterND { .. } => "SCATTERND".into(),
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -110,8 +117,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Sum { .. } => "SUM".into(),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Pack(_, _) => "PACK".into(),
|
||||
PolyOp::GlobalSumPool => "GLOBALSUMPOOL".into(),
|
||||
PolyOp::Conv { .. } => "CONV".into(),
|
||||
PolyOp::DeConv { .. } => "DECONV".into(),
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
@@ -181,13 +186,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
output_padding,
|
||||
stride,
|
||||
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
|
||||
PolyOp::Pack(base, scale) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pack inputs".to_string()));
|
||||
}
|
||||
|
||||
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
|
||||
}
|
||||
PolyOp::Pow(u) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pow inputs".to_string()));
|
||||
@@ -206,7 +204,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
}
|
||||
tensor::ops::prod_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::GlobalSumPool => unreachable!(),
|
||||
PolyOp::Concat { axis } => {
|
||||
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
|
||||
}
|
||||
@@ -225,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
indices,
|
||||
batch_dims,
|
||||
} => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = indices {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_nd(&x, &y, *batch_dims)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
@@ -241,6 +250,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
}
|
||||
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter_nd(&x, &idx, &src)
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
@@ -288,7 +312,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
}
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices,
|
||||
} => {
|
||||
if let Some(idx) = indices {
|
||||
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
|
||||
} else {
|
||||
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
@@ -304,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter_nd(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_nd(config, region, values[..].try_into()?)?
|
||||
}
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
@@ -334,10 +380,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
input
|
||||
}
|
||||
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
|
||||
PolyOp::Pack(base, scale) => {
|
||||
layouts::pack(config, region, values[..].try_into()?, *base, *scale)?
|
||||
}
|
||||
PolyOp::GlobalSumPool => unreachable!(),
|
||||
PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?,
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
|
||||
@@ -405,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
vec![1, 2]
|
||||
} else if matches!(self, PolyOp::Concat { .. }) {
|
||||
(0..100).collect()
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. }) {
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. })
|
||||
| matches!(self, PolyOp::ScatterND { .. })
|
||||
{
|
||||
vec![0, 2]
|
||||
} else {
|
||||
vec![]
|
||||
|
||||
@@ -2243,75 +2243,6 @@ mod pow {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod pack {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 8;
|
||||
const LEN: usize = 4;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 1],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MyCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
Box::new(PolyOp::Pack(2, 1)),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn packcircuit() {
|
||||
// parameters
|
||||
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
inputs: [ValTensor::from(a)],
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod matmul_relu {
|
||||
use super::*;
|
||||
|
||||
@@ -402,9 +402,6 @@ pub enum Commands {
|
||||
/// Number of logrows to use for srs. Overrides settings_path if specified.
|
||||
#[arg(long, default_value = None)]
|
||||
logrows: Option<u32>,
|
||||
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
check: CheckMode,
|
||||
},
|
||||
/// Loads model and input and runs mock prover (for testing)
|
||||
Mock {
|
||||
|
||||
@@ -159,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
check,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows).await,
|
||||
Commands::Table { model, args } => table(model, args),
|
||||
#[cfg(feature = "render")]
|
||||
Commands::RenderCircuit {
|
||||
@@ -492,23 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
use std::io::Read;
|
||||
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let file = std::fs::File::open(path.clone())?;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut buffer = vec![];
|
||||
let mut reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
let bytes_read = reader.read_to_end(&mut buffer)?;
|
||||
|
||||
info!(
|
||||
"read {} bytes from SRS file (vector of len = {})",
|
||||
"read {} bytes from file (vector of len = {})",
|
||||
bytes_read,
|
||||
buffer.len()
|
||||
);
|
||||
|
||||
let hash = sha256::digest(buffer);
|
||||
info!("SRS hash: {}", hash);
|
||||
info!("file hash: {}", hash);
|
||||
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let hash = get_file_hash(&path)?;
|
||||
|
||||
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
|
||||
Some(h) => h,
|
||||
@@ -532,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// logrows overrides settings
|
||||
|
||||
@@ -560,21 +563,20 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated");
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
|
||||
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
buffer.write_all(reader.get_ref())?;
|
||||
buffer.flush()?;
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
@@ -902,7 +904,12 @@ pub(crate) fn calibrate(
|
||||
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
|
||||
));
|
||||
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
let key = (
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
let local_run_args = RunArgs {
|
||||
@@ -971,29 +978,21 @@ pub(crate) fn calibrate(
|
||||
#[cfg(unix)]
|
||||
drop(_g);
|
||||
|
||||
let min_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
let result = forward_pass_res.get(&key).ok_or("key not found")?;
|
||||
|
||||
let min_lookup_range = result
|
||||
.iter()
|
||||
.map(|x| x.min_lookup_inputs)
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
let max_lookup_range = result
|
||||
.iter()
|
||||
.map(|x| x.max_lookup_inputs)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_range_size = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| x.max_range_size)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
let max_range_size = result.iter().map(|x| x.max_range_size).max().unwrap_or(0);
|
||||
|
||||
let res = circuit.calc_min_logrows(
|
||||
(min_lookup_range, max_lookup_range),
|
||||
@@ -1114,6 +1113,7 @@ pub(crate) fn calibrate(
|
||||
best_params.run_args.input_scale,
|
||||
best_params.run_args.param_scale,
|
||||
best_params.run_args.scale_rebase_multiplier,
|
||||
best_params.run_args.div_rebasing,
|
||||
))
|
||||
.ok_or("no params found")?
|
||||
.iter()
|
||||
@@ -1697,7 +1697,7 @@ pub(crate) fn fuzz(
|
||||
let logrows = circuit.settings().run_args.logrows;
|
||||
|
||||
info!("setting up tests");
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout()?;
|
||||
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
|
||||
@@ -1715,6 +1715,7 @@ pub(crate) fn fuzz(
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
|
||||
info!("starting fuzzing");
|
||||
@@ -1905,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
passed: &AtomicBool,
|
||||
) {
|
||||
let num_failures = AtomicI64::new(0);
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout().unwrap();
|
||||
|
||||
let pb = init_bar(num_runs as u64);
|
||||
@@ -1918,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
pb.inc(1);
|
||||
});
|
||||
pb.finish_with_message("Done.");
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
info!(
|
||||
"num failures: {} out of {}",
|
||||
|
||||
@@ -23,7 +23,10 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
},
|
||||
change_axes::AxisOp,
|
||||
cnn::{Conv, Deconv},
|
||||
einsum::EinSum,
|
||||
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
|
||||
|
||||
// Extract the max value
|
||||
}
|
||||
"ScatterNd" => {
|
||||
if inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"scatter nd".to_string(),
|
||||
)));
|
||||
};
|
||||
// just verify it deserializes correctly
|
||||
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
}
|
||||
// }
|
||||
|
||||
if inputs[1].opkind().is_input() {
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
|
||||
op
|
||||
}
|
||||
|
||||
"GatherNd" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"gather nd".to_string(),
|
||||
)));
|
||||
};
|
||||
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
let batch_dims = op.batch_dims;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
}
|
||||
// }
|
||||
|
||||
if inputs[1].opkind().is_input() {
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
|
||||
op
|
||||
}
|
||||
|
||||
"GatherElements" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
|
||||
@@ -484,7 +484,6 @@ fn get_srs(
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
CheckMode::SAFE,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
@@ -1104,6 +1103,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyG1Affine>()?;
|
||||
m.add_class::<PyG1>()?;
|
||||
m.add_class::<PyTestDataSource>()?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
|
||||
|
||||
@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
Tensor::new(Some(&res), &dims)
|
||||
}
|
||||
|
||||
/// Set a slice of the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// assert_eq!(a, b);
|
||||
/// ```
|
||||
pub fn set_slice(
|
||||
&mut self,
|
||||
indices: &[Range<usize>],
|
||||
value: &Tensor<T>,
|
||||
) -> Result<(), TensorError>
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
if indices.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if self.dims.len() < indices.len() {
|
||||
return Err(TensorError::DimError(format!(
|
||||
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
|
||||
indices, self.dims
|
||||
)));
|
||||
}
|
||||
|
||||
// if indices weren't specified we fill them in as required
|
||||
let mut full_indices = indices.to_vec();
|
||||
|
||||
let omitted_dims = (indices.len()..self.dims.len())
|
||||
.map(|i| self.dims[i])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for dim in &omitted_dims {
|
||||
full_indices.push(0..*dim);
|
||||
}
|
||||
|
||||
let full_dims = full_indices
|
||||
.iter()
|
||||
.map(|x| x.end - x.start)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// now broadcast the value to the full dims
|
||||
let value = value.expand(&full_dims)?;
|
||||
|
||||
let cartesian_coord: Vec<Vec<usize>> = full_indices
|
||||
.iter()
|
||||
.cloned()
|
||||
.multi_cartesian_product()
|
||||
.collect();
|
||||
|
||||
let _ = cartesian_coord
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| {
|
||||
self.set(e, value[i].clone());
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the array index from rows / columns indices.
|
||||
///
|
||||
/// ```
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::TensorError;
|
||||
use crate::tensor::{Tensor, TensorType};
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::{
|
||||
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
|
||||
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
|
||||
prelude::IntoParallelRefIterator,
|
||||
};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Gather ND.
|
||||
/// # Arguments
|
||||
/// * `input` - Tensor
|
||||
/// * `index` - Tensor of indices to gather
|
||||
/// * `batch_dims` - Number of batch dimensions
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::gather_nd;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[0, 1, 2, 3]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 0, 1, 1]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[1, 0]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
/// &[2, 2, 2],
|
||||
/// ).unwrap();
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 1, 0]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 1, 0]),
|
||||
/// &[2, 1, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[1, 0]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 1).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
|
||||
/// &[2, 2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
|
||||
/// &[2, 2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
pub fn gather_nd<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
index: &Tensor<usize>,
|
||||
batch_dims: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// Calculate the output tensor size
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
let last_value = index_dims
|
||||
.last()
|
||||
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
|
||||
if last_value > &(input_dims.len() - batch_dims) {
|
||||
return Err(TensorError::DimMismatch("gather_nd".to_string()));
|
||||
}
|
||||
|
||||
let output_size =
|
||||
// If indices_shape[-1] == r-b, since the rank of indices is q,
|
||||
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
|
||||
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
|
||||
// Let us think of each such r-b ranked tensor as indices_slice.
|
||||
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
|
||||
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
|
||||
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
|
||||
// Let us think of each such tensors as indices_slice.
|
||||
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
|
||||
{
|
||||
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
|
||||
|
||||
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
|
||||
let input_offset = batch_dims + last_value;
|
||||
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
|
||||
|
||||
assert_eq!(output_rank, dims.len());
|
||||
dims
|
||||
|
||||
};
|
||||
|
||||
// cartesian coord over batch dims
|
||||
let mut batch_cartesian_coord = input_dims[0..batch_dims]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if batch_cartesian_coord.is_empty() {
|
||||
batch_cartesian_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let outputs = batch_cartesian_coord
|
||||
.par_iter()
|
||||
.map(|batch_coord| {
|
||||
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut index_slice = index.get_slice(&batch_slice)?;
|
||||
index_slice.reshape(&index.dims()[batch_dims..])?;
|
||||
let mut input_slice = input.get_slice(&batch_slice)?;
|
||||
input_slice.reshape(&input.dims()[batch_dims..])?;
|
||||
|
||||
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if inner_cartesian_coord.is_empty() {
|
||||
inner_cartesian_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let output = inner_cartesian_coord
|
||||
.iter()
|
||||
.map(|coord| {
|
||||
let slice = coord
|
||||
.iter()
|
||||
.map(|x| *x..*x + 1)
|
||||
.chain(batch_coord.iter().map(|x| *x..*x + 1))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let index_slice = index_slice
|
||||
.get_slice(&slice)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| *x..*x + 1)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
input_slice.get_slice(&index_slice).unwrap()
|
||||
})
|
||||
.collect::<Tensor<_>>();
|
||||
|
||||
output.combine()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
|
||||
|
||||
outputs.reshape(&output_size)?;
|
||||
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
/// Scatter ND.
|
||||
/// This operator is the inverse of GatherND.
|
||||
/// # Arguments
|
||||
/// * `input` - Tensor
|
||||
/// * `index` - Tensor of indices to scatter
|
||||
/// * `src` - Tensor of src
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::scatter_nd;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[8],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[4, 3, 1, 7]),
|
||||
/// &[4, 1],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9, 10, 11, 12]),
|
||||
/// &[4],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[4, 4, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 2]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
|
||||
/// ]),
|
||||
/// &[2, 4, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
///
|
||||
/// let expected = Tensor::<i128>::new(
|
||||
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[4, 4, 4],
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[2, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9, 10]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[2, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1]),
|
||||
/// &[1, 1, 2],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9]),
|
||||
/// &[1, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ````
|
||||
///
|
||||
pub fn scatter_nd<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
index: &Tensor<usize>,
|
||||
src: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// Calculate the output tensor size
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
let last_value = index_dims
|
||||
.last()
|
||||
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
|
||||
if last_value > &input_dims.len() {
|
||||
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
|
||||
}
|
||||
|
||||
let mut output = input.clone();
|
||||
|
||||
let cartesian_coord = index_dims[0..index_dims.len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cartesian_coord
|
||||
.iter()
|
||||
.map(|coord| {
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let index_val = index.get_slice(&slice)?;
|
||||
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let src_val = src.get_slice(&slice)?;
|
||||
output.set_slice(&index_slice, &src_val)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn axes_op<T: TensorType + Send + Sync>(
|
||||
a: &Tensor<T>,
|
||||
axes: &[usize],
|
||||
|
||||
@@ -4,6 +4,37 @@ use super::{
|
||||
};
|
||||
use halo2_proofs::{arithmetic::Field, plonk::Instance};
|
||||
|
||||
pub(crate) fn create_constant_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
val: F,
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut constant = Tensor::from(vec![ValType::Constant(val); len].into_iter());
|
||||
constant.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(constant)
|
||||
}
|
||||
|
||||
pub(crate) fn create_unit_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut unit = Tensor::from(vec![ValType::Constant(F::ONE); len].into_iter());
|
||||
unit.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(unit)
|
||||
}
|
||||
|
||||
pub(crate) fn create_zero_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut zero = Tensor::from(vec![ValType::Constant(F::ZERO); len].into_iter());
|
||||
zero.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(zero)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// A [ValType] is a wrapper around Halo2 value(s).
|
||||
pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> {
|
||||
@@ -463,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
Ok(integer_evals.into_iter().into())
|
||||
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
|
||||
@@ -193,7 +193,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 77] = [
|
||||
const TESTS: [&str; 79] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -275,6 +275,8 @@ mod native_tests {
|
||||
"ltsf",
|
||||
"remainder", //75
|
||||
"bitshift",
|
||||
"gather_nd",
|
||||
"scatter_nd",
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -502,7 +504,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=76 {
|
||||
seq!(N in 0..=78 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -589,13 +591,16 @@ mod native_tests {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
// currently variable output rank is not supported in ONNX
|
||||
if test != "gather_nd" {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
@@ -853,7 +858,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -866,7 +871,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -914,6 +919,7 @@ mod native_tests {
|
||||
use crate::native_tests::kzg_fuzz;
|
||||
use tempdir::TempDir;
|
||||
use crate::native_tests::Hardfork;
|
||||
use crate::native_tests::run_js_tests;
|
||||
|
||||
/// Currently only on chain inputs that return a non-negative value are supported.
|
||||
const TESTS_ON_CHAIN_INPUT: [&str; 17] = [
|
||||
@@ -1008,8 +1014,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
@@ -1021,8 +1027,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify_render_seperately(2, path, test.to_string(), "private", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", true);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
@@ -1035,8 +1041,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let mut _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "hashed", "private", "private");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1052,8 +1058,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let mut _anvil_child = crate::native_tests::start_anvil(false, hardfork);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1065,8 +1071,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "hashed", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
@@ -1078,8 +1084,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "hashed");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1091,8 +1097,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "kzgcommit", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1104,8 +1110,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "kzgcommit");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1116,8 +1122,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -2179,15 +2185,17 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// run js browser evm verify tests for a given example
|
||||
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str) {
|
||||
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str, vk: bool) {
|
||||
let example = format!("--example={}", example_name);
|
||||
let dir = format!("--dir={}", test_dir);
|
||||
let mut args = vec!["run", "test", js_test, &example, &dir];
|
||||
let vk_string: String;
|
||||
if vk {
|
||||
vk_string = format!("--vk={}", vk);
|
||||
args.push(&vk_string);
|
||||
};
|
||||
let status = Command::new("pnpm")
|
||||
.args([
|
||||
"run",
|
||||
"test",
|
||||
js_test,
|
||||
&format!("--example={}", example_name),
|
||||
&format!("--dir={}", test_dir),
|
||||
])
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -1,28 +1,42 @@
|
||||
import localEVMVerify, { Hardfork } from '@ezkljs/verify'
|
||||
import localEVMVerify from '../../in-browser-evm-verifier/src/index'
|
||||
import { serialize, deserialize } from '@ezkljs/engine/nodejs'
|
||||
import { compileContracts } from './utils'
|
||||
import * as fs from 'fs'
|
||||
|
||||
exports.USER_NAME = require("minimist")(process.argv.slice(2))["example"];
|
||||
exports.EXAMPLE = require("minimist")(process.argv.slice(2))["example"];
|
||||
exports.PATH = require("minimist")(process.argv.slice(2))["dir"];
|
||||
exports.VK = require("minimist")(process.argv.slice(2))["vk"];
|
||||
|
||||
describe('localEVMVerify', () => {
|
||||
|
||||
let bytecode: string
|
||||
let bytecode_verifier: string
|
||||
|
||||
let bytecode_vk: string | undefined = undefined
|
||||
|
||||
let proof: any
|
||||
|
||||
const example = exports.USER_NAME || "1l_mlp"
|
||||
const example = exports.EXAMPLE || "1l_mlp"
|
||||
const path = exports.PATH || "../ezkl/examples/onnx"
|
||||
const vk = exports.VK || false
|
||||
|
||||
beforeEach(() => {
|
||||
let solcOutput = compileContracts(path, example)
|
||||
const solcOutput = compileContracts(path, example, 'kzg')
|
||||
|
||||
bytecode =
|
||||
bytecode_verifier =
|
||||
solcOutput.contracts['artifacts/Verifier.sol']['Halo2Verifier'].evm.bytecode
|
||||
.object
|
||||
|
||||
console.log('size', bytecode.length)
|
||||
if (vk) {
|
||||
const solcOutput_vk = compileContracts(path, example, 'vk')
|
||||
|
||||
bytecode_vk =
|
||||
solcOutput_vk.contracts['artifacts/Verifier.sol']['Halo2VerifyingKey'].evm.bytecode
|
||||
.object
|
||||
|
||||
|
||||
console.log('size of verifier bytecode', bytecode_verifier.length)
|
||||
}
|
||||
console.log('verifier bytecode', bytecode_verifier)
|
||||
})
|
||||
|
||||
it('should return true when verification succeeds', async () => {
|
||||
@@ -30,7 +44,9 @@ describe('localEVMVerify', () => {
|
||||
|
||||
proof = deserialize(proofFileBuffer)
|
||||
|
||||
const result = await localEVMVerify(proofFileBuffer, bytecode)
|
||||
const result = await localEVMVerify(proofFileBuffer, bytecode_verifier, bytecode_vk)
|
||||
|
||||
console.log('result', result)
|
||||
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
@@ -39,13 +55,16 @@ describe('localEVMVerify', () => {
|
||||
let result: boolean = true
|
||||
console.log(proof.proof)
|
||||
try {
|
||||
let index = Math.floor(Math.random() * (proof.proof.length - 2)) + 2
|
||||
let number = (proof.proof[index] + 1) % 16
|
||||
let index = Math.round((Math.random() * (proof.proof.length))) % proof.proof.length
|
||||
console.log('index', index)
|
||||
console.log('index', proof.proof[index])
|
||||
let number = (proof.proof[index] + 1) % 256
|
||||
console.log('index', index)
|
||||
console.log('new number', number)
|
||||
proof.proof[index] = number
|
||||
console.log('index post', proof.proof[index])
|
||||
const proofModified = serialize(proof)
|
||||
result = await localEVMVerify(proofModified, bytecode)
|
||||
result = await localEVMVerify(proofModified, bytecode_verifier, bytecode_vk)
|
||||
} catch (error) {
|
||||
// Check if the error thrown is the "out of gas" error.
|
||||
expect(error).toEqual(
|
||||
|
||||
@@ -43,21 +43,21 @@ export function serialize(data: object | string): Uint8ClampedArray { // data is
|
||||
return new Uint8ClampedArray(uint8Array.buffer);
|
||||
}
|
||||
|
||||
export function getSolcInput(path: string, example: string) {
|
||||
export function getSolcInput(path: string, example: string, name: string) {
|
||||
return {
|
||||
language: 'Solidity',
|
||||
sources: {
|
||||
'artifacts/Verifier.sol': {
|
||||
content: fsSync.readFileSync(`${path}/${example}/kzg.sol`, 'utf-8'),
|
||||
content: fsSync.readFileSync(`${path}/${example}/${name}.sol`, 'utf-8'),
|
||||
},
|
||||
// If more contracts were to be compiled, they should have their own entries here
|
||||
},
|
||||
settings: {
|
||||
optimizer: {
|
||||
enabled: true,
|
||||
runs: 200,
|
||||
runs: 1,
|
||||
},
|
||||
evmVersion: 'london',
|
||||
evmVersion: 'shanghai',
|
||||
outputSelection: {
|
||||
'*': {
|
||||
'*': ['abi', 'evm.bytecode'],
|
||||
@@ -67,8 +67,8 @@ export function getSolcInput(path: string, example: string) {
|
||||
}
|
||||
}
|
||||
|
||||
export function compileContracts(path: string, example: string) {
|
||||
const input = getSolcInput(path, example)
|
||||
export function compileContracts(path: string, example: string, name: string) {
|
||||
const input = getSolcInput(path, example, name)
|
||||
const output = JSON.parse(solc.compile(JSON.stringify(input)))
|
||||
|
||||
let compilationFailed = false
|
||||
|
||||
Binary file not shown.
1
verifier_abi.json
Normal file
1
verifier_abi.json
Normal file
@@ -0,0 +1 @@
|
||||
[{"type":"function","name":"verifyProof","inputs":[{"internalType":"bytes","name":"proof","type":"bytes"},{"internalType":"uint256[]","name":"instances","type":"uint256[]"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"nonpayable"}]
|
||||
Reference in New Issue
Block a user