mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2be181db35 | ||
|
|
de9e3f2673 | ||
|
|
a1450f8df7 | ||
|
|
ea535e2ecd | ||
|
|
f8aa91ed08 | ||
|
|
a59e3780b2 | ||
|
|
345fb5672a | ||
|
|
70daaff2e4 | ||
|
|
a437d8a51f | ||
|
|
fe535c1cac | ||
|
|
3e8dcb001a | ||
|
|
14786acb95 | ||
|
|
80a3c44cb4 | ||
|
|
1656846d1a |
@@ -1,4 +1,4 @@
|
||||
name: Build and Publish WASM<>JS Bindings
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -14,7 +14,7 @@ defaults:
|
||||
run:
|
||||
working-directory: .
|
||||
jobs:
|
||||
wasm-publish:
|
||||
publish-wasm-bindings:
|
||||
name: publish-wasm-bindings
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -174,3 +174,40 @@ jobs:
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: ["publish-wasm-bindings"]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
npm install
|
||||
npm run build
|
||||
npm ci
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
22
.github/workflows/rust.yml
vendored
22
.github/workflows/rust.yml
vendored
@@ -303,12 +303,24 @@ jobs:
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
|
||||
- name: Build @ezkljs/verify package
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
pnpm build:commonjs
|
||||
cd ..
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
@@ -364,7 +376,7 @@ jobs:
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies for js tests
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
env:
|
||||
@@ -575,7 +587,7 @@ jobs:
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
run: source .env/bin/activate; pytest -vv
|
||||
|
||||
@@ -599,7 +611,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Div rebase
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
|
||||
- name: Public inputs
|
||||
@@ -634,7 +646,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
# - name: authenticate-kaggle-cli
|
||||
# shell: bash
|
||||
# env:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -45,6 +45,7 @@ var/
|
||||
*.whl
|
||||
*.bak
|
||||
node_modules
|
||||
/dist
|
||||
timingData.json
|
||||
!tests/wasm/pk.key
|
||||
!tests/wasm/vk.key
|
||||
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -2263,7 +2263,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec 1.0.1",
|
||||
@@ -2280,7 +2280,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=main#4d7e6ddac661283e2b73c551b2e8f0011cedd50f"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
|
||||
@@ -74,6 +74,10 @@ For more details visit the [docs](https://docs.ezkl.xyz).
|
||||
|
||||
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
|
||||
|
||||
#### In-browser EVM verifier
|
||||
|
||||
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
|
||||
|
||||
|
||||
### building the project 🔨
|
||||
|
||||
|
||||
@@ -343,7 +343,6 @@
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" compress_selectors=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" assert res == True\n",
|
||||
|
||||
@@ -633,7 +633,7 @@
|
||||
"json.dump(data, open(cal_path, 'w'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [4])"
|
||||
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales = [11])"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -664,7 +664,6 @@
|
||||
" compiled_model_path,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path,\n",
|
||||
" \n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
||||
48
examples/onnx/gather_nd/gen.py
Normal file
48
examples/onnx/gather_nd/gen.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from torch import nn
|
||||
import json
|
||||
import numpy as np
|
||||
import tf2onnx
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
|
||||
# gather_nd in tf then export to onnx
|
||||
|
||||
|
||||
|
||||
|
||||
x = in1 = Input((15, 18,))
|
||||
w = in2 = Input((15, 1), dtype=tf.int32)
|
||||
x = tf.gather_nd(x, w, batch_dims=1)
|
||||
tm = Model((in1, in2), x )
|
||||
tm.summary()
|
||||
tm.compile(optimizer='adam', loss='mse')
|
||||
|
||||
shape = [1, 15, 18]
|
||||
index_shape = [1, 15, 1]
|
||||
# After training, export to onnx (network.onnx) and create a data file (input.json)
|
||||
x = 0.1*np.random.rand(1,*shape)
|
||||
# w = random int tensor
|
||||
w = np.random.randint(0, 10, index_shape)
|
||||
|
||||
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
|
||||
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
|
||||
|
||||
model_path = "network.onnx"
|
||||
|
||||
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
|
||||
|
||||
|
||||
d = x.reshape([-1]).tolist()
|
||||
d1 = w.reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d, d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/gather_nd/input.json
Normal file
1
examples/onnx/gather_nd/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/gather_nd/network.onnx
Normal file
BIN
examples/onnx/gather_nd/network.onnx
Normal file
Binary file not shown.
76
examples/onnx/scatter_nd/gen.py
Normal file
76
examples/onnx/scatter_nd/gen.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
import json
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Just one Linear layer
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# Use this line if you want to visualize the weights
|
||||
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
||||
self.channels = configs.enc_in
|
||||
self.individual = configs.individual
|
||||
if self.individual:
|
||||
self.Linear = nn.ModuleList()
|
||||
for i in range(self.channels):
|
||||
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
|
||||
else:
|
||||
self.Linear = nn.Linear(self.seq_len, self.pred_len)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input length, Channel]
|
||||
if self.individual:
|
||||
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
|
||||
for i in range(self.channels):
|
||||
output[:,:,i] = self.Linear[i](x[:,:,i])
|
||||
x = output
|
||||
else:
|
||||
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
|
||||
return x # [Batch, Output length, Channel]
|
||||
|
||||
class Configs:
|
||||
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = enc_in
|
||||
self.individual = individual
|
||||
|
||||
model = 'Linear'
|
||||
seq_len = 10
|
||||
pred_len = 4
|
||||
enc_in = 3
|
||||
|
||||
configs = Configs(seq_len, pred_len, enc_in, True)
|
||||
circuit = Model(configs)
|
||||
|
||||
x = torch.randn(1, seq_len, pred_len)
|
||||
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=15, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
# the model's input names
|
||||
input_names=['input'],
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/scatter_nd/input.json
Normal file
1
examples/onnx/scatter_nd/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
|
||||
BIN
examples/onnx/scatter_nd/network.onnx
Normal file
BIN
examples/onnx/scatter_nd/network.onnx
Normal file
Binary file not shown.
60
in-browser-evm-verifier/README.md
Normal file
60
in-browser-evm-verifier/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# inbrowser-evm-verify
|
||||
|
||||
We would like the Solidity verifier to be canonical and usually all you ever need. For this, we need to be able to run that verifier in browser.
|
||||
|
||||
## How to use (Node js)
|
||||
|
||||
```ts
|
||||
import localEVMVerify from '@ezkljs/verify';
|
||||
|
||||
// Load in the proof file as a buffer
|
||||
const proofFileBuffer = fs.readFileSync(`${path}/${example}/proof.pf`)
|
||||
|
||||
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
|
||||
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
|
||||
|
||||
const result = await localEVMVerify(proofFileBuffer, bytecode)
|
||||
|
||||
console.log('result', result)
|
||||
```
|
||||
|
||||
**Note**: Run `ezkl create-evm-verifier` to get the Solidity verifier, with which you can retrieve the bytecode once compiled. We recommend compiling to the Shanghai hardfork target, else you will have to pass an additional parameter specifying the EVM version to the `localEVMVerify` function like so (for Paris hardfork):
|
||||
|
||||
```ts
|
||||
import localEVMVerify, { hardfork } from '@ezkljs/verify';
|
||||
|
||||
const result = await localEVMVerify(proofFileBuffer, bytecode, hardfork['Paris'])
|
||||
```
|
||||
|
||||
**Note**: You can also verify separated vk verifiers using the `localEVMVerify` function. Just pass the vk verifier bytecode as the third parameter like so:
|
||||
```ts
|
||||
import localEVMVerify from '@ezkljs/verify';
|
||||
|
||||
const result = await localEVMVerify(proofFileBuffer, verifierBytecode, VKBytecode)
|
||||
```
|
||||
|
||||
|
||||
## How to use (Browser)
|
||||
|
||||
```ts
|
||||
import localEVMVerify from '@ezkljs/verify';
|
||||
|
||||
// Load in the proof file as a buffer using the web apis (fetch, FileReader, etc)
|
||||
// We use fetch in this example to load the proof file as a buffer
|
||||
const proofFileBuffer = await fetch(`${path}/${example}/proof.pf`).then(res => res.arrayBuffer())
|
||||
|
||||
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
|
||||
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
|
||||
|
||||
const result = await browserEVMVerify(proofFileBuffer, bytecode)
|
||||
|
||||
console.log('result', result)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```ts
|
||||
result: true
|
||||
```
|
||||
|
||||
|
||||
42
in-browser-evm-verifier/package.json
Normal file
42
in-browser-evm-verifier/package.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"name": "@ezkljs/verify",
|
||||
"version": "0.0.0",
|
||||
"publishConfig": {
|
||||
"access": "public"
|
||||
},
|
||||
"description": "Evm verify EZKL proofs in the browser.",
|
||||
"main": "dist/commonjs/index.js",
|
||||
"module": "dist/esm/index.js",
|
||||
"types": "dist/commonjs/index.d.ts",
|
||||
"files": [
|
||||
"dist",
|
||||
"LICENSE",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
"clean": "rm -r dist || true",
|
||||
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
|
||||
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
|
||||
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ethereumjs/common": "^4.0.0",
|
||||
"@ethereumjs/evm": "^2.0.0",
|
||||
"@ethereumjs/statemanager": "^2.0.0",
|
||||
"@ethereumjs/tx": "^5.0.0",
|
||||
"@ethereumjs/util": "^9.0.0",
|
||||
"@ethereumjs/vm": "^7.0.0",
|
||||
"@ethersproject/abi": "^5.7.0",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"ethers": "^6.7.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.8.3",
|
||||
"ts-loader": "^9.5.0",
|
||||
"ts-node": "^10.9.1",
|
||||
"resolve-tspaths": "^0.8.16",
|
||||
"tsconfig-paths": "^4.2.0",
|
||||
"typescript": "^5.2.2"
|
||||
}
|
||||
}
|
||||
1479
in-browser-evm-verifier/pnpm-lock.yaml
generated
Normal file
1479
in-browser-evm-verifier/pnpm-lock.yaml
generated
Normal file
File diff suppressed because it is too large
Load Diff
145
in-browser-evm-verifier/src/index.ts
Normal file
145
in-browser-evm-verifier/src/index.ts
Normal file
@@ -0,0 +1,145 @@
|
||||
import { defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
|
||||
import { Address, hexToBytes } from '@ethereumjs/util'
|
||||
import { Chain, Common, Hardfork } from '@ethereumjs/common'
|
||||
import { LegacyTransaction, LegacyTxData } from '@ethereumjs/tx'
|
||||
// import { DefaultStateManager } from '@ethereumjs/statemanager'
|
||||
// import { Blockchain } from '@ethereumjs/blockchain'
|
||||
import { VM } from '@ethereumjs/vm'
|
||||
import { EVM } from '@ethereumjs/evm'
|
||||
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
|
||||
import { getAccountNonce, insertAccount } from './utils/account-utils'
|
||||
import { encodeVerifierCalldata } from '../nodejs/ezkl';
|
||||
import { error } from 'console'
|
||||
|
||||
async function deployContract(
|
||||
vm: VM,
|
||||
common: Common,
|
||||
senderPrivateKey: Uint8Array,
|
||||
deploymentBytecode: string
|
||||
): Promise<Address> {
|
||||
// Contracts are deployed by sending their deployment bytecode to the address 0
|
||||
// The contract params should be abi-encoded and appended to the deployment bytecode.
|
||||
// const data =
|
||||
const data = encodeDeployment(deploymentBytecode)
|
||||
const txData = {
|
||||
data,
|
||||
nonce: await getAccountNonce(vm, senderPrivateKey),
|
||||
}
|
||||
|
||||
const tx = LegacyTransaction.fromTxData(
|
||||
buildTransaction(txData) as LegacyTxData,
|
||||
{ common, allowUnlimitedInitCodeSize: true },
|
||||
).sign(senderPrivateKey)
|
||||
|
||||
const deploymentResult = await vm.runTx({
|
||||
tx,
|
||||
skipBlockGasLimitValidation: true,
|
||||
skipNonce: true
|
||||
})
|
||||
|
||||
if (deploymentResult.execResult.exceptionError) {
|
||||
throw deploymentResult.execResult.exceptionError
|
||||
}
|
||||
|
||||
return deploymentResult.createdAddress!
|
||||
}
|
||||
|
||||
async function verify(
|
||||
vm: VM,
|
||||
contractAddress: Address,
|
||||
caller: Address,
|
||||
proof: Uint8Array | Uint8ClampedArray,
|
||||
vkAddress?: Address | Uint8Array,
|
||||
): Promise<boolean> {
|
||||
if (proof instanceof Uint8Array) {
|
||||
proof = new Uint8ClampedArray(proof.buffer)
|
||||
}
|
||||
if (vkAddress) {
|
||||
const vkAddressBytes = hexToBytes(vkAddress.toString())
|
||||
const vkAddressArray = Array.from(vkAddressBytes)
|
||||
|
||||
let string = JSON.stringify(vkAddressArray)
|
||||
|
||||
const uint8Array = new TextEncoder().encode(string);
|
||||
|
||||
// Step 3: Convert to Uint8ClampedArray
|
||||
vkAddress = new Uint8Array(uint8Array.buffer);
|
||||
|
||||
// convert uitn8array of length
|
||||
error('vkAddress', vkAddress)
|
||||
}
|
||||
const data = encodeVerifierCalldata(proof, vkAddress)
|
||||
|
||||
const verifyResult = await vm.evm.runCall({
|
||||
to: contractAddress,
|
||||
caller: caller,
|
||||
origin: caller, // The tx.origin is also the caller here
|
||||
data: data,
|
||||
})
|
||||
|
||||
if (verifyResult.execResult.exceptionError) {
|
||||
throw verifyResult.execResult.exceptionError
|
||||
}
|
||||
|
||||
const results = AbiCoder.decode(['bool'], verifyResult.execResult.returnValue)
|
||||
|
||||
return results[0]
|
||||
}
|
||||
|
||||
/**
|
||||
* Spins up an ephemeral EVM instance for executing the bytecode of a solidity verifier
|
||||
* @param proof Json serialized proof file
|
||||
* @param bytecode The bytecode of a compiled solidity verifier.
|
||||
* @param bytecode_vk The bytecode of a contract that stores the vk. (Optional, only required if the vk is stored in a separate contract)
|
||||
* @param evmVersion The evm version to use for the verification. (Default: London)
|
||||
* @returns The result of the evm verification.
|
||||
* @throws If the verify transaction reverts
|
||||
*/
|
||||
export default async function localEVMVerify(
|
||||
proof: Uint8Array | Uint8ClampedArray,
|
||||
bytecode_verifier: string,
|
||||
bytecode_vk?: string,
|
||||
evmVersion?: Hardfork,
|
||||
): Promise<boolean> {
|
||||
try {
|
||||
const hardfork = evmVersion ? evmVersion : Hardfork['Shanghai']
|
||||
const common = new Common({ chain: Chain.Mainnet, hardfork })
|
||||
const accountPk = hexToBytes(
|
||||
'0xe331b6d69882b4cb4ea581d88e0b604039a3de5967688d3dcffdd2270c0fd109', // anvil deterministic Pk
|
||||
)
|
||||
|
||||
const evm = new EVM({
|
||||
allowUnlimitedContractSize: true,
|
||||
allowUnlimitedInitCodeSize: true,
|
||||
})
|
||||
|
||||
const vm = await VM.create({ common, evm })
|
||||
const accountAddress = Address.fromPrivateKey(accountPk)
|
||||
|
||||
await insertAccount(vm, accountAddress)
|
||||
|
||||
const verifierAddress = await deployContract(
|
||||
vm,
|
||||
common,
|
||||
accountPk,
|
||||
bytecode_verifier
|
||||
)
|
||||
|
||||
if (bytecode_vk) {
|
||||
const accountPk = hexToBytes("0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"); // anvil deterministic Pk
|
||||
const accountAddress = Address.fromPrivateKey(accountPk)
|
||||
await insertAccount(vm, accountAddress)
|
||||
const output = await deployContract(vm, common, accountPk, bytecode_vk)
|
||||
const result = await verify(vm, verifierAddress, accountAddress, proof, output)
|
||||
return true
|
||||
}
|
||||
|
||||
const result = await verify(vm, verifierAddress, accountAddress, proof)
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
// log or re-throw the error, depending on your needs
|
||||
console.error('An error occurred:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
32
in-browser-evm-verifier/src/utils/account-utils.ts
Normal file
32
in-browser-evm-verifier/src/utils/account-utils.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { VM } from '@ethereumjs/vm'
|
||||
import { Account, Address } from '@ethereumjs/util'
|
||||
|
||||
export const keyPair = {
|
||||
secretKey:
|
||||
'0x3cd7232cd6f3fc66a57a6bedc1a8ed6c228fff0a327e169c2bcc5e869ed49511',
|
||||
publicKey:
|
||||
'0x0406cc661590d48ee972944b35ad13ff03c7876eae3fd191e8a2f77311b0a3c6613407b5005e63d7d8d76b89d5f900cde691497688bb281e07a5052ff61edebdc0',
|
||||
}
|
||||
|
||||
export const insertAccount = async (vm: VM, address: Address) => {
|
||||
const acctData = {
|
||||
nonce: 0,
|
||||
balance: BigInt('1000000000000000000'), // 1 eth
|
||||
}
|
||||
const account = Account.fromAccountData(acctData)
|
||||
|
||||
await vm.stateManager.putAccount(address, account)
|
||||
}
|
||||
|
||||
export const getAccountNonce = async (
|
||||
vm: VM,
|
||||
accountPrivateKey: Uint8Array,
|
||||
) => {
|
||||
const address = Address.fromPrivateKey(accountPrivateKey)
|
||||
const account = await vm.stateManager.getAccount(address)
|
||||
if (account) {
|
||||
return account.nonce
|
||||
} else {
|
||||
return BigInt(0)
|
||||
}
|
||||
}
|
||||
59
in-browser-evm-verifier/src/utils/tx-builder.ts
Normal file
59
in-browser-evm-verifier/src/utils/tx-builder.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { Interface, defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
|
||||
import {
|
||||
AccessListEIP2930TxData,
|
||||
FeeMarketEIP1559TxData,
|
||||
TxData,
|
||||
} from '@ethereumjs/tx'
|
||||
|
||||
type TransactionsData =
|
||||
| TxData
|
||||
| AccessListEIP2930TxData
|
||||
| FeeMarketEIP1559TxData
|
||||
|
||||
export const encodeFunction = (
|
||||
method: string,
|
||||
params?: {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
types: any[]
|
||||
values: unknown[]
|
||||
},
|
||||
): string => {
|
||||
const parameters = params?.types ?? []
|
||||
const methodWithParameters = `function ${method}(${parameters.join(',')})`
|
||||
const signatureHash = new Interface([methodWithParameters]).getSighash(method)
|
||||
const encodedArgs = AbiCoder.encode(parameters, params?.values ?? [])
|
||||
|
||||
return signatureHash + encodedArgs.slice(2)
|
||||
}
|
||||
|
||||
export const encodeDeployment = (
|
||||
bytecode: string,
|
||||
params?: {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
types: any[]
|
||||
values: unknown[]
|
||||
},
|
||||
) => {
|
||||
const deploymentData = '0x' + bytecode
|
||||
if (params) {
|
||||
const argumentsEncoded = AbiCoder.encode(params.types, params.values)
|
||||
return deploymentData + argumentsEncoded.slice(2)
|
||||
}
|
||||
return deploymentData
|
||||
}
|
||||
|
||||
export const buildTransaction = (
|
||||
data: Partial<TransactionsData>,
|
||||
): TransactionsData => {
|
||||
const defaultData: Partial<TransactionsData> = {
|
||||
gasLimit: 3_000_000_000_000_000,
|
||||
gasPrice: 7,
|
||||
value: 0,
|
||||
data: '0x',
|
||||
}
|
||||
|
||||
return {
|
||||
...defaultData,
|
||||
...data,
|
||||
}
|
||||
}
|
||||
7
in-browser-evm-verifier/tsconfig.commonjs.json
Normal file
7
in-browser-evm-verifier/tsconfig.commonjs.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"extends": "./tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"module": "CommonJS",
|
||||
"outDir": "./dist/commonjs"
|
||||
}
|
||||
}
|
||||
7
in-browser-evm-verifier/tsconfig.esm.json
Normal file
7
in-browser-evm-verifier/tsconfig.esm.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"extends": "./tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"module": "ES2020",
|
||||
"outDir": "./dist/esm"
|
||||
}
|
||||
}
|
||||
62
in-browser-evm-verifier/tsconfig.json
Normal file
62
in-browser-evm-verifier/tsconfig.json
Normal file
@@ -0,0 +1,62 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"rootDir": "src",
|
||||
"target": "es2017",
|
||||
"outDir": "dist",
|
||||
"declaration": true,
|
||||
"lib": [
|
||||
"dom",
|
||||
"dom.iterable",
|
||||
"esnext"
|
||||
],
|
||||
"allowJs": true,
|
||||
"checkJs": true,
|
||||
"skipLibCheck": true,
|
||||
"strict": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"noEmit": false,
|
||||
"esModuleInterop": true,
|
||||
"module": "CommonJS",
|
||||
"moduleResolution": "node",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"jsx": "preserve",
|
||||
// "incremental": true,
|
||||
"noUncheckedIndexedAccess": true,
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"@/*": [
|
||||
"./src/*"
|
||||
]
|
||||
}
|
||||
},
|
||||
"include": [
|
||||
"src/**/*.ts",
|
||||
"src/**/*.tsx",
|
||||
"src/**/*.cjs",
|
||||
"src/**/*.mjs"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules"
|
||||
],
|
||||
// NEW: Options for file/directory watching
|
||||
"watchOptions": {
|
||||
// Use native file system events for files and directories
|
||||
"watchFile": "useFsEvents",
|
||||
"watchDirectory": "useFsEvents",
|
||||
// Poll files for updates more frequently
|
||||
// when they're updated a lot.
|
||||
"fallbackPolling": "dynamicPriority",
|
||||
// Don't coalesce watch notification
|
||||
"synchronousWatchDirectory": true,
|
||||
// Finally, two additional settings for reducing the amount of possible
|
||||
// files to track work from these directories
|
||||
"excludeDirectories": [
|
||||
"**/node_modules",
|
||||
"_build"
|
||||
],
|
||||
"excludeFiles": [
|
||||
"build/fileWhichChangesOften.ts"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@
|
||||
"test": "jest"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@ezkljs/engine": "^2.4.5",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"@ezkljs/verify": "^0.0.6",
|
||||
"@jest/types": "^29.6.3",
|
||||
"@types/file-saver": "^2.0.5",
|
||||
@@ -27,4 +27,4 @@
|
||||
"tsconfig-paths": "^4.2.0",
|
||||
"typescript": "5.1.6"
|
||||
}
|
||||
}
|
||||
}
|
||||
11
pnpm-lock.yaml
generated
11
pnpm-lock.yaml
generated
@@ -6,8 +6,8 @@ settings:
|
||||
|
||||
devDependencies:
|
||||
'@ezkljs/engine':
|
||||
specifier: ^2.4.5
|
||||
version: 2.4.5
|
||||
specifier: ^9.4.4
|
||||
version: 9.4.4
|
||||
'@ezkljs/verify':
|
||||
specifier: ^0.0.6
|
||||
version: 0.0.6(buffer@6.0.3)
|
||||
@@ -785,6 +785,13 @@ packages:
|
||||
json-bigint: 1.0.0
|
||||
dev: true
|
||||
|
||||
/@ezkljs/engine@9.4.4:
|
||||
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
|
||||
dependencies:
|
||||
'@types/json-bigint': 1.0.1
|
||||
json-bigint: 1.0.0
|
||||
dev: true
|
||||
|
||||
/@ezkljs/verify@0.0.6(buffer@6.0.3):
|
||||
resolution: {integrity: sha512-9DHoEhLKl1DBGuUVseXLThuMyYceY08Zymr/OsLH0zbdA9OoISYhb77j4QPm4ANRKEm5dCi8oHDqkwGbFc2xFQ==}
|
||||
dependencies:
|
||||
|
||||
@@ -12,15 +12,11 @@ pub enum BaseOp {
|
||||
DotInit,
|
||||
CumProdInit,
|
||||
CumProd,
|
||||
Identity,
|
||||
Add,
|
||||
Mult,
|
||||
Sub,
|
||||
SumInit,
|
||||
Sum,
|
||||
Neg,
|
||||
Range { tol: i32 },
|
||||
IsZero,
|
||||
IsBoolean,
|
||||
}
|
||||
|
||||
@@ -36,12 +32,8 @@ impl BaseOp {
|
||||
let (a, b) = inputs;
|
||||
match &self {
|
||||
BaseOp::Add => a + b,
|
||||
BaseOp::Identity => b,
|
||||
BaseOp::Neg => -b,
|
||||
BaseOp::Sub => a - b,
|
||||
BaseOp::Mult => a * b,
|
||||
BaseOp::Range { .. } => b,
|
||||
BaseOp::IsZero => b,
|
||||
BaseOp::IsBoolean => b,
|
||||
_ => panic!("nonaccum_f called on accumulating operation"),
|
||||
}
|
||||
@@ -73,19 +65,15 @@ impl BaseOp {
|
||||
/// display func
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
BaseOp::Identity => "IDENTITY",
|
||||
BaseOp::Dot => "DOT",
|
||||
BaseOp::DotInit => "DOTINIT",
|
||||
BaseOp::CumProdInit => "CUMPRODINIT",
|
||||
BaseOp::CumProd => "CUMPROD",
|
||||
BaseOp::Add => "ADD",
|
||||
BaseOp::Neg => "NEG",
|
||||
BaseOp::Sub => "SUB",
|
||||
BaseOp::Mult => "MULT",
|
||||
BaseOp::Sum => "SUM",
|
||||
BaseOp::SumInit => "SUMINIT",
|
||||
BaseOp::Range { .. } => "RANGE",
|
||||
BaseOp::IsZero => "ISZERO",
|
||||
BaseOp::IsBoolean => "ISBOOLEAN",
|
||||
}
|
||||
}
|
||||
@@ -93,8 +81,6 @@ impl BaseOp {
|
||||
/// Returns the range of the query offset for this operation.
|
||||
pub fn query_offset_rng(&self) -> (i32, usize) {
|
||||
match self {
|
||||
BaseOp::Identity => (0, 1),
|
||||
BaseOp::Neg => (0, 1),
|
||||
BaseOp::DotInit => (0, 1),
|
||||
BaseOp::Dot => (-1, 2),
|
||||
BaseOp::CumProd => (-1, 2),
|
||||
@@ -104,8 +90,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => (0, 1),
|
||||
BaseOp::Sum => (-1, 2),
|
||||
BaseOp::SumInit => (0, 1),
|
||||
BaseOp::Range { .. } => (0, 1),
|
||||
BaseOp::IsZero => (0, 1),
|
||||
BaseOp::IsBoolean => (0, 1),
|
||||
}
|
||||
}
|
||||
@@ -113,8 +97,6 @@ impl BaseOp {
|
||||
/// Returns the number of inputs for this operation.
|
||||
pub fn num_inputs(&self) -> usize {
|
||||
match self {
|
||||
BaseOp::Identity => 1,
|
||||
BaseOp::Neg => 1,
|
||||
BaseOp::DotInit => 2,
|
||||
BaseOp::Dot => 2,
|
||||
BaseOp::CumProdInit => 1,
|
||||
@@ -124,8 +106,6 @@ impl BaseOp {
|
||||
BaseOp::Mult => 2,
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::Range { .. } => 1,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
@@ -133,19 +113,15 @@ impl BaseOp {
|
||||
/// Returns the number of outputs for this operation.
|
||||
pub fn constraint_idx(&self) -> usize {
|
||||
match self {
|
||||
BaseOp::Identity => 0,
|
||||
BaseOp::Neg => 0,
|
||||
BaseOp::DotInit => 0,
|
||||
BaseOp::Dot => 1,
|
||||
BaseOp::Add => 0,
|
||||
BaseOp::Sub => 0,
|
||||
BaseOp::Mult => 0,
|
||||
BaseOp::Range { .. } => 0,
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 0,
|
||||
BaseOp::CumProd => 1,
|
||||
BaseOp::CumProdInit => 0,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,31 +188,158 @@ impl<'source> FromPyObject<'source> for Tolerance {
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookups {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub lookup_selectors: BTreeMap<(usize, usize), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub table_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// tables
|
||||
pub tables: Vec<VarTensor>,
|
||||
}
|
||||
|
||||
impl DynamicLookups {
|
||||
/// Returns a new [DynamicLookups] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
let single_col_dummy_var = VarTensor::dummy(col_size, 1);
|
||||
|
||||
Self {
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
table_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
|
||||
tables: vec![
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
single_col_dummy_var.clone(),
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for the dynamic lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct Shuffles {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub input_selectors: BTreeMap<(usize, usize), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub reference_selectors: Vec<Selector>,
|
||||
/// Inputs:
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// tables
|
||||
pub references: Vec<VarTensor>,
|
||||
}
|
||||
|
||||
impl Shuffles {
|
||||
/// Returns a new [DynamicLookups] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
let single_col_dummy_var = VarTensor::dummy(col_size, 1);
|
||||
|
||||
Self {
|
||||
input_selectors: BTreeMap::new(),
|
||||
reference_selectors: vec![],
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone()],
|
||||
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for the static lookup tables
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct StaticLookups<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub tables: BTreeMap<LookupOp, Table<F>>,
|
||||
///
|
||||
pub index: VarTensor,
|
||||
///
|
||||
pub output: VarTensor,
|
||||
///
|
||||
pub input: VarTensor,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> StaticLookups<F> {
|
||||
/// Returns a new [StaticLookups] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
|
||||
Self {
|
||||
selectors: BTreeMap::new(),
|
||||
tables: BTreeMap::new(),
|
||||
index: dummy_var.clone(),
|
||||
output: dummy_var.clone(),
|
||||
input: dummy_var,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for custom gates
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct CustomGates {
|
||||
/// the inputs to the accumulated operations.
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// the (currently singular) output of the accumulated operations.
|
||||
pub output: VarTensor,
|
||||
/// selector
|
||||
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
|
||||
}
|
||||
|
||||
impl CustomGates {
|
||||
/// Returns a new [CustomGates] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
Self {
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone()],
|
||||
output: dummy_var,
|
||||
selectors: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing the selectors for the range checks
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct RangeChecks<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
|
||||
pub selectors: BTreeMap<(Range, usize, usize), Selector>,
|
||||
/// Selectors for the dynamic lookup tables
|
||||
pub ranges: BTreeMap<Range, RangeCheck<F>>,
|
||||
///
|
||||
pub index: VarTensor,
|
||||
///
|
||||
pub input: VarTensor,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeChecks<F> {
|
||||
/// Returns a new [RangeChecks] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
Self {
|
||||
selectors: BTreeMap::new(),
|
||||
ranges: BTreeMap::new(),
|
||||
index: dummy_var.clone(),
|
||||
input: dummy_var,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for an accumulated arg.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// the inputs to the accumulated operations.
|
||||
pub inputs: Vec<VarTensor>,
|
||||
/// the VarTensor reserved for lookup operations (could be an element of inputs)
|
||||
/// Note that you should be careful to ensure that the lookup_input is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
|
||||
pub lookup_input: VarTensor,
|
||||
/// the (currently singular) output of the accumulated operations.
|
||||
pub output: VarTensor,
|
||||
/// the VarTensor reserved for lookup operations (could be an element of inputs or the same as output)
|
||||
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
|
||||
pub lookup_output: VarTensor,
|
||||
///
|
||||
pub lookup_index: VarTensor,
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure [BaseOp].
|
||||
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
|
||||
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
|
||||
///
|
||||
pub tables: BTreeMap<LookupOp, Table<F>>,
|
||||
///
|
||||
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
|
||||
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
|
||||
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
|
||||
/// Custom gates
|
||||
pub custom_gates: CustomGates,
|
||||
/// StaticLookups
|
||||
pub static_lookups: StaticLookups<F>,
|
||||
/// [Selector]s for the dynamic lookup tables
|
||||
pub dynamic_lookups: DynamicLookups,
|
||||
/// [Selector]s for the range checks
|
||||
pub range_checks: RangeChecks<F>,
|
||||
/// [Selector]s for the shuffles
|
||||
pub shuffles: Shuffles,
|
||||
/// Activate sanity checks
|
||||
pub check_mode: CheckMode,
|
||||
_marker: PhantomData<F>,
|
||||
@@ -221,19 +348,12 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
|
||||
|
||||
Self {
|
||||
inputs: vec![dummy_var.clone(), dummy_var.clone()],
|
||||
lookup_input: dummy_var.clone(),
|
||||
output: dummy_var.clone(),
|
||||
lookup_output: dummy_var.clone(),
|
||||
lookup_index: dummy_var,
|
||||
selectors: BTreeMap::new(),
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
|
||||
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
|
||||
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
|
||||
shuffles: Shuffles::dummy(col_size, num_inner_cols),
|
||||
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
|
||||
check_mode: CheckMode::SAFE,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -266,10 +386,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
for j in 0..output.num_inner_cols() {
|
||||
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Neg, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::Identity, i, j), meta.selector());
|
||||
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
|
||||
}
|
||||
}
|
||||
@@ -314,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
vec![expected_output[base_op.constraint_idx()].clone()]
|
||||
}
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
@@ -373,16 +484,15 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
selectors,
|
||||
lookup_selectors: BTreeMap::new(),
|
||||
range_check_selectors: BTreeMap::new(),
|
||||
inputs: inputs.to_vec(),
|
||||
lookup_input: VarTensor::Empty,
|
||||
lookup_output: VarTensor::Empty,
|
||||
lookup_index: VarTensor::Empty,
|
||||
tables: BTreeMap::new(),
|
||||
range_checks: BTreeMap::new(),
|
||||
output: output.clone(),
|
||||
custom_gates: CustomGates {
|
||||
inputs: inputs.to_vec(),
|
||||
output: output.clone(),
|
||||
selectors,
|
||||
},
|
||||
static_lookups: StaticLookups::default(),
|
||||
dynamic_lookups: DynamicLookups::default(),
|
||||
shuffles: Shuffles::default(),
|
||||
range_checks: RangeChecks::default(),
|
||||
check_mode,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -403,8 +513,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
let mut selectors = BTreeMap::new();
|
||||
|
||||
if !index.is_advice() {
|
||||
return Err("wrong input type for lookup index".into());
|
||||
}
|
||||
@@ -417,9 +525,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let table = if !self.tables.contains_key(nl) {
|
||||
let table = if !self.static_lookups.tables.contains_key(nl) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let table = if let Some(table) = self.tables.values().next() {
|
||||
let table = if let Some(table) = self.static_lookups.tables.values().next() {
|
||||
Table::<F>::configure(
|
||||
cs,
|
||||
lookup_range,
|
||||
@@ -430,7 +538,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
} else {
|
||||
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
|
||||
};
|
||||
self.tables.insert(nl.clone(), table.clone());
|
||||
self.static_lookups.tables.insert(nl.clone(), table.clone());
|
||||
table
|
||||
} else {
|
||||
return Ok(());
|
||||
@@ -514,26 +622,193 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
selectors.insert((nl.clone(), x, y), multi_col_selector);
|
||||
self.static_lookups
|
||||
.selectors
|
||||
.insert((nl.clone(), x, y), multi_col_selector);
|
||||
}
|
||||
}
|
||||
self.lookup_selectors.extend(selectors);
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if let VarTensor::Empty = self.lookup_input {
|
||||
if let VarTensor::Empty = self.static_lookups.input {
|
||||
debug!("assigning lookup input");
|
||||
self.lookup_input = input.clone();
|
||||
self.static_lookups.input = input.clone();
|
||||
}
|
||||
if let VarTensor::Empty = self.lookup_output {
|
||||
if let VarTensor::Empty = self.static_lookups.output {
|
||||
debug!("assigning lookup output");
|
||||
self.lookup_output = output.clone();
|
||||
self.static_lookups.output = output.clone();
|
||||
}
|
||||
if let VarTensor::Empty = self.lookup_index {
|
||||
if let VarTensor::Empty = self.static_lookups.index {
|
||||
debug!("assigning lookup index");
|
||||
self.lookup_index = index.clone();
|
||||
self.static_lookups.index = index.clone();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_dynamic_lookup(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
lookups: &[VarTensor; 3],
|
||||
tables: &[VarTensor; 3],
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
for l in lookups.iter() {
|
||||
if !l.is_advice() {
|
||||
return Err("wrong input type for dynamic lookup".into());
|
||||
}
|
||||
}
|
||||
|
||||
for t in tables.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
return Err("wrong table type for dynamic lookup".into());
|
||||
}
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
let s_ltable = cs.complex_selector();
|
||||
|
||||
for x in 0..lookups[0].num_blocks() {
|
||||
for y in 0..lookups[0].num_inner_cols() {
|
||||
let s_lookup = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_lookupq = cs.query_selector(s_lookup);
|
||||
let mut expression = vec![];
|
||||
let s_ltableq = cs.query_selector(s_ltable);
|
||||
let mut lookup_queries = vec![one.clone()];
|
||||
|
||||
for lookup in lookups {
|
||||
lookup_queries.push(match lookup {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut table_queries = vec![one.clone()];
|
||||
for table in tables {
|
||||
table_queries.push(match table {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[0][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
|
||||
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
});
|
||||
self.dynamic_lookups
|
||||
.lookup_selectors
|
||||
.entry((x, y))
|
||||
.or_insert(s_lookup);
|
||||
}
|
||||
}
|
||||
self.dynamic_lookups.table_selectors.push(s_ltable);
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.dynamic_lookups.tables.is_empty() {
|
||||
debug!("assigning dynamic lookup table");
|
||||
self.dynamic_lookups.tables = tables.to_vec();
|
||||
}
|
||||
if self.dynamic_lookups.inputs.is_empty() {
|
||||
debug!("assigning dynamic lookup input");
|
||||
self.dynamic_lookups.inputs = lookups.to_vec();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_shuffles(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
inputs: &[VarTensor; 2],
|
||||
references: &[VarTensor; 2],
|
||||
) -> Result<(), Box<dyn Error>>
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
for l in inputs.iter() {
|
||||
if !l.is_advice() {
|
||||
return Err("wrong input type for dynamic lookup".into());
|
||||
}
|
||||
}
|
||||
|
||||
for t in references.iter() {
|
||||
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
|
||||
return Err("wrong table type for dynamic lookup".into());
|
||||
}
|
||||
}
|
||||
|
||||
let one = Expression::Constant(F::ONE);
|
||||
|
||||
let s_reference = cs.complex_selector();
|
||||
|
||||
for x in 0..inputs[0].num_blocks() {
|
||||
for y in 0..inputs[0].num_inner_cols() {
|
||||
let s_input = cs.complex_selector();
|
||||
|
||||
cs.lookup_any("lookup", |cs| {
|
||||
let s_inputq = cs.query_selector(s_input);
|
||||
let mut expression = vec![];
|
||||
let s_referenceq = cs.query_selector(s_reference);
|
||||
let mut input_queries = vec![one.clone()];
|
||||
|
||||
for input in inputs {
|
||||
input_queries.push(match input {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[x][y], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut ref_queries = vec![one.clone()];
|
||||
for reference in references {
|
||||
ref_queries.push(match reference {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
cs.query_advice(advices[0][0], Rotation(0))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
|
||||
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
|
||||
expression.extend(lhs.zip(rhs));
|
||||
|
||||
expression
|
||||
});
|
||||
self.shuffles
|
||||
.input_selectors
|
||||
.entry((x, y))
|
||||
.or_insert(s_input);
|
||||
}
|
||||
}
|
||||
self.shuffles.reference_selectors.push(s_reference);
|
||||
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if self.shuffles.references.is_empty() {
|
||||
debug!("assigning shuffles reference");
|
||||
self.shuffles.references = references.to_vec();
|
||||
}
|
||||
if self.shuffles.inputs.is_empty() {
|
||||
debug!("assigning shuffles input");
|
||||
self.shuffles.inputs = inputs.to_vec();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Configures and creates lookup selectors
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn configure_range_check(
|
||||
@@ -547,23 +822,22 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
where
|
||||
F: Field,
|
||||
{
|
||||
let mut selectors = BTreeMap::new();
|
||||
|
||||
if !input.is_advice() {
|
||||
return Err("wrong input type for lookup input".into());
|
||||
}
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let range_check =
|
||||
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
|
||||
self.range_checks.ranges.entry(range)
|
||||
{
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
|
||||
e.insert(range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
for x in 0..input.num_blocks() {
|
||||
for y in 0..input.num_inner_cols() {
|
||||
@@ -620,19 +894,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
res
|
||||
});
|
||||
}
|
||||
selectors.insert((range, x, y), multi_col_selector);
|
||||
self.range_checks
|
||||
.selectors
|
||||
.insert((range, x, y), multi_col_selector);
|
||||
}
|
||||
}
|
||||
self.range_check_selectors.extend(selectors);
|
||||
// if we haven't previously initialized the input/output, do so now
|
||||
if let VarTensor::Empty = self.lookup_input {
|
||||
debug!("assigning lookup input");
|
||||
self.lookup_input = input.clone();
|
||||
if let VarTensor::Empty = self.range_checks.input {
|
||||
debug!("assigning range check input");
|
||||
self.range_checks.input = input.clone();
|
||||
}
|
||||
|
||||
if let VarTensor::Empty = self.lookup_index {
|
||||
debug!("assigning lookup index");
|
||||
self.lookup_index = index.clone();
|
||||
if let VarTensor::Empty = self.range_checks.index {
|
||||
debug!("assigning range check index");
|
||||
self.range_checks.index = index.clone();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -640,7 +915,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
/// layout_tables must be called before layout.
|
||||
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
|
||||
for (i, table) in self.tables.values_mut().enumerate() {
|
||||
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
|
||||
if !table.is_assigned {
|
||||
debug!(
|
||||
"laying out table for {}",
|
||||
@@ -661,7 +936,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
&mut self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
for range_check in self.range_checks.values_mut() {
|
||||
for range_check in self.range_checks.ranges.values_mut() {
|
||||
if !range_check.is_assigned {
|
||||
debug!("laying out range check for {:?}", range_check.range);
|
||||
range_check.layout(layouter)?;
|
||||
|
||||
@@ -277,7 +277,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
..
|
||||
} => {
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
layouts::div(
|
||||
layouts::loop_div(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,10 +14,17 @@ pub enum PolyOp {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
GatherND {
|
||||
batch_dims: usize,
|
||||
indices: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterND {
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
MultiBroadcastTo {
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
@@ -60,8 +67,6 @@ pub enum PolyOp {
|
||||
len_prod: usize,
|
||||
},
|
||||
Pow(u32),
|
||||
Pack(u32, u32),
|
||||
GlobalSumPool,
|
||||
Concat {
|
||||
axis: usize,
|
||||
},
|
||||
@@ -91,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterND { .. } => "SCATTERND".into(),
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -110,8 +117,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Sum { .. } => "SUM".into(),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Pack(_, _) => "PACK".into(),
|
||||
PolyOp::GlobalSumPool => "GLOBALSUMPOOL".into(),
|
||||
PolyOp::Conv { .. } => "CONV".into(),
|
||||
PolyOp::DeConv { .. } => "DECONV".into(),
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
@@ -181,13 +186,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
output_padding,
|
||||
stride,
|
||||
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
|
||||
PolyOp::Pack(base, scale) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pack inputs".to_string()));
|
||||
}
|
||||
|
||||
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
|
||||
}
|
||||
PolyOp::Pow(u) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pow inputs".to_string()));
|
||||
@@ -206,7 +204,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
}
|
||||
tensor::ops::prod_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::GlobalSumPool => unreachable!(),
|
||||
PolyOp::Concat { axis } => {
|
||||
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
|
||||
}
|
||||
@@ -225,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
indices,
|
||||
batch_dims,
|
||||
} => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = indices {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_nd(&x, &y, *batch_dims)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
@@ -241,6 +250,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
}
|
||||
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter_nd(&x, &idx, &src)
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
@@ -288,7 +312,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
|
||||
}
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices,
|
||||
} => {
|
||||
if let Some(idx) = indices {
|
||||
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
|
||||
} else {
|
||||
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
@@ -304,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter_nd(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_nd(config, region, values[..].try_into()?)?
|
||||
}
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
@@ -334,10 +380,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
input
|
||||
}
|
||||
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
|
||||
PolyOp::Pack(base, scale) => {
|
||||
layouts::pack(config, region, values[..].try_into()?, *base, *scale)?
|
||||
}
|
||||
PolyOp::GlobalSumPool => unreachable!(),
|
||||
PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?,
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
|
||||
@@ -405,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
vec![1, 2]
|
||||
} else if matches!(self, PolyOp::Concat { .. }) {
|
||||
(0..100).collect()
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. }) {
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. })
|
||||
| matches!(self, PolyOp::ScatterND { .. })
|
||||
{
|
||||
vec![0, 2]
|
||||
} else {
|
||||
vec![]
|
||||
|
||||
@@ -20,6 +20,66 @@ use portable_atomic::AtomicI128 as AtomicInt;
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
|
||||
/// Dynamic lookup index
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookupIndex {
|
||||
index: usize,
|
||||
col_coord: usize,
|
||||
}
|
||||
|
||||
impl DynamicLookupIndex {
|
||||
/// Create a new dynamic lookup index
|
||||
pub fn new(index: usize, col_coord: usize) -> DynamicLookupIndex {
|
||||
DynamicLookupIndex { index, col_coord }
|
||||
}
|
||||
|
||||
/// Get the lookup index
|
||||
pub fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
/// Get the column coord
|
||||
pub fn col_coord(&self) -> usize {
|
||||
self.col_coord
|
||||
}
|
||||
|
||||
/// update with another dynamic lookup index
|
||||
pub fn update(&mut self, other: &DynamicLookupIndex) {
|
||||
self.index += other.index;
|
||||
self.col_coord += other.col_coord;
|
||||
}
|
||||
}
|
||||
|
||||
/// Dynamic lookup index
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ShuffleIndex {
|
||||
index: usize,
|
||||
col_coord: usize,
|
||||
}
|
||||
|
||||
impl ShuffleIndex {
|
||||
/// Create a new dynamic lookup index
|
||||
pub fn new(index: usize, col_coord: usize) -> ShuffleIndex {
|
||||
ShuffleIndex { index, col_coord }
|
||||
}
|
||||
|
||||
/// Get the lookup index
|
||||
pub fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
/// Get the column coord
|
||||
pub fn col_coord(&self) -> usize {
|
||||
self.col_coord
|
||||
}
|
||||
|
||||
/// update with another shuffle index
|
||||
pub fn update(&mut self, other: &ShuffleIndex) {
|
||||
self.index += other.index;
|
||||
self.col_coord += other.col_coord;
|
||||
}
|
||||
}
|
||||
|
||||
/// Region error
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RegionError {
|
||||
@@ -66,12 +126,14 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
total_constants: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i128,
|
||||
min_lookup_inputs: i128,
|
||||
min_range_check: i128,
|
||||
max_range_check: i128,
|
||||
max_range_size: i128,
|
||||
throw_range_check_error: bool,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
@@ -80,6 +142,31 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.total_constants += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn increment_dynamic_lookup_index(&mut self, n: usize) {
|
||||
self.dynamic_lookup_index.index += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
|
||||
self.dynamic_lookup_index.col_coord += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn increment_shuffle_index(&mut self, n: usize) {
|
||||
self.shuffle_index.index += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn increment_shuffle_col_coord(&mut self, n: usize) {
|
||||
self.shuffle_index.col_coord += n;
|
||||
}
|
||||
|
||||
///
|
||||
pub fn throw_range_check_error(&self) -> bool {
|
||||
self.throw_range_check_error
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
|
||||
let region = Some(RefCell::new(region));
|
||||
@@ -91,12 +178,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
row,
|
||||
linear_coord,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error: false,
|
||||
}
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
@@ -104,6 +193,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let linear_coord = row * num_inner_cols;
|
||||
RegionCtx {
|
||||
@@ -112,17 +203,23 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index,
|
||||
shuffle_index,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> {
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
throw_range_check_error: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
let linear_coord = row * num_inner_cols;
|
||||
|
||||
@@ -132,12 +229,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,8 +246,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord: usize,
|
||||
total_constants: usize,
|
||||
num_inner_cols: usize,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
throw_range_check_error: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -157,12 +255,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants,
|
||||
used_lookups,
|
||||
used_range_checks,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,6 +317,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
|
||||
*output = output
|
||||
.par_enum_map(|idx, _| {
|
||||
@@ -232,8 +334,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
starting_linear_coord,
|
||||
starting_constants,
|
||||
self.num_inner_cols,
|
||||
HashSet::new(),
|
||||
HashSet::new(),
|
||||
self.throw_range_check_error,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -252,14 +353,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
// update the range checks
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
// update the dynamic lookup index
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
|
||||
res
|
||||
})
|
||||
.map_err(|e| {
|
||||
log::error!("dummy_loop: {:?}", e);
|
||||
Error::Synthesis
|
||||
})?;
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
|
||||
self.total_constants = constants.into_inner();
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
@@ -282,6 +388,28 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
|
||||
})?;
|
||||
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!(
|
||||
"dummy_loop: failed to get dynamic lookup index: {:?}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!(
|
||||
"dummy_loop: failed to get dynamic lookup index: {:?}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
self.shuffle_index = Arc::try_unwrap(shuffle_index)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
|
||||
})?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -310,8 +438,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
return Err("update_max_min_lookup_range: invalid range".into());
|
||||
}
|
||||
|
||||
self.max_range_check = self.max_range_check.max(range.1);
|
||||
self.min_range_check = self.min_range_check.min(range.0);
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
|
||||
self.max_range_size = self.max_range_size.max(range_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -351,6 +480,26 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.total_constants
|
||||
}
|
||||
|
||||
/// Get the dynamic lookup index
|
||||
pub fn dynamic_lookup_index(&self) -> usize {
|
||||
self.dynamic_lookup_index.index
|
||||
}
|
||||
|
||||
/// Get the dynamic lookup column coordinate
|
||||
pub fn dynamic_lookup_col_coord(&self) -> usize {
|
||||
self.dynamic_lookup_index.col_coord
|
||||
}
|
||||
|
||||
/// Get the shuffle index
|
||||
pub fn shuffle_index(&self) -> usize {
|
||||
self.shuffle_index.index
|
||||
}
|
||||
|
||||
/// Get the shuffle column coordinate
|
||||
pub fn shuffle_col_coord(&self) -> usize {
|
||||
self.shuffle_index.col_coord
|
||||
}
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.used_lookups.clone()
|
||||
@@ -371,14 +520,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// min range check
|
||||
pub fn min_range_check(&self) -> i128 {
|
||||
self.min_range_check
|
||||
}
|
||||
|
||||
/// max range check
|
||||
pub fn max_range_check(&self) -> i128 {
|
||||
self.max_range_check
|
||||
pub fn max_range_size(&self) -> i128 {
|
||||
self.max_range_size
|
||||
}
|
||||
|
||||
/// Assign a constant value
|
||||
@@ -405,6 +549,38 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
pub fn combined_dynamic_shuffle_coord(&self) -> usize {
|
||||
self.dynamic_lookup_col_coord() + self.shuffle_col_coord()
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
pub fn assign_dynamic_lookup(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
self.total_constants += values.num_constants();
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.combined_dynamic_shuffle_coord(),
|
||||
values,
|
||||
)
|
||||
} else {
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
pub fn assign_shuffle(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
pub fn assign_with_omissions(
|
||||
&mut self,
|
||||
|
||||
@@ -133,9 +133,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range: Range, col_size: usize) -> usize {
|
||||
// double it to be safe
|
||||
let range_len = range.1 - range.0;
|
||||
pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as i128)) as usize + 1
|
||||
}
|
||||
@@ -152,7 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
// number of cols needed to store the range
|
||||
let num_cols = num_cols_required(range, col_size);
|
||||
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
|
||||
|
||||
log::debug!("table range: {:?}", range);
|
||||
|
||||
@@ -313,7 +311,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
let factors = cs.blinding_factors() + RESERVED_BLINDING_ROWS_PAD;
|
||||
let col_size = Self::cal_col_size(logrows, factors);
|
||||
// number of cols needed to store the range
|
||||
let num_cols = num_cols_required(range, col_size);
|
||||
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
|
||||
|
||||
let inputs = {
|
||||
let mut cols = vec![];
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::circuit::ops::hybrid::HybridOp;
|
||||
use crate::circuit::ops::poly::PolyOp;
|
||||
use crate::circuit::*;
|
||||
use crate::tensor::{Tensor, TensorType, ValTensor, VarTensor};
|
||||
@@ -358,8 +357,6 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,8 +474,6 @@ mod matmul_col_ultra_overflow {
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1281,8 +1276,6 @@ mod conv_col_ultra_overflow {
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1436,8 +1429,6 @@ mod conv_relu_col_ultra_overflow {
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1575,6 +1566,280 @@ mod add {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod dynamic_lookup {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 6;
|
||||
const LEN: usize = 4;
|
||||
const NUM_LOOP: usize = 5;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
tables: [[ValTensor<F>; 2]; NUM_LOOP],
|
||||
lookups: [[ValTensor<F>; 2]; NUM_LOOP],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MyCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
|
||||
let d = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let e = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
|
||||
|
||||
let mut config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
|
||||
config
|
||||
.configure_dynamic_lookup(
|
||||
cs,
|
||||
&[a.clone(), b.clone(), c.clone()],
|
||||
&[d.clone(), e.clone(), f.clone()],
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
for i in 0..NUM_LOOP {
|
||||
layouts::dynamic_lookup(
|
||||
&config,
|
||||
&mut region,
|
||||
&self.lookups[i],
|
||||
&self.tables[i],
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)?;
|
||||
}
|
||||
assert_eq!(
|
||||
region.dynamic_lookup_col_coord(),
|
||||
NUM_LOOP * self.tables[0][0].len()
|
||||
);
|
||||
assert_eq!(region.dynamic_lookup_index(), NUM_LOOP);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamiclookupcircuit() {
|
||||
// parameters
|
||||
let tables = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
|
||||
)),
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..LEN).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
|
||||
)),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lookups = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..3).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
|
||||
)),
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..3).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
|
||||
)),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
tables: tables.clone().try_into().unwrap(),
|
||||
lookups: lookups.try_into().unwrap(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
|
||||
let lookups = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
let prev_idx = if loop_idx == 0 {
|
||||
NUM_LOOP - 1
|
||||
} else {
|
||||
loop_idx - 1
|
||||
};
|
||||
[
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..3).map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
|
||||
)),
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..3).map(|i| Value::known(F::from((prev_idx * i * i) as u64 + 1))),
|
||||
)),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
tables: tables.try_into().unwrap(),
|
||||
lookups: lookups.try_into().unwrap(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
assert!(prover.verify().is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod shuffle {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 6;
|
||||
const LEN: usize = 4;
|
||||
const NUM_LOOP: usize = 5;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
|
||||
references: [[ValTensor<F>; 1]; NUM_LOOP],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MyCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
|
||||
|
||||
let d = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let e = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
|
||||
|
||||
let mut config =
|
||||
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
|
||||
config
|
||||
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
for i in 0..NUM_LOOP {
|
||||
layouts::shuffles(
|
||||
&config,
|
||||
&mut region,
|
||||
&self.inputs[i],
|
||||
&self.references[i],
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)?;
|
||||
}
|
||||
assert_eq!(
|
||||
region.shuffle_col_coord(),
|
||||
NUM_LOOP * self.references[0][0].len()
|
||||
);
|
||||
assert_eq!(region.shuffle_index(), NUM_LOOP);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shufflecircuit() {
|
||||
// parameters
|
||||
let references = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
|
||||
Value::known(F::from((i * loop_idx) as u64 + 1))
|
||||
})))]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inputs = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
|
||||
Value::known(F::from((i * loop_idx) as u64 + 1))
|
||||
})))]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
references: references.clone().try_into().unwrap(),
|
||||
inputs: inputs.try_into().unwrap(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
|
||||
let inputs = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
let prev_idx = if loop_idx == 0 {
|
||||
NUM_LOOP - 1
|
||||
} else {
|
||||
loop_idx - 1
|
||||
};
|
||||
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
|
||||
Value::known(F::from((i * prev_idx) as u64 + 1))
|
||||
})))]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
references: references.try_into().unwrap(),
|
||||
inputs: inputs.try_into().unwrap(),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
assert!(prover.verify().is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod add_with_overflow {
|
||||
use super::*;
|
||||
@@ -1978,75 +2243,6 @@ mod pow {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod pack {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 8;
|
||||
const LEN: usize = 4;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 1],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MyCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
Box::new(PolyOp::Pack(2, 1)),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn packcircuit() {
|
||||
// parameters
|
||||
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
inputs: [ValTensor::from(a)],
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod matmul_relu {
|
||||
use super::*;
|
||||
@@ -2334,117 +2530,5 @@ mod lookup_ultra_overflow {
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod softmax {
|
||||
|
||||
use super::*;
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
dev::MockProver,
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
|
||||
const K: usize = 18;
|
||||
const LEN: usize = 3;
|
||||
const SCALE: f32 = 128.0;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SoftmaxCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub input: ValTensor<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for SoftmaxCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, 1, LEN);
|
||||
let mut config = Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE);
|
||||
let advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&advices[0],
|
||||
&advices[1],
|
||||
&advices[2],
|
||||
(-32768, 32768),
|
||||
K,
|
||||
&LookupOp::Exp {
|
||||
scale: SCALE.into(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&advices[0],
|
||||
&advices[1],
|
||||
&advices[2],
|
||||
(-32768, 32768),
|
||||
K,
|
||||
&LookupOp::Recip {
|
||||
input_scale: SCALE.into(),
|
||||
output_scale: SCALE.into(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout_tables(&mut layouter).unwrap();
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1);
|
||||
let _output = config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
Box::new(HybridOp::Softmax {
|
||||
scale: SCALE.into(),
|
||||
axes: vec![0],
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax_circuit() {
|
||||
let input = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
|
||||
|
||||
let circuit = SoftmaxCircuit::<F> {
|
||||
input: ValTensor::from(input),
|
||||
_marker: PhantomData,
|
||||
};
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
|
||||
/// Default lookup safety margin
|
||||
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
|
||||
/// Default Compress selectors
|
||||
pub const DEFAULT_COMPRESS_SELECTORS: &str = "false";
|
||||
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
|
||||
/// Default render vk seperately
|
||||
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
|
||||
/// Default VK sol path
|
||||
@@ -402,9 +402,6 @@ pub enum Commands {
|
||||
/// Number of logrows to use for srs. Overrides settings_path if specified.
|
||||
#[arg(long, default_value = None)]
|
||||
logrows: Option<u32>,
|
||||
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE)]
|
||||
check: CheckMode,
|
||||
},
|
||||
/// Loads model and input and runs mock prover (for testing)
|
||||
Mock {
|
||||
@@ -450,8 +447,8 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_SPLIT)]
|
||||
split_proofs: bool,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
|
||||
compress_selectors: bool,
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
|
||||
disable_selector_compression: bool,
|
||||
},
|
||||
/// Aggregates proofs :)
|
||||
Aggregate {
|
||||
@@ -471,7 +468,7 @@ pub enum Commands {
|
||||
long,
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = TranscriptType::EVM,
|
||||
default_value_t = TranscriptType::default(),
|
||||
value_enum
|
||||
)]
|
||||
transcript: TranscriptType,
|
||||
@@ -515,8 +512,8 @@ pub enum Commands {
|
||||
#[arg(short = 'W', long)]
|
||||
witness: Option<PathBuf>,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
|
||||
compress_selectors: bool,
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
|
||||
disable_selector_compression: bool,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -526,13 +523,13 @@ pub enum Commands {
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
|
||||
witness: PathBuf,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT)]
|
||||
#[arg(short = 'M', long)]
|
||||
compiled_circuit: PathBuf,
|
||||
#[arg(
|
||||
long,
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = TranscriptType::EVM,
|
||||
default_value_t = TranscriptType::default(),
|
||||
value_enum
|
||||
)]
|
||||
transcript: TranscriptType,
|
||||
@@ -540,8 +537,8 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
|
||||
num_runs: usize,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
|
||||
compress_selectors: bool,
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
|
||||
disable_selector_compression: bool,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
|
||||
200
src/execute.rs
200
src/execute.rs
@@ -23,6 +23,7 @@ use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
use crate::RunArgs;
|
||||
#[cfg(unix)]
|
||||
use gag::Gag;
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
@@ -63,7 +64,11 @@ use std::process::Command;
|
||||
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::sync::OnceLock;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::BufWriter;
|
||||
use std::time::Duration;
|
||||
use tabled::Tabled;
|
||||
use thiserror::Error;
|
||||
@@ -140,13 +145,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
compiled_circuit,
|
||||
transcript,
|
||||
num_runs,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
} => fuzz(
|
||||
compiled_circuit,
|
||||
witness,
|
||||
transcript,
|
||||
num_runs,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
),
|
||||
Commands::GenSrs { srs_path, logrows } => gen_srs_cmd(srs_path, logrows as u32),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -154,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
check,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
|
||||
} => get_srs_cmd(srs_path, settings_path, logrows).await,
|
||||
Commands::Table { model, args } => table(model, args),
|
||||
#[cfg(feature = "render")]
|
||||
Commands::RenderCircuit {
|
||||
@@ -260,14 +264,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
vk_path,
|
||||
pk_path,
|
||||
witness,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
} => setup(
|
||||
compiled_circuit,
|
||||
srs_path,
|
||||
vk_path,
|
||||
pk_path,
|
||||
witness,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Commands::SetupTestEvmData {
|
||||
@@ -331,7 +335,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
logrows,
|
||||
split_proofs,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
} => setup_aggregate(
|
||||
sample_snarks,
|
||||
vk_path,
|
||||
@@ -339,7 +343,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
|
||||
srs_path,
|
||||
logrows,
|
||||
split_proofs,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
),
|
||||
Commands::Aggregate {
|
||||
proof_path,
|
||||
@@ -487,16 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
|
||||
use std::io::Read;
|
||||
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let file = std::fs::File::open(path.clone())?;
|
||||
let file = std::fs::File::open(path)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut buffer = vec![];
|
||||
let bytes_read = std::io::BufReader::new(file).read_to_end(&mut buffer)?;
|
||||
debug!("read {} bytes from SRS file", bytes_read);
|
||||
let bytes_read = reader.read_to_end(&mut buffer)?;
|
||||
info!(
|
||||
"read {} bytes from file (vector of len = {})",
|
||||
bytes_read,
|
||||
buffer.len()
|
||||
);
|
||||
|
||||
let hash = sha256::digest(buffer);
|
||||
info!("SRS hash: {}", hash);
|
||||
info!("file hash: {}", hash);
|
||||
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
|
||||
let path = get_srs_path(logrows, srs_path);
|
||||
let hash = get_file_hash(&path)?;
|
||||
|
||||
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
|
||||
Some(h) => h,
|
||||
@@ -520,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
|
||||
srs_path: Option<PathBuf>,
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
check_mode: CheckMode,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// logrows overrides settings
|
||||
|
||||
@@ -548,18 +563,21 @@ pub(crate) async fn get_srs_cmd(
|
||||
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
|
||||
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
|
||||
// check the SRS
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated");
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let pb = init_spinner();
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.set_message("Validating SRS (this may take a while) ...");
|
||||
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pb.finish_with_message("SRS validated.");
|
||||
|
||||
info!("Saving SRS to disk...");
|
||||
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
|
||||
file.write_all(reader.get_ref())?;
|
||||
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
|
||||
params.write(&mut buffer)?;
|
||||
|
||||
info!("Saved SRS to disk.");
|
||||
|
||||
info!("SRS downloaded");
|
||||
} else {
|
||||
info!("SRS already exists at that path");
|
||||
@@ -618,7 +636,7 @@ pub(crate) async fn gen_witness(
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref())?;
|
||||
let witness = circuit.forward(&mut input, vk.as_ref(), srs.as_ref(), false)?;
|
||||
|
||||
// print each variable tuple (symbol, value) as symbol=value
|
||||
trace!(
|
||||
@@ -806,18 +824,8 @@ pub(crate) fn calibrate(
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
// now retrieve the run args
|
||||
// we load the model to get the input and output shapes
|
||||
// check if gag already exists
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(r) => Some(r),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
let model = Model::from_run_args(&settings.run_args, &model_path)?;
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
|
||||
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
|
||||
info!("num of calibration batches: {}", chunks.len());
|
||||
@@ -833,7 +841,7 @@ pub(crate) fn calibrate(
|
||||
let range = if let Some(scales) = scales {
|
||||
scales
|
||||
} else {
|
||||
(10..14).collect::<Vec<crate::Scale>>()
|
||||
(11..14).collect::<Vec<crate::Scale>>()
|
||||
};
|
||||
|
||||
let div_rebasing = if only_range_check_rebase {
|
||||
@@ -896,17 +904,12 @@ pub(crate) fn calibrate(
|
||||
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
|
||||
));
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(r) => Some(r),
|
||||
Err(_) => None,
|
||||
};
|
||||
#[cfg(unix)]
|
||||
let _q = match Gag::stderr() {
|
||||
Ok(r) => Some(r),
|
||||
Err(_) => None,
|
||||
};
|
||||
let key = (input_scale, param_scale, scale_rebase_multiplier);
|
||||
let key = (
|
||||
input_scale,
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
);
|
||||
forward_pass_res.insert(key, vec![]);
|
||||
|
||||
let local_run_args = RunArgs {
|
||||
@@ -917,20 +920,27 @@ pub(crate) fn calibrate(
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
// if unix get a gag
|
||||
#[cfg(unix)]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
#[cfg(unix)]
|
||||
let _g = match Gag::stderr() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_q);
|
||||
debug!("circuit creation from run args failed: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
chunks
|
||||
let forward_res = chunks
|
||||
.iter()
|
||||
.map(|chunk| {
|
||||
let chunk = chunk.clone();
|
||||
@@ -940,7 +950,7 @@ pub(crate) fn calibrate(
|
||||
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;
|
||||
|
||||
let forward_res = circuit
|
||||
.forward(&mut data.clone(), None, None)
|
||||
.forward(&mut data.clone(), None, None, true)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
// push result to the hashmap
|
||||
@@ -951,53 +961,46 @@ pub(crate) fn calibrate(
|
||||
|
||||
Ok(()) as Result<(), String>
|
||||
})
|
||||
.collect::<Result<Vec<()>, String>>()?;
|
||||
.collect::<Result<Vec<()>, String>>();
|
||||
|
||||
let min_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
match forward_res {
|
||||
Ok(_) => (),
|
||||
// typically errors will be due to the circuit overflowing the i128 limit
|
||||
Err(e) => {
|
||||
debug!("forward pass failed: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
drop(_r);
|
||||
#[cfg(unix)]
|
||||
drop(_g);
|
||||
|
||||
let result = forward_pass_res.get(&key).ok_or("key not found")?;
|
||||
|
||||
let min_lookup_range = result
|
||||
.iter()
|
||||
.map(|x| x.min_lookup_inputs)
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
|
||||
let max_lookup_range = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
let max_lookup_range = result
|
||||
.iter()
|
||||
.map(|x| x.max_lookup_inputs)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let min_range_check = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| x.min_range_check)
|
||||
.min()
|
||||
.unwrap_or(0);
|
||||
let max_range_size = result.iter().map(|x| x.max_range_size).max().unwrap_or(0);
|
||||
|
||||
let max_range_check = forward_pass_res
|
||||
.get(&key)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| x.max_range_check)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let res = circuit.calibrate_from_min_max(
|
||||
let res = circuit.calc_min_logrows(
|
||||
(min_lookup_range, max_lookup_range),
|
||||
(min_range_check, max_range_check),
|
||||
max_range_size,
|
||||
max_logrows,
|
||||
lookup_safety_margin,
|
||||
);
|
||||
|
||||
// // drop the gag
|
||||
// #[cfg(unix)]
|
||||
// std::mem::drop(_r);
|
||||
// #[cfg(unix)]
|
||||
// std::mem::drop(_q);
|
||||
|
||||
if res.is_ok() {
|
||||
let new_settings = circuit.settings().clone();
|
||||
|
||||
@@ -1110,6 +1113,7 @@ pub(crate) fn calibrate(
|
||||
best_params.run_args.input_scale,
|
||||
best_params.run_args.param_scale,
|
||||
best_params.run_args.scale_rebase_multiplier,
|
||||
best_params.run_args.div_rebasing,
|
||||
))
|
||||
.ok_or("no params found")?
|
||||
.iter()
|
||||
@@ -1528,7 +1532,7 @@ pub(crate) fn setup(
|
||||
vk_path: PathBuf,
|
||||
pk_path: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
compress_selectors: bool,
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// these aren't real values so the sanity checks are mostly meaningless
|
||||
let mut circuit = GraphCircuit::load(compiled_circuit)?;
|
||||
@@ -1542,7 +1546,7 @@ pub(crate) fn setup(
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
@@ -1683,7 +1687,7 @@ pub(crate) fn fuzz(
|
||||
data_path: PathBuf,
|
||||
transcript: TranscriptType,
|
||||
num_runs: usize,
|
||||
compress_selectors: bool,
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
check_solc_requirement();
|
||||
let passed = AtomicBool::new(true);
|
||||
@@ -1693,7 +1697,7 @@ pub(crate) fn fuzz(
|
||||
let logrows = circuit.settings().run_args.logrows;
|
||||
|
||||
info!("setting up tests");
|
||||
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout()?;
|
||||
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
|
||||
|
||||
@@ -1702,7 +1706,7 @@ pub(crate) fn fuzz(
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from)?;
|
||||
|
||||
@@ -1711,6 +1715,7 @@ pub(crate) fn fuzz(
|
||||
let public_inputs = circuit.prepare_public_inputs(&data)?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
|
||||
info!("starting fuzzing");
|
||||
@@ -1723,7 +1728,7 @@ pub(crate) fn fuzz(
|
||||
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
&new_params,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
)
|
||||
.map_err(|_| ())?;
|
||||
|
||||
@@ -1801,7 +1806,7 @@ pub(crate) fn fuzz(
|
||||
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
|
||||
&circuit,
|
||||
&new_params,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
)
|
||||
.map_err(|_| ())?;
|
||||
|
||||
@@ -1901,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
passed: &AtomicBool,
|
||||
) {
|
||||
let num_failures = AtomicI64::new(0);
|
||||
#[cfg(unix)]
|
||||
let _r = Gag::stdout().unwrap();
|
||||
|
||||
let pb = init_bar(num_runs as u64);
|
||||
@@ -1914,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
|
||||
pb.inc(1);
|
||||
});
|
||||
pb.finish_with_message("Done.");
|
||||
#[cfg(unix)]
|
||||
std::mem::drop(_r);
|
||||
info!(
|
||||
"num failures: {} out of {}",
|
||||
@@ -1980,7 +1987,7 @@ pub(crate) fn setup_aggregate(
|
||||
srs_path: Option<PathBuf>,
|
||||
logrows: u32,
|
||||
split_proofs: bool,
|
||||
compress_selectors: bool,
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// the K used for the aggregation circuit
|
||||
let params = load_params_cmd(srs_path, logrows)?;
|
||||
@@ -1994,7 +2001,7 @@ pub(crate) fn setup_aggregate(
|
||||
let agg_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(
|
||||
&agg_circuit,
|
||||
¶ms,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
)?;
|
||||
|
||||
let agg_vk = agg_pk.get_vk();
|
||||
@@ -2071,7 +2078,8 @@ pub(crate) fn verify(
|
||||
let circuit_settings = GraphSettings::load(&settings_path)?;
|
||||
|
||||
let params = if reduced_srs {
|
||||
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
|
||||
// only need G_0 for the verification with shplonk
|
||||
load_params_cmd(srs_path, 1)?
|
||||
} else {
|
||||
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
|
||||
};
|
||||
|
||||
346
src/graph/mod.rs
346
src/graph/mod.rs
@@ -12,6 +12,8 @@ pub mod utilities;
|
||||
pub mod vars;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored_json::ToColoredJson;
|
||||
#[cfg(unix)]
|
||||
use gag::Gag;
|
||||
use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
pub use input::DataSource;
|
||||
@@ -61,8 +63,11 @@ use crate::pfsys::field_to_string;
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i128 = 2;
|
||||
|
||||
/// The maximum number of columns in a lookup table.
|
||||
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
|
||||
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: i128 = 8 * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
@@ -134,15 +139,16 @@ pub enum GraphError {
|
||||
MissingResults,
|
||||
}
|
||||
|
||||
const ASSUMED_BLINDING_FACTORS: usize = 5;
|
||||
///
|
||||
pub const ASSUMED_BLINDING_FACTORS: usize = 5;
|
||||
/// The minimum number of rows in the grid
|
||||
pub const MIN_LOGROWS: u32 = 6;
|
||||
|
||||
/// 26
|
||||
pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
|
||||
|
||||
/// Lookup deg
|
||||
pub const LOOKUP_DEG: usize = 5;
|
||||
///
|
||||
pub const RESERVED_BLINDING_ROWS: usize = ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD;
|
||||
|
||||
use std::cell::RefCell;
|
||||
|
||||
@@ -171,10 +177,8 @@ pub struct GraphWitness {
|
||||
pub max_lookup_inputs: i128,
|
||||
/// max lookup input
|
||||
pub min_lookup_inputs: i128,
|
||||
/// max range check input
|
||||
pub max_range_check: i128,
|
||||
/// max range check input
|
||||
pub min_range_check: i128,
|
||||
/// max range check size
|
||||
pub max_range_size: i128,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -202,8 +206,7 @@ impl GraphWitness {
|
||||
processed_outputs: None,
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_check: 0,
|
||||
min_range_check: 0,
|
||||
max_range_size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -376,9 +379,7 @@ impl ToPyObject for GraphWitness {
|
||||
.unwrap();
|
||||
dict.set_item("min_lookup_inputs", self.min_lookup_inputs)
|
||||
.unwrap();
|
||||
dict.set_item("max_range_check", self.max_range_check)
|
||||
.unwrap();
|
||||
dict.set_item("min_range_check", self.min_range_check)
|
||||
dict.set_item("max_range_size", self.max_range_size)
|
||||
.unwrap();
|
||||
|
||||
if let Some(processed_inputs) = &self.processed_inputs {
|
||||
@@ -450,6 +451,14 @@ pub struct GraphSettings {
|
||||
pub total_assignments: usize,
|
||||
/// total const size
|
||||
pub total_const_size: usize,
|
||||
/// total dynamic column size
|
||||
pub total_dynamic_col_size: usize,
|
||||
/// number of dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
/// number of shuffles
|
||||
pub num_shuffles: usize,
|
||||
/// total shuffle column size
|
||||
pub total_shuffle_col_size: usize,
|
||||
/// the shape of public inputs to the model (in order of appearance)
|
||||
pub model_instance_shapes: Vec<Vec<usize>>,
|
||||
/// model output scales
|
||||
@@ -473,6 +482,30 @@ pub struct GraphSettings {
|
||||
}
|
||||
|
||||
impl GraphSettings {
|
||||
fn model_constraint_logrows(&self) -> u32 {
|
||||
(self.num_rows as f64 + RESERVED_BLINDING_ROWS as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn dynamic_lookup_and_shuffle_logrows(&self) -> u32 {
|
||||
(self.total_dynamic_col_size as f64 + self.total_shuffle_col_size as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
|
||||
self.total_dynamic_col_size + self.total_shuffle_col_size
|
||||
}
|
||||
|
||||
fn module_constraint_logrows(&self) -> u32 {
|
||||
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
|
||||
}
|
||||
|
||||
fn constants_logrows(&self) -> u32 {
|
||||
(self.total_const_size as f64).log2().ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the total number of instances
|
||||
pub fn total_instances(&self) -> Vec<usize> {
|
||||
let mut instances: Vec<usize> = self
|
||||
@@ -557,6 +590,16 @@ impl GraphSettings {
|
||||
|| self.run_args.param_visibility.is_hashed()
|
||||
}
|
||||
|
||||
/// requires dynamic lookup
|
||||
pub fn requires_dynamic_lookup(&self) -> bool {
|
||||
self.num_dynamic_lookups > 0
|
||||
}
|
||||
|
||||
/// requires dynamic shuffle
|
||||
pub fn requires_shuffle(&self) -> bool {
|
||||
self.num_shuffles > 0
|
||||
}
|
||||
|
||||
/// any kzg visibility
|
||||
pub fn module_requires_kzg(&self) -> bool {
|
||||
self.run_args.input_visibility.is_kzgcommit()
|
||||
@@ -1005,10 +1048,6 @@ impl GraphCircuit {
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn reserved_blinding_rows() -> f64 {
|
||||
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
|
||||
}
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
|
||||
let mut margin = (
|
||||
lookup_safety_margin * min_max_lookup.0,
|
||||
@@ -1022,18 +1061,34 @@ impl GraphCircuit {
|
||||
margin
|
||||
}
|
||||
|
||||
fn calc_num_cols(safe_range: Range, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(
|
||||
max_logrows as usize,
|
||||
Self::reserved_blinding_rows() as usize,
|
||||
);
|
||||
num_cols_required(safe_range, max_col_size)
|
||||
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
|
||||
num_cols_required(range_len, max_col_size)
|
||||
}
|
||||
|
||||
fn calc_min_logrows(
|
||||
fn table_size_logrows(
|
||||
&self,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i128,
|
||||
) -> Result<u32, Box<dyn std::error::Error>> {
|
||||
// pick the range with the largest absolute size safe_lookup_range or max_range_size
|
||||
let safe_range = std::cmp::max(
|
||||
(safe_lookup_range.1 - safe_lookup_range.0).abs(),
|
||||
max_range_size,
|
||||
);
|
||||
|
||||
let min_bits = (safe_range as f64 + RESERVED_BLINDING_ROWS as f64 + 1.)
|
||||
.log2()
|
||||
.ceil() as u32;
|
||||
|
||||
Ok(min_bits)
|
||||
}
|
||||
|
||||
/// calculate the minimum logrows required for the circuit
|
||||
pub fn calc_min_logrows(
|
||||
&mut self,
|
||||
min_max_lookup: Range,
|
||||
min_max_range_checks: Range,
|
||||
max_range_size: i128,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i128,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
@@ -1043,68 +1098,60 @@ impl GraphCircuit {
|
||||
let mut max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);
|
||||
let mut min_logrows = MIN_LOGROWS;
|
||||
|
||||
let reserved_blinding_rows = Self::reserved_blinding_rows();
|
||||
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
|
||||
|
||||
// check if has overflowed max lookup input
|
||||
if min_max_lookup.1.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
|
||||
|| min_max_lookup.0.abs() > MAX_LOOKUP_ABS / lookup_safety_margin
|
||||
{
|
||||
if (min_max_lookup.1 - min_max_lookup.0).abs() > MAX_LOOKUP_ABS / lookup_safety_margin {
|
||||
let err_string = format!("max lookup input {:?} is too large", min_max_lookup);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
if min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
|
||||
|| min_max_range_checks.1.abs() > MAX_LOOKUP_ABS
|
||||
{
|
||||
let err_string = format!(
|
||||
"max range check input {:?} is too large",
|
||||
min_max_range_checks
|
||||
);
|
||||
if max_range_size.abs() > MAX_LOOKUP_ABS {
|
||||
let err_string = format!("max range check size {:?} is too large", max_range_size);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
|
||||
// pick the range with the largest absolute size between safe_lookup_range and min_max_range_checks
|
||||
let safe_range = if (safe_lookup_range.1 - safe_lookup_range.0)
|
||||
> (min_max_range_checks.1 - min_max_range_checks.0)
|
||||
{
|
||||
safe_lookup_range
|
||||
} else {
|
||||
min_max_range_checks
|
||||
};
|
||||
// These are hard lower limits, we can't overflow instances or modules constraints
|
||||
let instance_logrows = self.settings().log2_total_instances();
|
||||
let module_constraint_logrows = self.settings().module_constraint_logrows();
|
||||
let dynamic_lookup_logrows = self.settings().dynamic_lookup_and_shuffle_logrows();
|
||||
min_logrows = std::cmp::max(
|
||||
min_logrows,
|
||||
// max of the instance logrows and the module constraint logrows and the dynamic lookup logrows is the lower limit
|
||||
*[
|
||||
instance_logrows,
|
||||
module_constraint_logrows,
|
||||
dynamic_lookup_logrows,
|
||||
]
|
||||
.iter()
|
||||
.max()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// These are upper limits, going above these is wasteful, but they are not hard limits
|
||||
let model_constraint_logrows = self.settings().model_constraint_logrows();
|
||||
let min_bits = self.table_size_logrows(safe_lookup_range, max_range_size)?;
|
||||
let constants_logrows = self.settings().constants_logrows();
|
||||
max_logrows = std::cmp::min(
|
||||
max_logrows,
|
||||
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
|
||||
*[model_constraint_logrows, min_bits, constants_logrows]
|
||||
.iter()
|
||||
.max()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// we now have a min and max logrows
|
||||
max_logrows = std::cmp::max(min_logrows, max_logrows);
|
||||
|
||||
// degrade the max logrows until the extended k is small enough
|
||||
while min_logrows < max_logrows
|
||||
&& !self.extended_k_is_small_enough(
|
||||
min_logrows,
|
||||
Self::calc_num_cols(safe_range, min_logrows),
|
||||
)
|
||||
{
|
||||
min_logrows += 1;
|
||||
}
|
||||
|
||||
if !self
|
||||
.extended_k_is_small_enough(min_logrows, Self::calc_num_cols(safe_range, min_logrows))
|
||||
{
|
||||
let err_string = format!(
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
min_logrows
|
||||
);
|
||||
debug!("{}", err_string);
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
while min_logrows < max_logrows
|
||||
&& !self.extended_k_is_small_enough(
|
||||
max_logrows,
|
||||
Self::calc_num_cols(safe_range, max_logrows),
|
||||
)
|
||||
&& !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size)
|
||||
{
|
||||
max_logrows -= 1;
|
||||
}
|
||||
|
||||
if !self
|
||||
.extended_k_is_small_enough(max_logrows, Self::calc_num_cols(safe_range, max_logrows))
|
||||
{
|
||||
if !self.extended_k_is_small_enough(max_logrows, safe_lookup_range, max_range_size) {
|
||||
let err_string = format!(
|
||||
"extended k is too large to accommodate the quotient polynomial with logrows {}",
|
||||
max_logrows
|
||||
@@ -1113,67 +1160,27 @@ impl GraphCircuit {
|
||||
return Err(err_string.into());
|
||||
}
|
||||
|
||||
let min_bits = ((safe_range.1 - safe_range.0) as f64 + reserved_blinding_rows + 1.)
|
||||
.log2()
|
||||
.ceil() as usize;
|
||||
|
||||
let min_rows_from_constraints = (self.settings().num_rows as f64 + reserved_blinding_rows)
|
||||
.log2()
|
||||
.ceil() as usize;
|
||||
|
||||
let mut logrows = std::cmp::max(min_bits, min_rows_from_constraints);
|
||||
|
||||
// if public input then public inputs col will have public inputs len
|
||||
if self.settings().run_args.input_visibility.is_public()
|
||||
|| self.settings().run_args.output_visibility.is_public()
|
||||
{
|
||||
let mut max_instance_len = self
|
||||
.model()
|
||||
.instance_shapes()?
|
||||
.iter()
|
||||
.fold(0, |acc, x| std::cmp::max(acc, x.iter().product::<usize>()))
|
||||
as f64
|
||||
+ reserved_blinding_rows;
|
||||
// if there are modules then we need to add the max module size
|
||||
if self.settings().uses_modules() {
|
||||
max_instance_len += self
|
||||
.settings()
|
||||
.module_sizes
|
||||
.num_instances()
|
||||
.iter()
|
||||
.sum::<usize>() as f64;
|
||||
}
|
||||
let instance_len_logrows = (max_instance_len).log2().ceil() as usize;
|
||||
logrows = std::cmp::max(logrows, instance_len_logrows);
|
||||
// this is for fixed const columns
|
||||
}
|
||||
|
||||
// ensure logrows is at least 4
|
||||
logrows = std::cmp::max(logrows, min_logrows as usize);
|
||||
logrows = std::cmp::min(logrows, max_logrows as usize);
|
||||
let logrows = max_logrows;
|
||||
|
||||
let model = self.model().clone();
|
||||
let settings_mut = self.settings_mut();
|
||||
settings_mut.run_args.lookup_range = safe_lookup_range;
|
||||
settings_mut.run_args.logrows = logrows as u32;
|
||||
settings_mut.run_args.logrows = logrows;
|
||||
|
||||
*settings_mut = GraphCircuit::new(model, &settings_mut.run_args)?
|
||||
.settings()
|
||||
.clone();
|
||||
|
||||
// recalculate the total const size give nthe new logrows
|
||||
let total_const_len = settings_mut.total_const_size;
|
||||
let const_len_logrows = (total_const_len as f64).log2().ceil() as u32;
|
||||
settings_mut.run_args.logrows =
|
||||
std::cmp::max(settings_mut.run_args.logrows, const_len_logrows);
|
||||
// recalculate the total number of constraints given the new logrows
|
||||
let min_rows_from_constraints = (settings_mut.num_rows as f64 + reserved_blinding_rows)
|
||||
.log2()
|
||||
.ceil() as u32;
|
||||
settings_mut.run_args.logrows =
|
||||
std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints);
|
||||
|
||||
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);
|
||||
// recalculate the logrows if there has been overflow on the constants
|
||||
settings_mut.run_args.logrows = std::cmp::max(
|
||||
settings_mut.run_args.logrows,
|
||||
settings_mut.constants_logrows(),
|
||||
);
|
||||
// recalculate the logrows if there has been overflow for the model constraints
|
||||
settings_mut.run_args.logrows = std::cmp::max(
|
||||
settings_mut.run_args.logrows,
|
||||
settings_mut.model_constraint_logrows(),
|
||||
);
|
||||
|
||||
debug!(
|
||||
"setting lookup_range to: {:?}, setting logrows to: {}",
|
||||
@@ -1184,12 +1191,48 @@ impl GraphCircuit {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn extended_k_is_small_enough(&self, k: u32, num_lookup_cols: usize) -> bool {
|
||||
let max_degree = self.settings().run_args.num_inner_cols + 2;
|
||||
let max_lookup_degree = LOOKUP_DEG + num_lookup_cols - 1; // num_lookup_cols - 1 is the degree of the lookup synthetic selector
|
||||
fn extended_k_is_small_enough(
|
||||
&self,
|
||||
k: u32,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i128,
|
||||
) -> bool {
|
||||
// if num cols is too large then the extended k is too large
|
||||
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
|
||||
|| Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
let max_degree = std::cmp::max(max_degree, max_lookup_degree);
|
||||
let mut settings = self.settings().clone();
|
||||
settings.run_args.lookup_range = safe_lookup_range;
|
||||
settings.run_args.logrows = k;
|
||||
settings.required_range_checks = vec![(0, max_range_size)];
|
||||
let mut cs = ConstraintSystem::default();
|
||||
// if unix get a gag
|
||||
#[cfg(unix)]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
#[cfg(unix)]
|
||||
let _g = match Gag::stderr() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Self::configure_with_params(&mut cs, settings);
|
||||
|
||||
// drop the gag
|
||||
#[cfg(unix)]
|
||||
drop(_r);
|
||||
#[cfg(unix)]
|
||||
drop(_g);
|
||||
|
||||
#[cfg(feature = "mv-lookup")]
|
||||
let cs = cs.chunk_lookups();
|
||||
// quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial
|
||||
let max_degree = cs.degree();
|
||||
let quotient_poly_degree = (max_degree - 1) as u64;
|
||||
// n = 2^k
|
||||
let n = 1u64 << k;
|
||||
@@ -1204,29 +1247,13 @@ impl GraphCircuit {
|
||||
true
|
||||
}
|
||||
|
||||
/// Calibrate the circuit to the supplied data.
|
||||
pub fn calibrate_from_min_max(
|
||||
&mut self,
|
||||
min_max_lookup: Range,
|
||||
min_max_range_checks: Range,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i128,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.calc_min_logrows(
|
||||
min_max_lookup,
|
||||
min_max_range_checks,
|
||||
max_logrows,
|
||||
lookup_safety_margin,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the forward pass of the model / graph of computations and any associated hashing.
|
||||
pub fn forward(
|
||||
&self,
|
||||
inputs: &mut [Tensor<Fp>],
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&ParamsKZG<Bn256>>,
|
||||
throw_range_check_error: bool,
|
||||
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
@@ -1267,7 +1294,9 @@ impl GraphCircuit {
|
||||
}
|
||||
}
|
||||
|
||||
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;
|
||||
let mut model_results =
|
||||
self.model()
|
||||
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
@@ -1310,8 +1339,7 @@ impl GraphCircuit {
|
||||
processed_outputs,
|
||||
max_lookup_inputs: model_results.max_lookup_inputs,
|
||||
min_lookup_inputs: model_results.min_lookup_inputs,
|
||||
max_range_check: model_results.max_range_check,
|
||||
min_range_check: model_results.min_range_check,
|
||||
max_range_size: model_results.max_range_size,
|
||||
};
|
||||
|
||||
witness.generate_rescaled_elements(
|
||||
@@ -1518,34 +1546,18 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
params.run_args.logrows as usize,
|
||||
);
|
||||
|
||||
let mut vars = ModelVars::new(
|
||||
cs,
|
||||
params.run_args.logrows as usize,
|
||||
params.total_assignments,
|
||||
params.run_args.num_inner_cols,
|
||||
params.total_const_size,
|
||||
params.module_requires_fixed(),
|
||||
);
|
||||
let mut vars = ModelVars::new(cs, ¶ms);
|
||||
|
||||
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
|
||||
|
||||
vars.instantiate_instance(
|
||||
cs,
|
||||
params.model_instance_shapes,
|
||||
params.model_instance_shapes.clone(),
|
||||
params.run_args.input_scale,
|
||||
module_configs.instance,
|
||||
);
|
||||
|
||||
let base = Model::configure(
|
||||
cs,
|
||||
&vars,
|
||||
params.run_args.lookup_range,
|
||||
params.run_args.logrows as usize,
|
||||
params.required_lookups,
|
||||
params.required_range_checks,
|
||||
params.check_mode,
|
||||
)
|
||||
.unwrap();
|
||||
let base = Model::configure(cs, &vars, ¶ms).unwrap();
|
||||
|
||||
let model_config = ModelConfig { base, vars };
|
||||
|
||||
|
||||
@@ -67,10 +67,8 @@ pub struct ForwardResult {
|
||||
pub max_lookup_inputs: i128,
|
||||
/// The minimum value of any input to a lookup operation.
|
||||
pub min_lookup_inputs: i128,
|
||||
/// The max range check value
|
||||
pub max_range_check: i128,
|
||||
/// The min range check value
|
||||
pub min_range_check: i128,
|
||||
/// The max range check size
|
||||
pub max_range_size: i128,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
@@ -79,8 +77,7 @@ impl From<DummyPassRes> for ForwardResult {
|
||||
outputs: res.outputs,
|
||||
max_lookup_inputs: res.max_lookup_inputs,
|
||||
min_lookup_inputs: res.min_lookup_inputs,
|
||||
min_range_check: res.min_range_check,
|
||||
max_range_check: res.max_range_check,
|
||||
max_range_size: res.max_range_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -102,6 +99,14 @@ pub type NodeGraph = BTreeMap<usize, NodeType>;
|
||||
pub struct DummyPassRes {
|
||||
/// number of rows use
|
||||
pub num_rows: usize,
|
||||
/// num dynamic lookups
|
||||
pub num_dynamic_lookups: usize,
|
||||
/// dynamic lookup col size
|
||||
pub dynamic_lookup_col_coord: usize,
|
||||
/// num shuffles
|
||||
pub num_shuffles: usize,
|
||||
/// shuffle
|
||||
pub shuffle_col_coord: usize,
|
||||
/// linear coordinate
|
||||
pub linear_coord: usize,
|
||||
/// total const size
|
||||
@@ -115,9 +120,7 @@ pub struct DummyPassRes {
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: i128,
|
||||
/// min range check
|
||||
pub min_range_check: i128,
|
||||
/// max range check
|
||||
pub max_range_check: i128,
|
||||
pub max_range_size: i128,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
@@ -531,7 +534,7 @@ impl Model {
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs)?;
|
||||
let res = self.dummy_layout(run_args, &inputs, false)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -545,6 +548,10 @@ impl Model {
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
num_dynamic_lookups: res.num_dynamic_lookups,
|
||||
total_dynamic_col_size: res.dynamic_lookup_col_coord,
|
||||
num_shuffles: res.num_shuffles,
|
||||
total_shuffle_col_size: res.shuffle_col_coord,
|
||||
total_const_size: res.total_const_size,
|
||||
check_mode,
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
@@ -570,12 +577,13 @@ impl Model {
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
throw_range_check_error: bool,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs)?;
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -1007,24 +1015,24 @@ impl Model {
|
||||
/// # Arguments
|
||||
/// * `meta` - The constraint system.
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `run_args` - [RunArgs]
|
||||
/// * `required_lookups` - The required lookup operations for the circuit.
|
||||
/// * `settings` - [GraphSettings]
|
||||
pub fn configure(
|
||||
meta: &mut ConstraintSystem<Fp>,
|
||||
vars: &ModelVars<Fp>,
|
||||
lookup_range: Range,
|
||||
logrows: usize,
|
||||
required_lookups: Vec<LookupOp>,
|
||||
required_range_checks: Vec<Range>,
|
||||
check_mode: CheckMode,
|
||||
settings: &GraphSettings,
|
||||
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
|
||||
info!("configuring model");
|
||||
debug!("configuring model");
|
||||
|
||||
let lookup_range = settings.run_args.lookup_range;
|
||||
let logrows = settings.run_args.logrows as usize;
|
||||
let required_lookups = settings.required_lookups.clone();
|
||||
let required_range_checks = settings.required_range_checks.clone();
|
||||
|
||||
let mut base_gate = PolyConfig::configure(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
&vars.advices[2],
|
||||
check_mode,
|
||||
settings.check_mode,
|
||||
);
|
||||
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
|
||||
let input = &vars.advices[0];
|
||||
@@ -1038,6 +1046,22 @@ impl Model {
|
||||
base_gate.configure_range_check(meta, input, index, range, logrows)?;
|
||||
}
|
||||
|
||||
if settings.requires_dynamic_lookup() {
|
||||
base_gate.configure_dynamic_lookup(
|
||||
meta,
|
||||
vars.advices[0..3].try_into()?,
|
||||
vars.advices[3..6].try_into()?,
|
||||
)?;
|
||||
}
|
||||
|
||||
if settings.requires_shuffle() {
|
||||
base_gate.configure_shuffles(
|
||||
meta,
|
||||
vars.advices[0..2].try_into()?,
|
||||
vars.advices[3..5].try_into()?,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(base_gate)
|
||||
}
|
||||
|
||||
@@ -1356,6 +1380,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
throw_range_check_error: bool,
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
@@ -1374,7 +1399,7 @@ impl Model {
|
||||
vars: ModelVars::new_dummy(),
|
||||
};
|
||||
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols);
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
|
||||
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
@@ -1441,8 +1466,11 @@ impl Model {
|
||||
range_checks: region.used_range_checks(),
|
||||
max_lookup_inputs: region.max_lookup_inputs(),
|
||||
min_lookup_inputs: region.min_lookup_inputs(),
|
||||
min_range_check: region.min_range_check(),
|
||||
max_range_check: region.max_range_check(),
|
||||
max_range_size: region.max_range_size(),
|
||||
num_dynamic_lookups: region.dynamic_lookup_index(),
|
||||
dynamic_lookup_col_coord: region.dynamic_lookup_col_coord(),
|
||||
num_shuffles: region.shuffle_index(),
|
||||
shuffle_col_coord: region.shuffle_col_coord(),
|
||||
outputs,
|
||||
};
|
||||
|
||||
|
||||
@@ -23,7 +23,10 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
},
|
||||
change_axes::AxisOp,
|
||||
cnn::{Conv, Deconv},
|
||||
einsum::EinSum,
|
||||
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
|
||||
|
||||
// Extract the max value
|
||||
}
|
||||
"ScatterNd" => {
|
||||
if inputs.len() != 3 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"scatter nd".to_string(),
|
||||
)));
|
||||
};
|
||||
// just verify it deserializes correctly
|
||||
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
}
|
||||
// }
|
||||
|
||||
if inputs[1].opkind().is_input() {
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
|
||||
op
|
||||
}
|
||||
|
||||
"GatherNd" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"gather nd".to_string(),
|
||||
)));
|
||||
};
|
||||
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
|
||||
let batch_dims = op.batch_dims;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
}
|
||||
// }
|
||||
|
||||
if inputs[1].opkind().is_input() {
|
||||
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
|
||||
scale: 0,
|
||||
datum_type: InputType::TDim,
|
||||
}));
|
||||
inputs[1].bump_scale(0);
|
||||
}
|
||||
|
||||
op
|
||||
}
|
||||
|
||||
"GatherElements" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
@@ -734,7 +809,7 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
use_range_check_for_int: false,
|
||||
use_range_check_for_int: true,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -420,20 +420,34 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
|
||||
}
|
||||
|
||||
/// Allocate all columns that will be assigned to by a model.
|
||||
pub fn new(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
logrows: usize,
|
||||
var_len: usize,
|
||||
num_inner_cols: usize,
|
||||
num_constants: usize,
|
||||
module_requires_fixed: bool,
|
||||
) -> Self {
|
||||
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
|
||||
debug!("number of blinding factors: {}", cs.blinding_factors());
|
||||
|
||||
let advices = (0..3)
|
||||
let logrows = params.run_args.logrows as usize;
|
||||
let var_len = params.total_assignments;
|
||||
let num_inner_cols = params.run_args.num_inner_cols;
|
||||
let num_constants = params.total_const_size;
|
||||
let module_requires_fixed = params.module_requires_fixed();
|
||||
let requires_dynamic_lookup = params.requires_dynamic_lookup();
|
||||
let requires_shuffle = params.requires_shuffle();
|
||||
let dynamic_lookup_and_shuffle_size = params.dynamic_lookup_and_shuffle_col_size();
|
||||
|
||||
let mut advices = (0..3)
|
||||
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
|
||||
.collect_vec();
|
||||
|
||||
if requires_dynamic_lookup || requires_shuffle {
|
||||
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
|
||||
for _ in 0..num_cols {
|
||||
let dynamic_lookup =
|
||||
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
|
||||
if dynamic_lookup.num_blocks() > 1 {
|
||||
panic!("dynamic lookup or shuffle should only have one block");
|
||||
};
|
||||
advices.push(dynamic_lookup);
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"model uses {} advice blocks (size={})",
|
||||
advices.iter().map(|v| v.num_blocks()).sum::<usize>(),
|
||||
|
||||
@@ -180,6 +180,9 @@ impl RunArgs {
|
||||
if self.num_inner_cols < 1 {
|
||||
return Err("num_inner_cols must be >= 1".into());
|
||||
}
|
||||
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
|
||||
return Err("tolerance > 0.0 requires output_visibility to be public".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -197,7 +197,11 @@ impl std::fmt::Display for TranscriptType {
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for TranscriptType {}
|
||||
impl ToFlags for TranscriptType {
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for TranscriptType {
|
||||
@@ -480,7 +484,7 @@ where
|
||||
pub fn create_keys<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
circuit: &C,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
compress_selectors: bool,
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
@@ -492,7 +496,7 @@ where
|
||||
// Initialize verifying key
|
||||
let now = Instant::now();
|
||||
trace!("preparing VK");
|
||||
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
|
||||
let vk = keygen_vk_custom(params, &empty_circuit, !disable_selector_compression)?;
|
||||
let elapsed = now.elapsed();
|
||||
info!("VK took {}.{}", elapsed.as_secs(), elapsed.subsec_millis());
|
||||
|
||||
|
||||
@@ -484,7 +484,6 @@ fn get_srs(
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
CheckMode::SAFE,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
@@ -622,7 +621,7 @@ fn mock_aggregate(
|
||||
pk_path=PathBuf::from(DEFAULT_PK),
|
||||
srs_path=None,
|
||||
witness_path = None,
|
||||
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
|
||||
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
|
||||
))]
|
||||
fn setup(
|
||||
model: PathBuf,
|
||||
@@ -630,7 +629,7 @@ fn setup(
|
||||
pk_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
witness_path: Option<PathBuf>,
|
||||
compress_selectors: bool,
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::setup(
|
||||
model,
|
||||
@@ -638,7 +637,7 @@ fn setup(
|
||||
vk_path,
|
||||
pk_path,
|
||||
witness_path,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run setup: {}", e);
|
||||
@@ -719,7 +718,7 @@ fn verify(
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
split_proofs = false,
|
||||
srs_path = None,
|
||||
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
|
||||
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
|
||||
))]
|
||||
fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
@@ -728,7 +727,7 @@ fn setup_aggregate(
|
||||
logrows: u32,
|
||||
split_proofs: bool,
|
||||
srs_path: Option<PathBuf>,
|
||||
compress_selectors: bool,
|
||||
disable_selector_compression: bool,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::setup_aggregate(
|
||||
sample_snarks,
|
||||
@@ -737,7 +736,7 @@ fn setup_aggregate(
|
||||
srs_path,
|
||||
logrows,
|
||||
split_proofs,
|
||||
compress_selectors,
|
||||
disable_selector_compression,
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to setup aggregate: {}", e);
|
||||
@@ -1104,6 +1103,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyG1Affine>()?;
|
||||
m.add_class::<PyG1>()?;
|
||||
m.add_class::<PyTestDataSource>()?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
|
||||
|
||||
@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
Tensor::new(Some(&res), &dims)
|
||||
}
|
||||
|
||||
/// Set a slice of the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// assert_eq!(a, b);
|
||||
/// ```
|
||||
pub fn set_slice(
|
||||
&mut self,
|
||||
indices: &[Range<usize>],
|
||||
value: &Tensor<T>,
|
||||
) -> Result<(), TensorError>
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
if indices.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
if self.dims.len() < indices.len() {
|
||||
return Err(TensorError::DimError(format!(
|
||||
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
|
||||
indices, self.dims
|
||||
)));
|
||||
}
|
||||
|
||||
// if indices weren't specified we fill them in as required
|
||||
let mut full_indices = indices.to_vec();
|
||||
|
||||
let omitted_dims = (indices.len()..self.dims.len())
|
||||
.map(|i| self.dims[i])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for dim in &omitted_dims {
|
||||
full_indices.push(0..*dim);
|
||||
}
|
||||
|
||||
let full_dims = full_indices
|
||||
.iter()
|
||||
.map(|x| x.end - x.start)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// now broadcast the value to the full dims
|
||||
let value = value.expand(&full_dims)?;
|
||||
|
||||
let cartesian_coord: Vec<Vec<usize>> = full_indices
|
||||
.iter()
|
||||
.cloned()
|
||||
.multi_cartesian_product()
|
||||
.collect();
|
||||
|
||||
let _ = cartesian_coord
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| {
|
||||
self.set(e, value[i].clone());
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the array index from rows / columns indices.
|
||||
///
|
||||
/// ```
|
||||
@@ -1526,18 +1588,20 @@ pub fn get_broadcasted_shape(
|
||||
let num_dims_a = shape_a.len();
|
||||
let num_dims_b = shape_b.len();
|
||||
|
||||
// reewrite the below using match
|
||||
if num_dims_a == num_dims_b {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
match (num_dims_a, num_dims_b) {
|
||||
(a, b) if a == b => {
|
||||
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
|
||||
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
|
||||
let max_dim = dim_a.max(dim_b);
|
||||
broadcasted_shape.push(*max_dim);
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
}
|
||||
Ok(broadcasted_shape)
|
||||
} else if num_dims_a < num_dims_b {
|
||||
Ok(shape_b.to_vec())
|
||||
} else {
|
||||
Ok(shape_a.to_vec())
|
||||
(a, b) if a < b => Ok(shape_b.to_vec()),
|
||||
(a, b) if a > b => Ok(shape_a.to_vec()),
|
||||
_ => Err(Box::new(TensorError::DimError(
|
||||
"Unknown condition for broadcasting".to_string(),
|
||||
))),
|
||||
}
|
||||
}
|
||||
////////////////////////
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::TensorError;
|
||||
use crate::tensor::{Tensor, TensorType};
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::{
|
||||
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
|
||||
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
|
||||
prelude::IntoParallelRefIterator,
|
||||
};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Gather ND.
|
||||
/// # Arguments
|
||||
/// * `input` - Tensor
|
||||
/// * `index` - Tensor of indices to gather
|
||||
/// * `batch_dims` - Number of batch dimensions
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::gather_nd;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[0, 1, 2, 3]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 0, 1, 1]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[1, 0]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
/// &[2, 2, 2],
|
||||
/// ).unwrap();
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 1, 0]),
|
||||
/// &[2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 1, 0]),
|
||||
/// &[2, 1, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[1, 0]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 1).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
|
||||
/// &[2, 2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
|
||||
/// &[2, 2, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1, 0, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = gather_nd(&x, &index, 0).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
pub fn gather_nd<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
index: &Tensor<usize>,
|
||||
batch_dims: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// Calculate the output tensor size
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
let last_value = index_dims
|
||||
.last()
|
||||
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
|
||||
if last_value > &(input_dims.len() - batch_dims) {
|
||||
return Err(TensorError::DimMismatch("gather_nd".to_string()));
|
||||
}
|
||||
|
||||
let output_size =
|
||||
// If indices_shape[-1] == r-b, since the rank of indices is q,
|
||||
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
|
||||
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
|
||||
// Let us think of each such r-b ranked tensor as indices_slice.
|
||||
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
|
||||
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
|
||||
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
|
||||
// Let us think of each such tensors as indices_slice.
|
||||
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
|
||||
{
|
||||
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
|
||||
|
||||
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
|
||||
let input_offset = batch_dims + last_value;
|
||||
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
|
||||
|
||||
assert_eq!(output_rank, dims.len());
|
||||
dims
|
||||
|
||||
};
|
||||
|
||||
// cartesian coord over batch dims
|
||||
let mut batch_cartesian_coord = input_dims[0..batch_dims]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if batch_cartesian_coord.is_empty() {
|
||||
batch_cartesian_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let outputs = batch_cartesian_coord
|
||||
.par_iter()
|
||||
.map(|batch_coord| {
|
||||
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let mut index_slice = index.get_slice(&batch_slice)?;
|
||||
index_slice.reshape(&index.dims()[batch_dims..])?;
|
||||
let mut input_slice = input.get_slice(&batch_slice)?;
|
||||
input_slice.reshape(&input.dims()[batch_dims..])?;
|
||||
|
||||
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if inner_cartesian_coord.is_empty() {
|
||||
inner_cartesian_coord.push(vec![]);
|
||||
}
|
||||
|
||||
let output = inner_cartesian_coord
|
||||
.iter()
|
||||
.map(|coord| {
|
||||
let slice = coord
|
||||
.iter()
|
||||
.map(|x| *x..*x + 1)
|
||||
.chain(batch_coord.iter().map(|x| *x..*x + 1))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let index_slice = index_slice
|
||||
.get_slice(&slice)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| *x..*x + 1)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
input_slice.get_slice(&index_slice).unwrap()
|
||||
})
|
||||
.collect::<Tensor<_>>();
|
||||
|
||||
output.combine()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
|
||||
|
||||
outputs.reshape(&output_size)?;
|
||||
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
/// Scatter ND.
|
||||
/// This operator is the inverse of GatherND.
|
||||
/// # Arguments
|
||||
/// * `input` - Tensor
|
||||
/// * `index` - Tensor of indices to scatter
|
||||
/// * `src` - Tensor of src
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::scatter_nd;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[8],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[4, 3, 1, 7]),
|
||||
/// &[4, 1],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9, 10, 11, 12]),
|
||||
/// &[4],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[4, 4, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 2]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
|
||||
/// ]),
|
||||
/// &[2, 4, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
///
|
||||
/// let expected = Tensor::<i128>::new(
|
||||
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
|
||||
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
|
||||
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
|
||||
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[4, 4, 4],
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[2, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1]),
|
||||
/// &[2, 1],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9, 10]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
|
||||
/// &[2, 4],
|
||||
/// ).unwrap();
|
||||
///
|
||||
/// let index = Tensor::<usize>::new(
|
||||
/// Some(&[0, 1]),
|
||||
/// &[1, 1, 2],
|
||||
/// ).unwrap();
|
||||
/// let src = Tensor::<i128>::new(
|
||||
/// Some(&[9]),
|
||||
/// &[1, 1],
|
||||
/// ).unwrap();
|
||||
/// let result = scatter_nd(&x, &index, &src).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ````
|
||||
///
|
||||
pub fn scatter_nd<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
index: &Tensor<usize>,
|
||||
src: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
// Calculate the output tensor size
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
let last_value = index_dims
|
||||
.last()
|
||||
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
|
||||
if last_value > &input_dims.len() {
|
||||
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
|
||||
}
|
||||
|
||||
let mut output = input.clone();
|
||||
|
||||
let cartesian_coord = index_dims[0..index_dims.len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
.multi_cartesian_product()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cartesian_coord
|
||||
.iter()
|
||||
.map(|coord| {
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let index_val = index.get_slice(&slice)?;
|
||||
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let src_val = src.get_slice(&slice)?;
|
||||
output.set_slice(&index_slice, &src_val)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn axes_op<T: TensorType + Send + Sync>(
|
||||
a: &Tensor<T>,
|
||||
axes: &[usize],
|
||||
@@ -3773,6 +4083,30 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise inverse.
|
||||
/// # Arguments
|
||||
/// * `out_scale` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
|
||||
/// let k = 2_f64;
|
||||
/// let result = zero_recip(1.0);
|
||||
/// let expected = Tensor::<i128>::new(Some(&[4503599627370496]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn zero_recip(out_scale: f64) -> Tensor<i128> {
|
||||
let a = Tensor::<i128>::new(Some(&[0]), &[1]).unwrap();
|
||||
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rescaled = a_i as f64;
|
||||
let denom = (1_f64) / (rescaled + f64::EPSILON);
|
||||
let d_inv_x = out_scale * denom;
|
||||
Ok::<_, TensorError>(d_inv_x.round() as i128)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise greater than
|
||||
/// # Arguments
|
||||
///
|
||||
|
||||
@@ -4,6 +4,37 @@ use super::{
|
||||
};
|
||||
use halo2_proofs::{arithmetic::Field, plonk::Instance};
|
||||
|
||||
pub(crate) fn create_constant_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
val: F,
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut constant = Tensor::from(vec![ValType::Constant(val); len].into_iter());
|
||||
constant.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(constant)
|
||||
}
|
||||
|
||||
pub(crate) fn create_unit_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut unit = Tensor::from(vec![ValType::Constant(F::ONE); len].into_iter());
|
||||
unit.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(unit)
|
||||
}
|
||||
|
||||
pub(crate) fn create_zero_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
>(
|
||||
len: usize,
|
||||
) -> ValTensor<F> {
|
||||
let mut zero = Tensor::from(vec![ValType::Constant(F::ZERO); len].into_iter());
|
||||
zero.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
ValTensor::from(zero)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// A [ValType] is a wrapper around Halo2 value(s).
|
||||
pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> {
|
||||
@@ -318,6 +349,19 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
matches!(self, ValTensor::Instance { .. })
|
||||
}
|
||||
|
||||
/// reverse order of elements whilst preserving the shape
|
||||
pub fn reverse(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
v.reverse();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
pub fn set_initial_instance_offset(&mut self, offset: usize) {
|
||||
if let ValTensor::Instance { initial_offset, .. } = self {
|
||||
@@ -450,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
_ => return Err(Box::new(TensorError::WrongMethod)),
|
||||
};
|
||||
Ok(integer_evals.into_iter().into())
|
||||
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
|
||||
13
src/wasm.rs
13
src/wasm.rs
@@ -78,6 +78,17 @@ pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String,
|
||||
Ok(format!("{:?}", felt))
|
||||
}
|
||||
|
||||
/// Converts a felt to a little endian string
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let repr = serde_json::to_string(&felt).unwrap();
|
||||
let b: String = serde_json::from_str(&repr).unwrap();
|
||||
Ok(b)
|
||||
}
|
||||
|
||||
/// Converts a hex string to a byte array
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
@@ -211,7 +222,7 @@ pub fn genWitness(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward(&mut input, None, None)
|
||||
.forward(&mut input, None, None, false)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#[cfg(test)]
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_i128, i128_to_felt};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
@@ -192,7 +193,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 77] = [
|
||||
const TESTS: [&str; 79] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -274,9 +275,11 @@ mod native_tests {
|
||||
"ltsf",
|
||||
"remainder", //75
|
||||
"bitshift",
|
||||
"gather_nd",
|
||||
"scatter_nd",
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 48] = [
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
"1l_mlp",
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -325,8 +328,6 @@ mod native_tests {
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"decision_tree", // "variable_cnn",
|
||||
"random_forest",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
// "xgboost",
|
||||
@@ -503,7 +504,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=76 {
|
||||
seq!(N in 0..=78 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -586,15 +587,20 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
// currently variable output rank is not supported in ONNX
|
||||
if test != "gather_nd" {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let large_batch_dir = &format!("large_batches_{}", test);
|
||||
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
|
||||
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
@@ -841,7 +847,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=47 {
|
||||
seq!(N in 0..=45 {
|
||||
|
||||
#(#[test_case(WASM_TESTS[N])])*
|
||||
fn kzg_prove_and_verify_with_overflow_(test: &str) {
|
||||
@@ -852,7 +858,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
// test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -865,7 +871,7 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testWasm");
|
||||
run_js_tests(path, test.to_string(), "testWasm", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -913,6 +919,7 @@ mod native_tests {
|
||||
use crate::native_tests::kzg_fuzz;
|
||||
use tempdir::TempDir;
|
||||
use crate::native_tests::Hardfork;
|
||||
use crate::native_tests::run_js_tests;
|
||||
|
||||
/// Currently only on chain inputs that return a non-negative value are supported.
|
||||
const TESTS_ON_CHAIN_INPUT: [&str; 17] = [
|
||||
@@ -1007,8 +1014,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
@@ -1020,8 +1027,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify_render_seperately(2, path, test.to_string(), "private", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", true);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
@@ -1034,8 +1041,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let mut _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "hashed", "private", "private");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1051,8 +1058,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let mut _anvil_child = crate::native_tests::start_anvil(false, hardfork);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "private", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1064,8 +1071,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "hashed", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
|
||||
}
|
||||
@@ -1077,8 +1084,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "hashed");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1090,8 +1097,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "kzgcommit", "public");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1103,8 +1110,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "kzgcommit");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1115,8 +1122,8 @@ mod native_tests {
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
|
||||
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit");
|
||||
// #[cfg(not(feature = "icicle"))]
|
||||
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
|
||||
#[cfg(not(feature = "icicle"))]
|
||||
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
@@ -1288,6 +1295,7 @@ mod native_tests {
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
tolerance: f32,
|
||||
) {
|
||||
let mut tolerance = tolerance;
|
||||
gen_circuit_settings_and_witness(
|
||||
test_dir,
|
||||
example_name.clone(),
|
||||
@@ -1299,16 +1307,10 @@ mod native_tests {
|
||||
scales_to_use,
|
||||
2,
|
||||
false,
|
||||
tolerance,
|
||||
&mut tolerance,
|
||||
);
|
||||
|
||||
let settings =
|
||||
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
|
||||
|
||||
if tolerance > 0.0 && !any_output_scales_smol {
|
||||
if tolerance > 0.0 {
|
||||
// load witness and shift the output by a small amount that is less than tolerance percent
|
||||
let witness = GraphWitness::from_path(
|
||||
format!("{}/{}/witness.json", test_dir, example_name).into(),
|
||||
@@ -1333,7 +1335,7 @@ mod native_tests {
|
||||
as i128,
|
||||
)
|
||||
};
|
||||
|
||||
|
||||
*v + perturbation
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
@@ -1444,7 +1446,7 @@ mod native_tests {
|
||||
scales_to_use: Option<Vec<u32>>,
|
||||
num_inner_columns: usize,
|
||||
div_rebasing: bool,
|
||||
tolerance: f32,
|
||||
tolerance: &mut f32,
|
||||
) {
|
||||
let mut args = vec![
|
||||
"gen-settings".to_string(),
|
||||
@@ -1502,6 +1504,24 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let mut settings =
|
||||
GraphSettings::load(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
|
||||
let any_output_scales_smol = settings.model_output_scales.iter().any(|s| *s <= 0);
|
||||
|
||||
if any_output_scales_smol {
|
||||
// set the tolerance to 0.0
|
||||
settings.run_args.tolerance = Tolerance {
|
||||
val: 0.0,
|
||||
scale: 0.0.into(),
|
||||
};
|
||||
settings
|
||||
.save(&format!("{}/{}/settings.json", test_dir, example_name).into())
|
||||
.unwrap();
|
||||
*tolerance = 0.0;
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"compile-circuit",
|
||||
@@ -1559,7 +1579,7 @@ mod native_tests {
|
||||
None,
|
||||
2,
|
||||
div_rebasing,
|
||||
0.0,
|
||||
&mut 0.0,
|
||||
);
|
||||
|
||||
println!(
|
||||
@@ -1819,7 +1839,7 @@ mod native_tests {
|
||||
scales_to_use,
|
||||
num_inner_columns,
|
||||
false,
|
||||
0.0,
|
||||
&mut 0.0,
|
||||
);
|
||||
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
@@ -1835,6 +1855,7 @@ mod native_tests {
|
||||
&format!("{}/{}/key.pk", test_dir, example_name),
|
||||
"--vk-path",
|
||||
&format!("{}/{}/key.vk", test_dir, example_name),
|
||||
"--disable-selector-compression",
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
@@ -1921,7 +1942,7 @@ mod native_tests {
|
||||
None,
|
||||
2,
|
||||
false,
|
||||
0.0,
|
||||
&mut 0.0,
|
||||
);
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -2164,15 +2185,17 @@ mod native_tests {
|
||||
}
|
||||
|
||||
// run js browser evm verify tests for a given example
|
||||
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str) {
|
||||
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str, vk: bool) {
|
||||
let example = format!("--example={}", example_name);
|
||||
let dir = format!("--dir={}", test_dir);
|
||||
let mut args = vec!["run", "test", js_test, &example, &dir];
|
||||
let vk_string: String;
|
||||
if vk {
|
||||
vk_string = format!("--vk={}", vk);
|
||||
args.push(&vk_string);
|
||||
};
|
||||
let status = Command::new("pnpm")
|
||||
.args([
|
||||
"run",
|
||||
"test",
|
||||
js_test,
|
||||
&format!("--example={}", example_name),
|
||||
&format!("--dir={}", test_dir),
|
||||
])
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
@@ -2198,7 +2221,7 @@ mod native_tests {
|
||||
Some(vec![4]),
|
||||
1,
|
||||
false,
|
||||
0.0,
|
||||
&mut 0.0,
|
||||
);
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
|
||||
@@ -91,9 +91,7 @@ def compare_outputs(zk_output, onnx_output):
|
||||
print("------- zk_output: ", list1_i)
|
||||
print("------- onnx_output: ", list2_i)
|
||||
|
||||
|
||||
|
||||
return np.mean(np.abs(res))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -113,6 +111,9 @@ if __name__ == '__main__':
|
||||
onnx_output = get_onnx_output(model_file, input_file)
|
||||
# compare the outputs
|
||||
percentage_difference = compare_outputs(ezkl_output, onnx_output)
|
||||
mean_percentage_difference = np.mean(np.abs(percentage_difference))
|
||||
max_percentage_difference = np.max(np.abs(percentage_difference))
|
||||
# print the percentage difference
|
||||
print("mean percent diff: ", percentage_difference)
|
||||
assert percentage_difference < target, "Percentage difference is too high"
|
||||
print("mean percent diff: ", mean_percentage_difference)
|
||||
print("max percent diff: ", max_percentage_difference)
|
||||
assert mean_percentage_difference < target, "Percentage difference is too high"
|
||||
|
||||
@@ -9,8 +9,8 @@ mod wasm32 {
|
||||
use ezkl::pfsys;
|
||||
use ezkl::wasm::{
|
||||
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
|
||||
feltToFloat, feltToInt, genPk, genVk, genWitness, inputValidation, pkValidation,
|
||||
poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
|
||||
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
|
||||
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
|
||||
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
|
||||
};
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
@@ -89,9 +89,16 @@ mod wasm32 {
|
||||
.unwrap();
|
||||
assert_eq!(integer, i as i128);
|
||||
|
||||
let hex_string = format!("{:?}", field_element);
|
||||
let returned_string: String = feltToBigEndian(clamped).map_err(|_| "failed").unwrap();
|
||||
let hex_string = format!("{:?}", field_element.clone());
|
||||
let returned_string: String = feltToBigEndian(clamped.clone())
|
||||
.map_err(|_| "failed")
|
||||
.unwrap();
|
||||
assert_eq!(hex_string, returned_string);
|
||||
let repr = serde_json::to_string(&field_element).unwrap();
|
||||
let little_endian_string: String = serde_json::from_str(&repr).unwrap();
|
||||
let returned_string: String =
|
||||
feltToLittleEndian(clamped).map_err(|_| "failed").unwrap();
|
||||
assert_eq!(little_endian_string, returned_string);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -27,6 +27,10 @@
|
||||
"check_mode": "UNSAFE"
|
||||
},
|
||||
"num_rows": 16,
|
||||
"total_dynamic_col_size": 0,
|
||||
"num_dynamic_lookups": 0,
|
||||
"num_shuffles": 0,
|
||||
"total_shuffle_col_size": 0,
|
||||
"total_assignments": 32,
|
||||
"total_const_size": 8,
|
||||
"model_instance_shapes": [
|
||||
|
||||
@@ -1,28 +1,42 @@
|
||||
import localEVMVerify, { Hardfork } from '@ezkljs/verify'
|
||||
import localEVMVerify from '../../in-browser-evm-verifier/src/index'
|
||||
import { serialize, deserialize } from '@ezkljs/engine/nodejs'
|
||||
import { compileContracts } from './utils'
|
||||
import * as fs from 'fs'
|
||||
|
||||
exports.USER_NAME = require("minimist")(process.argv.slice(2))["example"];
|
||||
exports.EXAMPLE = require("minimist")(process.argv.slice(2))["example"];
|
||||
exports.PATH = require("minimist")(process.argv.slice(2))["dir"];
|
||||
exports.VK = require("minimist")(process.argv.slice(2))["vk"];
|
||||
|
||||
describe('localEVMVerify', () => {
|
||||
|
||||
let bytecode: string
|
||||
let bytecode_verifier: string
|
||||
|
||||
let bytecode_vk: string | undefined = undefined
|
||||
|
||||
let proof: any
|
||||
|
||||
const example = exports.USER_NAME || "1l_mlp"
|
||||
const example = exports.EXAMPLE || "1l_mlp"
|
||||
const path = exports.PATH || "../ezkl/examples/onnx"
|
||||
const vk = exports.VK || false
|
||||
|
||||
beforeEach(() => {
|
||||
let solcOutput = compileContracts(path, example)
|
||||
const solcOutput = compileContracts(path, example, 'kzg')
|
||||
|
||||
bytecode =
|
||||
bytecode_verifier =
|
||||
solcOutput.contracts['artifacts/Verifier.sol']['Halo2Verifier'].evm.bytecode
|
||||
.object
|
||||
|
||||
console.log('size', bytecode.length)
|
||||
if (vk) {
|
||||
const solcOutput_vk = compileContracts(path, example, 'vk')
|
||||
|
||||
bytecode_vk =
|
||||
solcOutput_vk.contracts['artifacts/Verifier.sol']['Halo2VerifyingKey'].evm.bytecode
|
||||
.object
|
||||
|
||||
|
||||
console.log('size of verifier bytecode', bytecode_verifier.length)
|
||||
}
|
||||
console.log('verifier bytecode', bytecode_verifier)
|
||||
})
|
||||
|
||||
it('should return true when verification succeeds', async () => {
|
||||
@@ -30,7 +44,9 @@ describe('localEVMVerify', () => {
|
||||
|
||||
proof = deserialize(proofFileBuffer)
|
||||
|
||||
const result = await localEVMVerify(proofFileBuffer, bytecode)
|
||||
const result = await localEVMVerify(proofFileBuffer, bytecode_verifier, bytecode_vk)
|
||||
|
||||
console.log('result', result)
|
||||
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
@@ -39,13 +55,16 @@ describe('localEVMVerify', () => {
|
||||
let result: boolean = true
|
||||
console.log(proof.proof)
|
||||
try {
|
||||
let index = Math.floor(Math.random() * (proof.proof.length - 2)) + 2
|
||||
let number = (proof.proof[index] + 1) % 16
|
||||
let index = Math.round((Math.random() * (proof.proof.length))) % proof.proof.length
|
||||
console.log('index', index)
|
||||
console.log('index', proof.proof[index])
|
||||
let number = (proof.proof[index] + 1) % 256
|
||||
console.log('index', index)
|
||||
console.log('new number', number)
|
||||
proof.proof[index] = number
|
||||
console.log('index post', proof.proof[index])
|
||||
const proofModified = serialize(proof)
|
||||
result = await localEVMVerify(proofModified, bytecode)
|
||||
result = await localEVMVerify(proofModified, bytecode_verifier, bytecode_vk)
|
||||
} catch (error) {
|
||||
// Check if the error thrown is the "out of gas" error.
|
||||
expect(error).toEqual(
|
||||
|
||||
@@ -38,7 +38,10 @@ describe('Generate witness, prove and verify', () => {
|
||||
let pk = await readEzklArtifactsFile(path, example, 'key.pk');
|
||||
let circuit_ser = await readEzklArtifactsFile(path, example, 'network.compiled');
|
||||
circuit_settings_ser = await readEzklArtifactsFile(path, example, 'settings.json');
|
||||
params_ser = await readEzklSrsFile(path, example);
|
||||
// get the log rows from the circuit settings
|
||||
const circuit_settings = deserialize(circuit_settings_ser) as any;
|
||||
const logrows = circuit_settings.run_args.logrows as string;
|
||||
params_ser = await readEzklSrsFile(logrows);
|
||||
const startTimeProve = Date.now();
|
||||
result = wasmFunctions.prove(witness, pk, circuit_ser, params_ser);
|
||||
const endTimeProve = Date.now();
|
||||
@@ -54,6 +57,7 @@ describe('Generate witness, prove and verify', () => {
|
||||
let result
|
||||
const vk = await readEzklArtifactsFile(path, example, 'key.vk');
|
||||
const startTimeVerify = Date.now();
|
||||
params_ser = await readEzklSrsFile("1");
|
||||
result = wasmFunctions.verify(proof_ser, vk, circuit_settings_ser, params_ser);
|
||||
const result_ref = wasmFunctions.verify(proof_ser_ref, vk, circuit_settings_ser, params_ser);
|
||||
const endTimeVerify = Date.now();
|
||||
|
||||
@@ -16,15 +16,7 @@ export async function readEzklArtifactsFile(path: string, example: string, filen
|
||||
return new Uint8ClampedArray(buffer.buffer);
|
||||
}
|
||||
|
||||
export async function readEzklSrsFile(path: string, example: string): Promise<Uint8ClampedArray> {
|
||||
// const settingsPath = path.join(__dirname, '..', '..', 'ezkl', 'examples', 'onnx', example, 'settings.json');
|
||||
|
||||
const settingsPath = `${path}/${example}/settings.json`
|
||||
const settingsBuffer = await fs.readFile(settingsPath, { encoding: 'utf-8' });
|
||||
const settings = JSONBig.parse(settingsBuffer);
|
||||
const logrows = settings.run_args.logrows;
|
||||
// const filePath = path.join(__dirname, '..', '..', 'ezkl', 'examples', 'onnx', `kzg${logrows}.srs`);
|
||||
// srs path is at $HOME/.ezkl/srs
|
||||
export async function readEzklSrsFile(logrows: string): Promise<Uint8ClampedArray> {
|
||||
const filePath = `${userHomeDir}/.ezkl/srs/kzg${logrows}.srs`
|
||||
const buffer = await fs.readFile(filePath);
|
||||
return new Uint8ClampedArray(buffer.buffer);
|
||||
@@ -51,21 +43,21 @@ export function serialize(data: object | string): Uint8ClampedArray { // data is
|
||||
return new Uint8ClampedArray(uint8Array.buffer);
|
||||
}
|
||||
|
||||
export function getSolcInput(path: string, example: string) {
|
||||
export function getSolcInput(path: string, example: string, name: string) {
|
||||
return {
|
||||
language: 'Solidity',
|
||||
sources: {
|
||||
'artifacts/Verifier.sol': {
|
||||
content: fsSync.readFileSync(`${path}/${example}/kzg.sol`, 'utf-8'),
|
||||
content: fsSync.readFileSync(`${path}/${example}/${name}.sol`, 'utf-8'),
|
||||
},
|
||||
// If more contracts were to be compiled, they should have their own entries here
|
||||
},
|
||||
settings: {
|
||||
optimizer: {
|
||||
enabled: true,
|
||||
runs: 200,
|
||||
runs: 1,
|
||||
},
|
||||
evmVersion: 'london',
|
||||
evmVersion: 'shanghai',
|
||||
outputSelection: {
|
||||
'*': {
|
||||
'*': ['abi', 'evm.bytecode'],
|
||||
@@ -75,8 +67,8 @@ export function getSolcInput(path: string, example: string) {
|
||||
}
|
||||
}
|
||||
|
||||
export function compileContracts(path: string, example: string) {
|
||||
const input = getSolcInput(path, example)
|
||||
export function compileContracts(path: string, example: string, name: string) {
|
||||
const input = getSolcInput(path, example, name)
|
||||
const output = JSON.parse(solc.compile(JSON.stringify(input)))
|
||||
|
||||
let compilationFailed = false
|
||||
|
||||
Binary file not shown.
@@ -1 +1 @@
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_check":0,"min_range_check":0}
|
||||
{"inputs":[["0200000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000","0100000000000000000000000000000000000000000000000000000000000000"]],"pretty_elements":{"rescaled_inputs":[["2","1","1"]],"inputs":[["0x0000000000000000000000000000000000000000000000000000000000000002","0x0000000000000000000000000000000000000000000000000000000000000001","0x0000000000000000000000000000000000000000000000000000000000000001"]],"processed_inputs":[],"processed_params":[],"processed_outputs":[],"rescaled_outputs":[["0","0","0","0"]],"outputs":[["0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000","0x0000000000000000000000000000000000000000000000000000000000000000"]]},"outputs":[["0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000","0000000000000000000000000000000000000000000000000000000000000000"]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":0,"min_lookup_inputs":-1,"max_range_size":0}
|
||||
1
verifier_abi.json
Normal file
1
verifier_abi.json
Normal file
@@ -0,0 +1 @@
|
||||
[{"type":"function","name":"verifyProof","inputs":[{"internalType":"bytes","name":"proof","type":"bytes"},{"internalType":"uint256[]","name":"instances","type":"uint256[]"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"nonpayable"}]
|
||||
Reference in New Issue
Block a user