Compare commits

...

12 Commits

Author SHA1 Message Date
Ethan Cemer
2be181db35 feat: merge @ezkljs/verify package into core repo. (#736) 2024-03-14 01:13:14 +00:00
jmjac
de9e3f2673 Add __version__ to python bindings (#739) 2024-03-13 14:22:20 +00:00
dante
a1450f8df7 feat: gather_nd/scatter_nd support (#737) 2024-03-11 22:05:40 +00:00
dante
ea535e2ecd refactor: use linear index constraints for gather and scatter (#735) 2024-03-09 18:00:21 +00:00
Alexander Camuto
f8aa91ed08 fix: windows compile 2024-03-06 11:40:44 +00:00
dante
a59e3780b2 chore: rm recip_int helper (#733) 2024-03-05 21:51:14 +00:00
dante
345fb5672a chore: cleanup unused args (#732) 2024-03-05 13:43:29 +00:00
dante
70daaff2e4 chore: cleanup calibrate (#731) 2024-03-04 17:52:11 +00:00
dante
a437d8a51f feat: "sub"-dynamic tables (#730) 2024-03-04 10:35:28 +00:00
Ethan Cemer
fe535c1cac feat: wasm felt to little endian string (#729)
---------

Co-authored-by: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com>
2024-03-01 14:06:20 +00:00
dante
3e8dcb001a chore: test for reduced-srs on wasm bundle (#728)
---------

Co-authored-by: Ethan <tylercemer@gmail.com>
2024-03-01 13:23:07 +00:00
dante
14786acb95 feat: dynamic lookups (#727) 2024-03-01 01:44:45 +00:00
53 changed files with 4699 additions and 1232 deletions

View File

@@ -1,4 +1,4 @@
name: Build and Publish WASM<>JS Bindings
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
@@ -14,7 +14,7 @@ defaults:
run:
working-directory: .
jobs:
wasm-publish:
publish-wasm-bindings:
name: publish-wasm-bindings
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
@@ -174,3 +174,40 @@ jobs:
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
needs: ["publish-wasm-bindings"]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
npm install
npm run build
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -303,12 +303,24 @@ jobs:
with:
node-version: "18.12.1"
cache: "pnpm"
- name: Install dependencies
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --no-frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
env:
CI: false
NODE_ENV: development
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
- name: Build @ezkljs/verify package
run: |
cd in-browser-evm-verifier
pnpm build:commonjs
cd ..
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
@@ -364,7 +376,7 @@ jobs:
with:
node-version: "18.12.1"
cache: "pnpm"
- name: Install dependencies
- name: Install dependencies for js tests
run: |
pnpm install --no-frozen-lockfile
env:
@@ -575,7 +587,7 @@ jobs:
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
run: source .env/bin/activate; pytest -vv
@@ -599,7 +611,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
@@ -634,7 +646,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
# - name: authenticate-kaggle-cli
# shell: bash
# env:

1
.gitignore vendored
View File

@@ -45,6 +45,7 @@ var/
*.whl
*.bak
node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key

View File

@@ -74,6 +74,10 @@ For more details visit the [docs](https://docs.ezkl.xyz).
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
#### In-browser EVM verifier
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
### building the project 🔨

View File

@@ -343,7 +343,6 @@
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" compress_selectors=True,\n",
" )\n",
"\n",
" assert res == True\n",

View File

@@ -0,0 +1,48 @@
from torch import nn
import json
import numpy as np
import tf2onnx
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
# gather_nd in tf then export to onnx
x = in1 = Input((15, 18,))
w = in2 = Input((15, 1), dtype=tf.int32)
x = tf.gather_nd(x, w, batch_dims=1)
tm = Model((in1, in2), x )
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 15, 18]
index_shape = [1, 15, 1]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)
# w = random int tensor
w = np.random.randint(0, 10, index_shape)
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d, d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import sys
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}

Binary file not shown.

View File

@@ -0,0 +1,60 @@
# inbrowser-evm-verify
We would like the Solidity verifier to be canonical and usually all you ever need. For this, we need to be able to run that verifier in browser.
## How to use (Node js)
```ts
import localEVMVerify from '@ezkljs/verify';
// Load in the proof file as a buffer
const proofFileBuffer = fs.readFileSync(`${path}/${example}/proof.pf`)
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
const result = await localEVMVerify(proofFileBuffer, bytecode)
console.log('result', result)
```
**Note**: Run `ezkl create-evm-verifier` to get the Solidity verifier, with which you can retrieve the bytecode once compiled. We recommend compiling to the Shanghai hardfork target, else you will have to pass an additional parameter specifying the EVM version to the `localEVMVerify` function like so (for Paris hardfork):
```ts
import localEVMVerify, { hardfork } from '@ezkljs/verify';
const result = await localEVMVerify(proofFileBuffer, bytecode, hardfork['Paris'])
```
**Note**: You can also verify separated vk verifiers using the `localEVMVerify` function. Just pass the vk verifier bytecode as the third parameter like so:
```ts
import localEVMVerify from '@ezkljs/verify';
const result = await localEVMVerify(proofFileBuffer, verifierBytecode, VKBytecode)
```
## How to use (Browser)
```ts
import localEVMVerify from '@ezkljs/verify';
// Load in the proof file as a buffer using the web apis (fetch, FileReader, etc)
// We use fetch in this example to load the proof file as a buffer
const proofFileBuffer = await fetch(`${path}/${example}/proof.pf`).then(res => res.arrayBuffer())
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
const result = await browserEVMVerify(proofFileBuffer, bytecode)
console.log('result', result)
```
Output:
```ts
result: true
```

View File

@@ -0,0 +1,42 @@
{
"name": "@ezkljs/verify",
"version": "0.0.0",
"publishConfig": {
"access": "public"
},
"description": "Evm verify EZKL proofs in the browser.",
"main": "dist/commonjs/index.js",
"module": "dist/esm/index.js",
"types": "dist/commonjs/index.d.ts",
"files": [
"dist",
"LICENSE",
"README.md"
],
"scripts": {
"clean": "rm -r dist || true",
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",
"@ethereumjs/evm": "^2.0.0",
"@ethereumjs/statemanager": "^2.0.0",
"@ethereumjs/tx": "^5.0.0",
"@ethereumjs/util": "^9.0.0",
"@ethereumjs/vm": "^7.0.0",
"@ethersproject/abi": "^5.7.0",
"@ezkljs/engine": "^9.4.4",
"ethers": "^6.7.1",
"json-bigint": "^1.0.0"
},
"devDependencies": {
"@types/node": "^20.8.3",
"ts-loader": "^9.5.0",
"ts-node": "^10.9.1",
"resolve-tspaths": "^0.8.16",
"tsconfig-paths": "^4.2.0",
"typescript": "^5.2.2"
}
}

1479
in-browser-evm-verifier/pnpm-lock.yaml generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,145 @@
import { defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
import { Address, hexToBytes } from '@ethereumjs/util'
import { Chain, Common, Hardfork } from '@ethereumjs/common'
import { LegacyTransaction, LegacyTxData } from '@ethereumjs/tx'
// import { DefaultStateManager } from '@ethereumjs/statemanager'
// import { Blockchain } from '@ethereumjs/blockchain'
import { VM } from '@ethereumjs/vm'
import { EVM } from '@ethereumjs/evm'
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
import { getAccountNonce, insertAccount } from './utils/account-utils'
import { encodeVerifierCalldata } from '../nodejs/ezkl';
import { error } from 'console'
async function deployContract(
vm: VM,
common: Common,
senderPrivateKey: Uint8Array,
deploymentBytecode: string
): Promise<Address> {
// Contracts are deployed by sending their deployment bytecode to the address 0
// The contract params should be abi-encoded and appended to the deployment bytecode.
// const data =
const data = encodeDeployment(deploymentBytecode)
const txData = {
data,
nonce: await getAccountNonce(vm, senderPrivateKey),
}
const tx = LegacyTransaction.fromTxData(
buildTransaction(txData) as LegacyTxData,
{ common, allowUnlimitedInitCodeSize: true },
).sign(senderPrivateKey)
const deploymentResult = await vm.runTx({
tx,
skipBlockGasLimitValidation: true,
skipNonce: true
})
if (deploymentResult.execResult.exceptionError) {
throw deploymentResult.execResult.exceptionError
}
return deploymentResult.createdAddress!
}
async function verify(
vm: VM,
contractAddress: Address,
caller: Address,
proof: Uint8Array | Uint8ClampedArray,
vkAddress?: Address | Uint8Array,
): Promise<boolean> {
if (proof instanceof Uint8Array) {
proof = new Uint8ClampedArray(proof.buffer)
}
if (vkAddress) {
const vkAddressBytes = hexToBytes(vkAddress.toString())
const vkAddressArray = Array.from(vkAddressBytes)
let string = JSON.stringify(vkAddressArray)
const uint8Array = new TextEncoder().encode(string);
// Step 3: Convert to Uint8ClampedArray
vkAddress = new Uint8Array(uint8Array.buffer);
// convert uitn8array of length
error('vkAddress', vkAddress)
}
const data = encodeVerifierCalldata(proof, vkAddress)
const verifyResult = await vm.evm.runCall({
to: contractAddress,
caller: caller,
origin: caller, // The tx.origin is also the caller here
data: data,
})
if (verifyResult.execResult.exceptionError) {
throw verifyResult.execResult.exceptionError
}
const results = AbiCoder.decode(['bool'], verifyResult.execResult.returnValue)
return results[0]
}
/**
* Spins up an ephemeral EVM instance for executing the bytecode of a solidity verifier
* @param proof Json serialized proof file
* @param bytecode The bytecode of a compiled solidity verifier.
* @param bytecode_vk The bytecode of a contract that stores the vk. (Optional, only required if the vk is stored in a separate contract)
* @param evmVersion The evm version to use for the verification. (Default: London)
* @returns The result of the evm verification.
* @throws If the verify transaction reverts
*/
export default async function localEVMVerify(
proof: Uint8Array | Uint8ClampedArray,
bytecode_verifier: string,
bytecode_vk?: string,
evmVersion?: Hardfork,
): Promise<boolean> {
try {
const hardfork = evmVersion ? evmVersion : Hardfork['Shanghai']
const common = new Common({ chain: Chain.Mainnet, hardfork })
const accountPk = hexToBytes(
'0xe331b6d69882b4cb4ea581d88e0b604039a3de5967688d3dcffdd2270c0fd109', // anvil deterministic Pk
)
const evm = new EVM({
allowUnlimitedContractSize: true,
allowUnlimitedInitCodeSize: true,
})
const vm = await VM.create({ common, evm })
const accountAddress = Address.fromPrivateKey(accountPk)
await insertAccount(vm, accountAddress)
const verifierAddress = await deployContract(
vm,
common,
accountPk,
bytecode_verifier
)
if (bytecode_vk) {
const accountPk = hexToBytes("0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"); // anvil deterministic Pk
const accountAddress = Address.fromPrivateKey(accountPk)
await insertAccount(vm, accountAddress)
const output = await deployContract(vm, common, accountPk, bytecode_vk)
const result = await verify(vm, verifierAddress, accountAddress, proof, output)
return true
}
const result = await verify(vm, verifierAddress, accountAddress, proof)
return result
} catch (error) {
// log or re-throw the error, depending on your needs
console.error('An error occurred:', error)
throw error
}
}

View File

@@ -0,0 +1,32 @@
import { VM } from '@ethereumjs/vm'
import { Account, Address } from '@ethereumjs/util'
export const keyPair = {
secretKey:
'0x3cd7232cd6f3fc66a57a6bedc1a8ed6c228fff0a327e169c2bcc5e869ed49511',
publicKey:
'0x0406cc661590d48ee972944b35ad13ff03c7876eae3fd191e8a2f77311b0a3c6613407b5005e63d7d8d76b89d5f900cde691497688bb281e07a5052ff61edebdc0',
}
export const insertAccount = async (vm: VM, address: Address) => {
const acctData = {
nonce: 0,
balance: BigInt('1000000000000000000'), // 1 eth
}
const account = Account.fromAccountData(acctData)
await vm.stateManager.putAccount(address, account)
}
export const getAccountNonce = async (
vm: VM,
accountPrivateKey: Uint8Array,
) => {
const address = Address.fromPrivateKey(accountPrivateKey)
const account = await vm.stateManager.getAccount(address)
if (account) {
return account.nonce
} else {
return BigInt(0)
}
}

View File

@@ -0,0 +1,59 @@
import { Interface, defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
import {
AccessListEIP2930TxData,
FeeMarketEIP1559TxData,
TxData,
} from '@ethereumjs/tx'
type TransactionsData =
| TxData
| AccessListEIP2930TxData
| FeeMarketEIP1559TxData
export const encodeFunction = (
method: string,
params?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
types: any[]
values: unknown[]
},
): string => {
const parameters = params?.types ?? []
const methodWithParameters = `function ${method}(${parameters.join(',')})`
const signatureHash = new Interface([methodWithParameters]).getSighash(method)
const encodedArgs = AbiCoder.encode(parameters, params?.values ?? [])
return signatureHash + encodedArgs.slice(2)
}
export const encodeDeployment = (
bytecode: string,
params?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
types: any[]
values: unknown[]
},
) => {
const deploymentData = '0x' + bytecode
if (params) {
const argumentsEncoded = AbiCoder.encode(params.types, params.values)
return deploymentData + argumentsEncoded.slice(2)
}
return deploymentData
}
export const buildTransaction = (
data: Partial<TransactionsData>,
): TransactionsData => {
const defaultData: Partial<TransactionsData> = {
gasLimit: 3_000_000_000_000_000,
gasPrice: 7,
value: 0,
data: '0x',
}
return {
...defaultData,
...data,
}
}

View File

@@ -0,0 +1,7 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "CommonJS",
"outDir": "./dist/commonjs"
}
}

View File

@@ -0,0 +1,7 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "ES2020",
"outDir": "./dist/esm"
}
}

View File

@@ -0,0 +1,62 @@
{
"compilerOptions": {
"rootDir": "src",
"target": "es2017",
"outDir": "dist",
"declaration": true,
"lib": [
"dom",
"dom.iterable",
"esnext"
],
"allowJs": true,
"checkJs": true,
"skipLibCheck": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noEmit": false,
"esModuleInterop": true,
"module": "CommonJS",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
// "incremental": true,
"noUncheckedIndexedAccess": true,
"baseUrl": ".",
"paths": {
"@/*": [
"./src/*"
]
}
},
"include": [
"src/**/*.ts",
"src/**/*.tsx",
"src/**/*.cjs",
"src/**/*.mjs"
],
"exclude": [
"node_modules"
],
// NEW: Options for file/directory watching
"watchOptions": {
// Use native file system events for files and directories
"watchFile": "useFsEvents",
"watchDirectory": "useFsEvents",
// Poll files for updates more frequently
// when they're updated a lot.
"fallbackPolling": "dynamicPriority",
// Don't coalesce watch notification
"synchronousWatchDirectory": true,
// Finally, two additional settings for reducing the amount of possible
// files to track work from these directories
"excludeDirectories": [
"**/node_modules",
"_build"
],
"excludeFiles": [
"build/fileWhichChangesOften.ts"
]
}
}

View File

@@ -7,7 +7,7 @@
"test": "jest"
},
"devDependencies": {
"@ezkljs/engine": "^2.4.5",
"@ezkljs/engine": "^9.4.4",
"@ezkljs/verify": "^0.0.6",
"@jest/types": "^29.6.3",
"@types/file-saver": "^2.0.5",
@@ -27,4 +27,4 @@
"tsconfig-paths": "^4.2.0",
"typescript": "5.1.6"
}
}
}

11
pnpm-lock.yaml generated
View File

@@ -6,8 +6,8 @@ settings:
devDependencies:
'@ezkljs/engine':
specifier: ^2.4.5
version: 2.4.5
specifier: ^9.4.4
version: 9.4.4
'@ezkljs/verify':
specifier: ^0.0.6
version: 0.0.6(buffer@6.0.3)
@@ -785,6 +785,13 @@ packages:
json-bigint: 1.0.0
dev: true
/@ezkljs/engine@9.4.4:
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
dependencies:
'@types/json-bigint': 1.0.1
json-bigint: 1.0.0
dev: true
/@ezkljs/verify@0.0.6(buffer@6.0.3):
resolution: {integrity: sha512-9DHoEhLKl1DBGuUVseXLThuMyYceY08Zymr/OsLH0zbdA9OoISYhb77j4QPm4ANRKEm5dCi8oHDqkwGbFc2xFQ==}
dependencies:

View File

@@ -12,15 +12,11 @@ pub enum BaseOp {
DotInit,
CumProdInit,
CumProd,
Identity,
Add,
Mult,
Sub,
SumInit,
Sum,
Neg,
Range { tol: i32 },
IsZero,
IsBoolean,
}
@@ -36,12 +32,8 @@ impl BaseOp {
let (a, b) = inputs;
match &self {
BaseOp::Add => a + b,
BaseOp::Identity => b,
BaseOp::Neg => -b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::Range { .. } => b,
BaseOp::IsZero => b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
@@ -73,19 +65,15 @@ impl BaseOp {
/// display func
pub fn as_str(&self) -> &'static str {
match self {
BaseOp::Identity => "IDENTITY",
BaseOp::Dot => "DOT",
BaseOp::DotInit => "DOTINIT",
BaseOp::CumProdInit => "CUMPRODINIT",
BaseOp::CumProd => "CUMPROD",
BaseOp::Add => "ADD",
BaseOp::Neg => "NEG",
BaseOp::Sub => "SUB",
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::Range { .. } => "RANGE",
BaseOp::IsZero => "ISZERO",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -93,8 +81,6 @@ impl BaseOp {
/// Returns the range of the query offset for this operation.
pub fn query_offset_rng(&self) -> (i32, usize) {
match self {
BaseOp::Identity => (0, 1),
BaseOp::Neg => (0, 1),
BaseOp::DotInit => (0, 1),
BaseOp::Dot => (-1, 2),
BaseOp::CumProd => (-1, 2),
@@ -104,8 +90,6 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::Range { .. } => (0, 1),
BaseOp::IsZero => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -113,8 +97,6 @@ impl BaseOp {
/// Returns the number of inputs for this operation.
pub fn num_inputs(&self) -> usize {
match self {
BaseOp::Identity => 1,
BaseOp::Neg => 1,
BaseOp::DotInit => 2,
BaseOp::Dot => 2,
BaseOp::CumProdInit => 1,
@@ -124,8 +106,6 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::Range { .. } => 1,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}
@@ -133,19 +113,15 @@ impl BaseOp {
/// Returns the number of outputs for this operation.
pub fn constraint_idx(&self) -> usize {
match self {
BaseOp::Identity => 0,
BaseOp::Neg => 0,
BaseOp::DotInit => 0,
BaseOp::Dot => 1,
BaseOp::Add => 0,
BaseOp::Sub => 0,
BaseOp::Mult => 0,
BaseOp::Range { .. } => 0,
BaseOp::Sum => 1,
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}

View File

@@ -188,31 +188,158 @@ impl<'source> FromPyObject<'source> for Tolerance {
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct DynamicLookups {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub lookup_selectors: BTreeMap<(usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub table_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub tables: Vec<VarTensor>,
}
impl DynamicLookups {
/// Returns a new [DynamicLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let single_col_dummy_var = VarTensor::dummy(col_size, 1);
Self {
lookup_selectors: BTreeMap::new(),
table_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone(), dummy_var.clone()],
tables: vec![
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
single_col_dummy_var.clone(),
],
}
}
}
/// A struct representing the selectors for the dynamic lookup tables
#[derive(Clone, Debug, Default)]
pub struct Shuffles {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub input_selectors: BTreeMap<(usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub reference_selectors: Vec<Selector>,
/// Inputs:
pub inputs: Vec<VarTensor>,
/// tables
pub references: Vec<VarTensor>,
}
impl Shuffles {
/// Returns a new [DynamicLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
let single_col_dummy_var = VarTensor::dummy(col_size, 1);
Self {
input_selectors: BTreeMap::new(),
reference_selectors: vec![],
inputs: vec![dummy_var.clone(), dummy_var.clone()],
references: vec![single_col_dummy_var.clone(), single_col_dummy_var.clone()],
}
}
}
/// A struct representing the selectors for the static lookup tables
#[derive(Clone, Debug, Default)]
pub struct StaticLookups<F: PrimeField + TensorType + PartialOrd> {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub index: VarTensor,
///
pub output: VarTensor,
///
pub input: VarTensor,
}
impl<F: PrimeField + TensorType + PartialOrd> StaticLookups<F> {
/// Returns a new [StaticLookups] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
selectors: BTreeMap::new(),
tables: BTreeMap::new(),
index: dummy_var.clone(),
output: dummy_var.clone(),
input: dummy_var,
}
}
}
/// A struct representing the selectors for custom gates
#[derive(Clone, Debug, Default)]
pub struct CustomGates {
/// the inputs to the accumulated operations.
pub inputs: Vec<VarTensor>,
/// the (currently singular) output of the accumulated operations.
pub output: VarTensor,
/// selector
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
}
impl CustomGates {
/// Returns a new [CustomGates] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
inputs: vec![dummy_var.clone(), dummy_var.clone()],
output: dummy_var,
selectors: BTreeMap::new(),
}
}
}
/// A struct representing the selectors for the range checks
#[derive(Clone, Debug, Default)]
pub struct RangeChecks<F: PrimeField + TensorType + PartialOrd> {
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many dynamic lookup ops.
pub selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Selectors for the dynamic lookup tables
pub ranges: BTreeMap<Range, RangeCheck<F>>,
///
pub index: VarTensor,
///
pub input: VarTensor,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeChecks<F> {
/// Returns a new [RangeChecks] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
selectors: BTreeMap::new(),
ranges: BTreeMap::new(),
index: dummy_var.clone(),
input: dummy_var,
}
}
}
/// Configuration for an accumulated arg.
#[derive(Clone, Debug, Default)]
pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
/// the inputs to the accumulated operations.
pub inputs: Vec<VarTensor>,
/// the VarTensor reserved for lookup operations (could be an element of inputs)
/// Note that you should be careful to ensure that the lookup_input is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub lookup_input: VarTensor,
/// the (currently singular) output of the accumulated operations.
pub output: VarTensor,
/// the VarTensor reserved for lookup operations (could be an element of inputs or the same as output)
/// Note that you should be careful to ensure that the lookup_output is not simultaneously assigned to by other non-lookup operations eg. in the case of composite ops.
pub lookup_output: VarTensor,
///
pub lookup_index: VarTensor,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure [BaseOp].
pub selectors: BTreeMap<(BaseOp, usize, usize), Selector>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub lookup_selectors: BTreeMap<(LookupOp, usize, usize), Selector>,
///
pub tables: BTreeMap<LookupOp, Table<F>>,
///
pub range_checks: BTreeMap<Range, RangeCheck<F>>,
/// [Selector]s generated when configuring the layer. We use a [BTreeMap] as we expect to configure many lookup ops.
pub range_check_selectors: BTreeMap<(Range, usize, usize), Selector>,
/// Custom gates
pub custom_gates: CustomGates,
/// StaticLookups
pub static_lookups: StaticLookups<F>,
/// [Selector]s for the dynamic lookup tables
pub dynamic_lookups: DynamicLookups,
/// [Selector]s for the range checks
pub range_checks: RangeChecks<F>,
/// [Selector]s for the shuffles
pub shuffles: Shuffles,
/// Activate sanity checks
pub check_mode: CheckMode,
_marker: PhantomData<F>,
@@ -221,19 +348,12 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
let dummy_var = VarTensor::dummy(col_size, num_inner_cols);
Self {
inputs: vec![dummy_var.clone(), dummy_var.clone()],
lookup_input: dummy_var.clone(),
output: dummy_var.clone(),
lookup_output: dummy_var.clone(),
lookup_index: dummy_var,
selectors: BTreeMap::new(),
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
custom_gates: CustomGates::dummy(col_size, num_inner_cols),
static_lookups: StaticLookups::dummy(col_size, num_inner_cols),
dynamic_lookups: DynamicLookups::dummy(col_size, num_inner_cols),
shuffles: Shuffles::dummy(col_size, num_inner_cols),
range_checks: RangeChecks::dummy(col_size, num_inner_cols),
check_mode: CheckMode::SAFE,
_marker: PhantomData,
}
@@ -266,10 +386,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
for j in 0..output.num_inner_cols() {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Neg, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Identity, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -314,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
vec![expected_output[base_op.constraint_idx()].clone()]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
@@ -373,16 +484,15 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
.collect();
Self {
selectors,
lookup_selectors: BTreeMap::new(),
range_check_selectors: BTreeMap::new(),
inputs: inputs.to_vec(),
lookup_input: VarTensor::Empty,
lookup_output: VarTensor::Empty,
lookup_index: VarTensor::Empty,
tables: BTreeMap::new(),
range_checks: BTreeMap::new(),
output: output.clone(),
custom_gates: CustomGates {
inputs: inputs.to_vec(),
output: output.clone(),
selectors,
},
static_lookups: StaticLookups::default(),
dynamic_lookups: DynamicLookups::default(),
shuffles: Shuffles::default(),
range_checks: RangeChecks::default(),
check_mode,
_marker: PhantomData,
}
@@ -403,8 +513,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
where
F: Field,
{
let mut selectors = BTreeMap::new();
if !index.is_advice() {
return Err("wrong input type for lookup index".into());
}
@@ -417,9 +525,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
// we borrow mutably twice so we need to do this dance
let table = if !self.tables.contains_key(nl) {
let table = if !self.static_lookups.tables.contains_key(nl) {
// as all tables have the same input we see if there's another table who's input we can reuse
let table = if let Some(table) = self.tables.values().next() {
let table = if let Some(table) = self.static_lookups.tables.values().next() {
Table::<F>::configure(
cs,
lookup_range,
@@ -430,7 +538,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
} else {
Table::<F>::configure(cs, lookup_range, logrows, nl, None)
};
self.tables.insert(nl.clone(), table.clone());
self.static_lookups.tables.insert(nl.clone(), table.clone());
table
} else {
return Ok(());
@@ -514,26 +622,193 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
res
});
}
selectors.insert((nl.clone(), x, y), multi_col_selector);
self.static_lookups
.selectors
.insert((nl.clone(), x, y), multi_col_selector);
}
}
self.lookup_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
if let VarTensor::Empty = self.static_lookups.input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
self.static_lookups.input = input.clone();
}
if let VarTensor::Empty = self.lookup_output {
if let VarTensor::Empty = self.static_lookups.output {
debug!("assigning lookup output");
self.lookup_output = output.clone();
self.static_lookups.output = output.clone();
}
if let VarTensor::Empty = self.lookup_index {
if let VarTensor::Empty = self.static_lookups.index {
debug!("assigning lookup index");
self.lookup_index = index.clone();
self.static_lookups.index = index.clone();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_dynamic_lookup(
&mut self,
cs: &mut ConstraintSystem<F>,
lookups: &[VarTensor; 3],
tables: &[VarTensor; 3],
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
for l in lookups.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
}
}
for t in tables.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
}
}
let one = Expression::Constant(F::ONE);
let s_ltable = cs.complex_selector();
for x in 0..lookups[0].num_blocks() {
for y in 0..lookups[0].num_inner_cols() {
let s_lookup = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_lookupq = cs.query_selector(s_lookup);
let mut expression = vec![];
let s_ltableq = cs.query_selector(s_ltable);
let mut lookup_queries = vec![one.clone()];
for lookup in lookups {
lookup_queries.push(match lookup {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut table_queries = vec![one.clone()];
for table in tables {
table_queries.push(match table {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = lookup_queries.into_iter().map(|c| c * s_lookupq.clone());
let rhs = table_queries.into_iter().map(|c| c * s_ltableq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.dynamic_lookups
.lookup_selectors
.entry((x, y))
.or_insert(s_lookup);
}
}
self.dynamic_lookups.table_selectors.push(s_ltable);
// if we haven't previously initialized the input/output, do so now
if self.dynamic_lookups.tables.is_empty() {
debug!("assigning dynamic lookup table");
self.dynamic_lookups.tables = tables.to_vec();
}
if self.dynamic_lookups.inputs.is_empty() {
debug!("assigning dynamic lookup input");
self.dynamic_lookups.inputs = lookups.to_vec();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_shuffles(
&mut self,
cs: &mut ConstraintSystem<F>,
inputs: &[VarTensor; 2],
references: &[VarTensor; 2],
) -> Result<(), Box<dyn Error>>
where
F: Field,
{
for l in inputs.iter() {
if !l.is_advice() {
return Err("wrong input type for dynamic lookup".into());
}
}
for t in references.iter() {
if !t.is_advice() || t.num_blocks() > 1 || t.num_inner_cols() > 1 {
return Err("wrong table type for dynamic lookup".into());
}
}
let one = Expression::Constant(F::ONE);
let s_reference = cs.complex_selector();
for x in 0..inputs[0].num_blocks() {
for y in 0..inputs[0].num_inner_cols() {
let s_input = cs.complex_selector();
cs.lookup_any("lookup", |cs| {
let s_inputq = cs.query_selector(s_input);
let mut expression = vec![];
let s_referenceq = cs.query_selector(s_reference);
let mut input_queries = vec![one.clone()];
for input in inputs {
input_queries.push(match input {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[x][y], Rotation(0))
}
_ => unreachable!(),
});
}
let mut ref_queries = vec![one.clone()];
for reference in references {
ref_queries.push(match reference {
VarTensor::Advice { inner: advices, .. } => {
cs.query_advice(advices[0][0], Rotation(0))
}
_ => unreachable!(),
});
}
let lhs = input_queries.into_iter().map(|c| c * s_inputq.clone());
let rhs = ref_queries.into_iter().map(|c| c * s_referenceq.clone());
expression.extend(lhs.zip(rhs));
expression
});
self.shuffles
.input_selectors
.entry((x, y))
.or_insert(s_input);
}
}
self.shuffles.reference_selectors.push(s_reference);
// if we haven't previously initialized the input/output, do so now
if self.shuffles.references.is_empty() {
debug!("assigning shuffles reference");
self.shuffles.references = references.to_vec();
}
if self.shuffles.inputs.is_empty() {
debug!("assigning shuffles input");
self.shuffles.inputs = inputs.to_vec();
}
Ok(())
}
/// Configures and creates lookup selectors
#[allow(clippy::too_many_arguments)]
pub fn configure_range_check(
@@ -547,23 +822,22 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
where
F: Field,
{
let mut selectors = BTreeMap::new();
if !input.is_advice() {
return Err("wrong input type for lookup input".into());
}
// we borrow mutably twice so we need to do this dance
let range_check =
if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
} else {
return Ok(());
};
let range_check = if let std::collections::btree_map::Entry::Vacant(e) =
self.range_checks.ranges.entry(range)
{
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range, logrows);
e.insert(range_check.clone());
range_check
} else {
return Ok(());
};
for x in 0..input.num_blocks() {
for y in 0..input.num_inner_cols() {
@@ -620,19 +894,20 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
res
});
}
selectors.insert((range, x, y), multi_col_selector);
self.range_checks
.selectors
.insert((range, x, y), multi_col_selector);
}
}
self.range_check_selectors.extend(selectors);
// if we haven't previously initialized the input/output, do so now
if let VarTensor::Empty = self.lookup_input {
debug!("assigning lookup input");
self.lookup_input = input.clone();
if let VarTensor::Empty = self.range_checks.input {
debug!("assigning range check input");
self.range_checks.input = input.clone();
}
if let VarTensor::Empty = self.lookup_index {
debug!("assigning lookup index");
self.lookup_index = index.clone();
if let VarTensor::Empty = self.range_checks.index {
debug!("assigning range check index");
self.range_checks.index = index.clone();
}
Ok(())
@@ -640,7 +915,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
/// layout_tables must be called before layout.
pub fn layout_tables(&mut self, layouter: &mut impl Layouter<F>) -> Result<(), Box<dyn Error>> {
for (i, table) in self.tables.values_mut().enumerate() {
for (i, table) in self.static_lookups.tables.values_mut().enumerate() {
if !table.is_assigned {
debug!(
"laying out table for {}",
@@ -661,7 +936,7 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
&mut self,
layouter: &mut impl Layouter<F>,
) -> Result<(), Box<dyn Error>> {
for range_check in self.range_checks.values_mut() {
for range_check in self.range_checks.ranges.values_mut() {
if !range_check.is_assigned {
debug!("laying out range check for {:?}", range_check.range);
range_check.layout(layouter)?;

File diff suppressed because it is too large Load Diff

View File

@@ -14,10 +14,17 @@ pub enum PolyOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
GatherND {
batch_dims: usize,
indices: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterND {
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -60,8 +67,6 @@ pub enum PolyOp {
len_prod: usize,
},
Pow(u32),
Pack(u32, u32),
GlobalSumPool,
Concat {
axis: usize,
},
@@ -91,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -110,8 +117,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Sum { .. } => "SUM".into(),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Pack(_, _) => "PACK".into(),
PolyOp::GlobalSumPool => "GLOBALSUMPOOL".into(),
PolyOp::Conv { .. } => "CONV".into(),
PolyOp::DeConv { .. } => "DECONV".into(),
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -181,13 +186,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
output_padding,
stride,
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
PolyOp::Pack(base, scale) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pack inputs".to_string()));
}
tensor::ops::pack(&inputs[0], F::from(*base as u64), *scale)
}
PolyOp::Pow(u) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
@@ -206,7 +204,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
tensor::ops::prod_axes(&inputs[0], axes)
}
PolyOp::GlobalSumPool => unreachable!(),
PolyOp::Concat { axis } => {
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
}
@@ -225,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
@@ -241,6 +250,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
}?;
Ok(ForwardResult { output: res })
@@ -288,7 +312,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
}
}
PolyOp::GatherND {
batch_dims,
indices,
} => {
if let Some(idx) = indices {
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
} else {
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
@@ -304,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterND { constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter_nd(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
)?
.into()
} else {
layouts::scatter_nd(config, region, values[..].try_into()?)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -334,10 +380,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
input
}
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
PolyOp::Pack(base, scale) => {
layouts::pack(config, region, values[..].try_into()?, *base, *scale)?
}
PolyOp::GlobalSumPool => unreachable!(),
PolyOp::Concat { axis } => layouts::concat(values[..].try_into()?, axis)?,
PolyOp::Slice { axis, start, end } => {
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
@@ -405,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
} else if matches!(self, PolyOp::ScatterElements { .. })
| matches!(self, PolyOp::ScatterND { .. })
{
vec![0, 2]
} else {
vec![]

View File

@@ -20,6 +20,66 @@ use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
index: usize,
col_coord: usize,
}
impl DynamicLookupIndex {
/// Create a new dynamic lookup index
pub fn new(index: usize, col_coord: usize) -> DynamicLookupIndex {
DynamicLookupIndex { index, col_coord }
}
/// Get the lookup index
pub fn index(&self) -> usize {
self.index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
/// update with another dynamic lookup index
pub fn update(&mut self, other: &DynamicLookupIndex) {
self.index += other.index;
self.col_coord += other.col_coord;
}
}
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct ShuffleIndex {
index: usize,
col_coord: usize,
}
impl ShuffleIndex {
/// Create a new dynamic lookup index
pub fn new(index: usize, col_coord: usize) -> ShuffleIndex {
ShuffleIndex { index, col_coord }
}
/// Get the lookup index
pub fn index(&self) -> usize {
self.index
}
/// Get the column coord
pub fn col_coord(&self) -> usize {
self.col_coord
}
/// update with another shuffle index
pub fn update(&mut self, other: &ShuffleIndex) {
self.index += other.index;
self.col_coord += other.col_coord;
}
}
/// Region error
#[derive(Debug, thiserror::Error)]
pub enum RegionError {
@@ -66,6 +126,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
@@ -80,6 +142,26 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants += n;
}
///
pub fn increment_dynamic_lookup_index(&mut self, n: usize) {
self.dynamic_lookup_index.index += n;
}
///
pub fn increment_dynamic_lookup_col_coord(&mut self, n: usize) {
self.dynamic_lookup_index.col_coord += n;
}
///
pub fn increment_shuffle_index(&mut self, n: usize) {
self.shuffle_index.index += n;
}
///
pub fn increment_shuffle_col_coord(&mut self, n: usize) {
self.shuffle_index.col_coord += n;
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
@@ -96,6 +178,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
@@ -109,6 +193,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
) -> RegionCtx<'a, F> {
let linear_coord = row * num_inner_cols;
RegionCtx {
@@ -117,6 +203,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
@@ -141,6 +229,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
@@ -156,8 +246,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
let region = None;
@@ -167,8 +255,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
linear_coord,
row,
total_constants,
used_lookups,
used_range_checks,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
@@ -227,6 +317,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
*output = output
.par_enum_map(|idx, _| {
@@ -242,8 +334,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
starting_linear_coord,
starting_constants,
self.num_inner_cols,
HashSet::new(),
HashSet::new(),
self.throw_range_check_error,
);
let res = inner_loop_function(idx, &mut local_reg);
@@ -263,14 +353,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
// update the range checks
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
res
})
.map_err(|e| {
log::error!("dummy_loop: {:?}", e);
Error::Synthesis
})?;
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
@@ -293,6 +388,28 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get range checks: {:?}", e))
})?;
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!(
"dummy_loop: failed to get dynamic lookup index: {:?}",
e
))
})?;
self.shuffle_index = Arc::try_unwrap(shuffle_index)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
Ok(())
}
@@ -363,6 +480,26 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.total_constants
}
/// Get the dynamic lookup index
pub fn dynamic_lookup_index(&self) -> usize {
self.dynamic_lookup_index.index
}
/// Get the dynamic lookup column coordinate
pub fn dynamic_lookup_col_coord(&self) -> usize {
self.dynamic_lookup_index.col_coord
}
/// Get the shuffle index
pub fn shuffle_index(&self) -> usize {
self.shuffle_index.index
}
/// Get the shuffle column coordinate
pub fn shuffle_col_coord(&self) -> usize {
self.shuffle_index.col_coord
}
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.used_lookups.clone()
@@ -412,6 +549,38 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
}
///
pub fn combined_dynamic_shuffle_coord(&self) -> usize {
self.dynamic_lookup_col_coord() + self.shuffle_col_coord()
}
/// Assign a valtensor to a vartensor
pub fn assign_dynamic_lookup(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
)
} else {
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor
pub fn assign_shuffle(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,

View File

@@ -357,8 +357,6 @@ mod matmul_col_ultra_overflow_double_col {
);
assert!(result.is_ok());
println!("done.");
}
}
@@ -476,8 +474,6 @@ mod matmul_col_ultra_overflow {
);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1280,8 +1276,6 @@ mod conv_col_ultra_overflow {
);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1435,8 +1429,6 @@ mod conv_relu_col_ultra_overflow {
);
assert!(result.is_ok());
println!("done.");
}
}
@@ -1574,6 +1566,280 @@ mod add {
}
}
#[cfg(test)]
mod dynamic_lookup {
use super::*;
const K: usize = 6;
const LEN: usize = 4;
const NUM_LOOP: usize = 5;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
tables: [[ValTensor<F>; 2]; NUM_LOOP],
lookups: [[ValTensor<F>; 2]; NUM_LOOP],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 2, LEN);
let b = VarTensor::new_advice(cs, K, 2, LEN);
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let f: VarTensor = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
config
.configure_dynamic_lookup(
cs,
&[a.clone(), b.clone(), c.clone()],
&[d.clone(), e.clone(), f.clone()],
)
.unwrap();
config
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
for i in 0..NUM_LOOP {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i],
&self.tables[i],
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.dynamic_lookup_col_coord(),
NUM_LOOP * self.tables[0][0].len()
);
assert_eq!(region.dynamic_lookup_index(), NUM_LOOP);
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn dynamiclookupcircuit() {
// parameters
let tables = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
[
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((loop_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
tables: tables.clone().try_into().unwrap(),
lookups: lookups.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
let lookups = (0..NUM_LOOP)
.map(|loop_idx| {
let prev_idx = if loop_idx == 0 {
NUM_LOOP - 1
} else {
loop_idx - 1
};
[
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
)),
ValTensor::from(Tensor::from(
(0..3).map(|i| Value::known(F::from((prev_idx * i * i) as u64 + 1))),
)),
]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
tables: tables.try_into().unwrap(),
lookups: lookups.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert!(prover.verify().is_err());
}
}
#[cfg(test)]
mod shuffle {
use super::*;
const K: usize = 6;
const LEN: usize = 4;
const NUM_LOOP: usize = 5;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
references: [[ValTensor<F>; 1]; NUM_LOOP],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 2, LEN);
let b = VarTensor::new_advice(cs, K, 2, LEN);
let c: VarTensor = VarTensor::new_advice(cs, K, 2, LEN);
let d = VarTensor::new_advice(cs, K, 1, LEN);
let e = VarTensor::new_advice(cs, K, 1, LEN);
let _constant = VarTensor::constant_cols(cs, K, LEN * NUM_LOOP, false);
let mut config =
Self::Config::configure(cs, &[a.clone(), b.clone()], &c, CheckMode::SAFE);
config
.configure_shuffles(cs, &[a.clone(), b.clone()], &[d.clone(), e.clone()])
.unwrap();
config
}
fn synthesize(
&self,
config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
for i in 0..NUM_LOOP {
layouts::shuffles(
&config,
&mut region,
&self.inputs[i],
&self.references[i],
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.shuffle_col_coord(),
NUM_LOOP * self.references[0][0].len()
);
assert_eq!(region.shuffle_index(), NUM_LOOP);
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
fn shufflecircuit() {
// parameters
let references = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
references: references.clone().try_into().unwrap(),
inputs: inputs.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
let prev_idx = if loop_idx == 0 {
NUM_LOOP - 1
} else {
loop_idx - 1
};
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * prev_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let circuit = MyCircuit::<F> {
references: references.try_into().unwrap(),
inputs: inputs.try_into().unwrap(),
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
assert!(prover.verify().is_err());
}
}
#[cfg(test)]
mod add_with_overflow {
use super::*;
@@ -1977,75 +2243,6 @@ mod pow {
}
}
#[cfg(test)]
mod pack {
use super::*;
const K: usize = 8;
const LEN: usize = 4;
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 1],
_marker: PhantomData<F>,
}
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, 1, LEN);
let b = VarTensor::new_advice(cs, K, 1, LEN);
let output = VarTensor::new_advice(cs, K, 1, LEN);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(
&mut region,
&self.inputs.clone(),
Box::new(PolyOp::Pack(2, 1)),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
fn packcircuit() {
// parameters
let a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64 + 1))));
let circuit = MyCircuit::<F> {
inputs: [ValTensor::from(a)],
_marker: PhantomData,
};
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
prover.assert_satisfied();
}
}
#[cfg(test)]
mod matmul_relu {
use super::*;
@@ -2333,7 +2530,5 @@ mod lookup_ultra_overflow {
);
assert!(result.is_ok());
println!("done.");
}
}

View File

@@ -77,7 +77,7 @@ pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
/// Default lookup safety margin
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";
/// Default Compress selectors
pub const DEFAULT_COMPRESS_SELECTORS: &str = "false";
pub const DEFAULT_DISABLE_SELECTOR_COMPRESSION: &str = "false";
/// Default render vk seperately
pub const DEFAULT_RENDER_VK_SEPERATELY: &str = "false";
/// Default VK sol path
@@ -402,9 +402,6 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
logrows: Option<u32>,
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check: CheckMode,
},
/// Loads model and input and runs mock prover (for testing)
Mock {
@@ -450,8 +447,8 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
/// Aggregates proofs :)
Aggregate {
@@ -515,8 +512,8 @@ pub enum Commands {
#[arg(short = 'W', long)]
witness: Option<PathBuf>,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -540,8 +537,8 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
num_runs: usize,
/// compress selectors
#[arg(long, default_value = DEFAULT_COMPRESS_SELECTORS)]
compress_selectors: bool,
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information

View File

@@ -23,6 +23,7 @@ use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::RunArgs;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::poly::commitment::Params;
@@ -63,7 +64,11 @@ use std::process::Command;
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
#[cfg(not(target_arch = "wasm32"))]
use std::sync::OnceLock;
#[cfg(not(target_arch = "wasm32"))]
use crate::EZKL_BUF_CAPACITY;
#[cfg(not(target_arch = "wasm32"))]
use std::io::BufWriter;
use std::time::Duration;
use tabled::Tabled;
use thiserror::Error;
@@ -140,13 +145,13 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
compiled_circuit,
transcript,
num_runs,
compress_selectors,
disable_selector_compression,
} => fuzz(
compiled_circuit,
witness,
transcript,
num_runs,
compress_selectors,
disable_selector_compression,
),
Commands::GenSrs { srs_path, logrows } => gen_srs_cmd(srs_path, logrows as u32),
#[cfg(not(target_arch = "wasm32"))]
@@ -154,8 +159,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
settings_path,
logrows,
check,
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
} => get_srs_cmd(srs_path, settings_path, logrows).await,
Commands::Table { model, args } => table(model, args),
#[cfg(feature = "render")]
Commands::RenderCircuit {
@@ -260,14 +264,14 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
vk_path,
pk_path,
witness,
compress_selectors,
disable_selector_compression,
} => setup(
compiled_circuit,
srs_path,
vk_path,
pk_path,
witness,
compress_selectors,
disable_selector_compression,
),
#[cfg(not(target_arch = "wasm32"))]
Commands::SetupTestEvmData {
@@ -331,7 +335,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
split_proofs,
compress_selectors,
disable_selector_compression,
} => setup_aggregate(
sample_snarks,
vk_path,
@@ -339,7 +343,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
split_proofs,
compress_selectors,
disable_selector_compression,
),
Commands::Aggregate {
proof_path,
@@ -487,16 +491,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
use std::io::Read;
let path = get_srs_path(logrows, srs_path);
let file = std::fs::File::open(path.clone())?;
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);
let mut buffer = vec![];
let bytes_read = std::io::BufReader::new(file).read_to_end(&mut buffer)?;
debug!("read {} bytes from SRS file", bytes_read);
let bytes_read = reader.read_to_end(&mut buffer)?;
info!(
"read {} bytes from file (vector of len = {})",
bytes_read,
buffer.len()
);
let hash = sha256::digest(buffer);
info!("SRS hash: {}", hash);
info!("file hash: {}", hash);
Ok(hash)
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
let path = get_srs_path(logrows, srs_path);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
Some(h) => h,
@@ -520,7 +536,6 @@ pub(crate) async fn get_srs_cmd(
srs_path: Option<PathBuf>,
settings_path: Option<PathBuf>,
logrows: Option<u32>,
check_mode: CheckMode,
) -> Result<String, Box<dyn Error>> {
// logrows overrides settings
@@ -548,18 +563,21 @@ pub(crate) async fn get_srs_cmd(
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
// check the SRS
if matches!(check_mode, CheckMode::SAFE) {
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated");
}
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated.");
info!("Saving SRS to disk...");
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
file.write_all(reader.get_ref())?;
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
params.write(&mut buffer)?;
info!("Saved SRS to disk.");
info!("SRS downloaded");
} else {
info!("SRS already exists at that path");
@@ -806,7 +824,6 @@ pub(crate) fn calibrate(
let settings = GraphSettings::load(&settings_path)?;
// now retrieve the run args
// we load the model to get the input and output shapes
// check if gag already exists
let model = Model::from_run_args(&settings.run_args, &model_path)?;
@@ -887,7 +904,12 @@ pub(crate) fn calibrate(
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
));
let key = (input_scale, param_scale, scale_rebase_multiplier);
let key = (
input_scale,
param_scale,
scale_rebase_multiplier,
div_rebasing,
);
forward_pass_res.insert(key, vec![]);
let local_run_args = RunArgs {
@@ -898,6 +920,18 @@ pub(crate) fn calibrate(
..settings.run_args.clone()
};
// if unix get a gag
#[cfg(unix)]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(unix)]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
@@ -938,31 +972,29 @@ pub(crate) fn calibrate(
}
}
let min_lookup_range = forward_pass_res
.get(&key)
.unwrap()
// drop the gag
#[cfg(unix)]
drop(_r);
#[cfg(unix)]
drop(_g);
let result = forward_pass_res.get(&key).ok_or("key not found")?;
let min_lookup_range = result
.iter()
.map(|x| x.min_lookup_inputs)
.min()
.unwrap_or(0);
let max_lookup_range = forward_pass_res
.get(&key)
.unwrap()
let max_lookup_range = result
.iter()
.map(|x| x.max_lookup_inputs)
.max()
.unwrap_or(0);
let max_range_size = forward_pass_res
.get(&key)
.unwrap()
.iter()
.map(|x| x.max_range_size)
.max()
.unwrap_or(0);
let max_range_size = result.iter().map(|x| x.max_range_size).max().unwrap_or(0);
let res = circuit.calibrate_from_min_max(
let res = circuit.calc_min_logrows(
(min_lookup_range, max_lookup_range),
max_range_size,
max_logrows,
@@ -1081,6 +1113,7 @@ pub(crate) fn calibrate(
best_params.run_args.input_scale,
best_params.run_args.param_scale,
best_params.run_args.scale_rebase_multiplier,
best_params.run_args.div_rebasing,
))
.ok_or("no params found")?
.iter()
@@ -1499,7 +1532,7 @@ pub(crate) fn setup(
vk_path: PathBuf,
pk_path: PathBuf,
witness: Option<PathBuf>,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<String, Box<dyn Error>> {
// these aren't real values so the sanity checks are mostly meaningless
let mut circuit = GraphCircuit::load(compiled_circuit)?;
@@ -1513,7 +1546,7 @@ pub(crate) fn setup(
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
disable_selector_compression,
)
.map_err(Box::<dyn Error>::from)?;
@@ -1654,7 +1687,7 @@ pub(crate) fn fuzz(
data_path: PathBuf,
transcript: TranscriptType,
num_runs: usize,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let passed = AtomicBool::new(true);
@@ -1664,7 +1697,7 @@ pub(crate) fn fuzz(
let logrows = circuit.settings().run_args.logrows;
info!("setting up tests");
#[cfg(unix)]
let _r = Gag::stdout()?;
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
@@ -1673,7 +1706,7 @@ pub(crate) fn fuzz(
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&params,
compress_selectors,
disable_selector_compression,
)
.map_err(Box::<dyn Error>::from)?;
@@ -1682,6 +1715,7 @@ pub(crate) fn fuzz(
let public_inputs = circuit.prepare_public_inputs(&data)?;
let strategy = KZGSingleStrategy::new(&params);
#[cfg(unix)]
std::mem::drop(_r);
info!("starting fuzzing");
@@ -1694,7 +1728,7 @@ pub(crate) fn fuzz(
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&new_params,
compress_selectors,
disable_selector_compression,
)
.map_err(|_| ())?;
@@ -1772,7 +1806,7 @@ pub(crate) fn fuzz(
let bad_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(
&circuit,
&new_params,
compress_selectors,
disable_selector_compression,
)
.map_err(|_| ())?;
@@ -1872,6 +1906,7 @@ pub(crate) fn run_fuzz_fn(
passed: &AtomicBool,
) {
let num_failures = AtomicI64::new(0);
#[cfg(unix)]
let _r = Gag::stdout().unwrap();
let pb = init_bar(num_runs as u64);
@@ -1885,6 +1920,7 @@ pub(crate) fn run_fuzz_fn(
pb.inc(1);
});
pb.finish_with_message("Done.");
#[cfg(unix)]
std::mem::drop(_r);
info!(
"num failures: {} out of {}",
@@ -1951,7 +1987,7 @@ pub(crate) fn setup_aggregate(
srs_path: Option<PathBuf>,
logrows: u32,
split_proofs: bool,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<String, Box<dyn Error>> {
// the K used for the aggregation circuit
let params = load_params_cmd(srs_path, logrows)?;
@@ -1965,7 +2001,7 @@ pub(crate) fn setup_aggregate(
let agg_pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(
&agg_circuit,
&params,
compress_selectors,
disable_selector_compression,
)?;
let agg_vk = agg_pk.get_vk();
@@ -2042,7 +2078,8 @@ pub(crate) fn verify(
let circuit_settings = GraphSettings::load(&settings_path)?;
let params = if reduced_srs {
load_params_cmd(srs_path, circuit_settings.log2_total_instances())?
// only need G_0 for the verification with shplonk
load_params_cmd(srs_path, 1)?
} else {
load_params_cmd(srs_path, circuit_settings.run_args.logrows)?
};

View File

@@ -12,6 +12,8 @@ pub mod utilities;
pub mod vars;
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
pub use input::DataSource;
@@ -449,6 +451,14 @@ pub struct GraphSettings {
pub total_assignments: usize,
/// total const size
pub total_const_size: usize,
/// total dynamic column size
pub total_dynamic_col_size: usize,
/// number of dynamic lookups
pub num_dynamic_lookups: usize,
/// number of shuffles
pub num_shuffles: usize,
/// total shuffle column size
pub total_shuffle_col_size: usize,
/// the shape of public inputs to the model (in order of appearance)
pub model_instance_shapes: Vec<Vec<usize>>,
/// model output scales
@@ -478,6 +488,16 @@ impl GraphSettings {
.ceil() as u32
}
fn dynamic_lookup_and_shuffle_logrows(&self) -> u32 {
(self.total_dynamic_col_size as f64 + self.total_shuffle_col_size as f64)
.log2()
.ceil() as u32
}
fn dynamic_lookup_and_shuffle_col_size(&self) -> usize {
self.total_dynamic_col_size + self.total_shuffle_col_size
}
fn module_constraint_logrows(&self) -> u32 {
(self.module_sizes.max_constraints() as f64).log2().ceil() as u32
}
@@ -570,6 +590,16 @@ impl GraphSettings {
|| self.run_args.param_visibility.is_hashed()
}
/// requires dynamic lookup
pub fn requires_dynamic_lookup(&self) -> bool {
self.num_dynamic_lookups > 0
}
/// requires dynamic shuffle
pub fn requires_shuffle(&self) -> bool {
self.num_shuffles > 0
}
/// any kzg visibility
pub fn module_requires_kzg(&self) -> bool {
self.run_args.input_visibility.is_kzgcommit()
@@ -1054,7 +1084,8 @@ impl GraphCircuit {
Ok(min_bits)
}
fn calc_min_logrows(
/// calculate the minimum logrows required for the circuit
pub fn calc_min_logrows(
&mut self,
min_max_lookup: Range,
max_range_size: i128,
@@ -1083,14 +1114,18 @@ impl GraphCircuit {
// These are hard lower limits, we can't overflow instances or modules constraints
let instance_logrows = self.settings().log2_total_instances();
let module_constraint_logrows = self.settings().module_constraint_logrows();
let dynamic_lookup_logrows = self.settings().dynamic_lookup_and_shuffle_logrows();
min_logrows = std::cmp::max(
min_logrows,
// max of the instance logrows and the module constraint logrows is the lower limit
[instance_logrows, module_constraint_logrows]
.iter()
.max()
.unwrap()
.clone(),
// max of the instance logrows and the module constraint logrows and the dynamic lookup logrows is the lower limit
*[
instance_logrows,
module_constraint_logrows,
dynamic_lookup_logrows,
]
.iter()
.max()
.unwrap(),
);
// These are upper limits, going above these is wasteful, but they are not hard limits
@@ -1100,11 +1135,10 @@ impl GraphCircuit {
max_logrows = std::cmp::min(
max_logrows,
// max of the model constraint logrows, min_bits, and the constants logrows is the upper limit
[model_constraint_logrows, min_bits, constants_logrows]
*[model_constraint_logrows, min_bits, constants_logrows]
.iter()
.max()
.unwrap()
.clone(),
.unwrap(),
);
// we now have a min and max logrows
@@ -1164,9 +1198,9 @@ impl GraphCircuit {
max_range_size: i128,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS {
return false;
} else if Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS {
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
|| Self::calc_num_cols(max_range_size, k) > MAX_NUM_LOOKUP_COLS
{
return false;
}
@@ -1175,19 +1209,30 @@ impl GraphCircuit {
settings.run_args.logrows = k;
settings.required_range_checks = vec![(0, max_range_size)];
let mut cs = ConstraintSystem::default();
// fetch gag
// if unix get a gag
#[cfg(unix)]
let _r = match gag::Gag::stdout() {
Ok(r) => Some(r),
Err(_) => None,
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(unix)]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
Self::configure_with_params(&mut cs, settings);
// drop the gag
#[cfg(unix)]
drop(_r);
#[cfg(unix)]
drop(_g);
#[cfg(feature = "mv-lookup")]
let cs = cs.chunk_lookups();
// quotient_poly_degree * params.n - 1 is the degree of the quotient polynomial
let max_degree = cs.degree();
#[cfg(unix)]
std::mem::drop(_r);
let quotient_poly_degree = (max_degree - 1) as u64;
// n = 2^k
let n = 1u64 << k;
@@ -1202,23 +1247,6 @@ impl GraphCircuit {
true
}
/// Calibrate the circuit to the supplied data.
pub fn calibrate_from_min_max(
&mut self,
min_max_lookup: Range,
max_range_size: i128,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
self.calc_min_logrows(
min_max_lookup,
max_range_size,
max_logrows,
lookup_safety_margin,
)?;
Ok(())
}
/// Runs the forward pass of the model / graph of computations and any associated hashing.
pub fn forward(
&self,
@@ -1518,34 +1546,18 @@ impl Circuit<Fp> for GraphCircuit {
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(
cs,
params.run_args.logrows as usize,
params.total_assignments,
params.run_args.num_inner_cols,
params.total_const_size,
params.module_requires_fixed(),
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
vars.instantiate_instance(
cs,
params.model_instance_shapes,
params.model_instance_shapes.clone(),
params.run_args.input_scale,
module_configs.instance,
);
let base = Model::configure(
cs,
&vars,
params.run_args.lookup_range,
params.run_args.logrows as usize,
params.required_lookups,
params.required_range_checks,
params.check_mode,
)
.unwrap();
let base = Model::configure(cs, &vars, &params).unwrap();
let model_config = ModelConfig { base, vars };

View File

@@ -99,6 +99,14 @@ pub type NodeGraph = BTreeMap<usize, NodeType>;
pub struct DummyPassRes {
/// number of rows use
pub num_rows: usize,
/// num dynamic lookups
pub num_dynamic_lookups: usize,
/// dynamic lookup col size
pub dynamic_lookup_col_coord: usize,
/// num shuffles
pub num_shuffles: usize,
/// shuffle
pub shuffle_col_coord: usize,
/// linear coordinate
pub linear_coord: usize,
/// total const size
@@ -540,6 +548,10 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
num_shuffles: res.num_shuffles,
total_shuffle_col_size: res.shuffle_col_coord,
total_const_size: res.total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
@@ -1003,24 +1015,24 @@ impl Model {
/// # Arguments
/// * `meta` - The constraint system.
/// * `vars` - The variables for the circuit.
/// * `run_args` - [RunArgs]
/// * `required_lookups` - The required lookup operations for the circuit.
/// * `settings` - [GraphSettings]
pub fn configure(
meta: &mut ConstraintSystem<Fp>,
vars: &ModelVars<Fp>,
lookup_range: Range,
logrows: usize,
required_lookups: Vec<LookupOp>,
required_range_checks: Vec<Range>,
check_mode: CheckMode,
settings: &GraphSettings,
) -> Result<PolyConfig<Fp>, Box<dyn Error>> {
info!("configuring model");
debug!("configuring model");
let lookup_range = settings.run_args.lookup_range;
let logrows = settings.run_args.logrows as usize;
let required_lookups = settings.required_lookups.clone();
let required_range_checks = settings.required_range_checks.clone();
let mut base_gate = PolyConfig::configure(
meta,
vars.advices[0..2].try_into()?,
&vars.advices[2],
check_mode,
settings.check_mode,
);
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
let input = &vars.advices[0];
@@ -1034,6 +1046,22 @@ impl Model {
base_gate.configure_range_check(meta, input, index, range, logrows)?;
}
if settings.requires_dynamic_lookup() {
base_gate.configure_dynamic_lookup(
meta,
vars.advices[0..3].try_into()?,
vars.advices[3..6].try_into()?,
)?;
}
if settings.requires_shuffle() {
base_gate.configure_shuffles(
meta,
vars.advices[0..2].try_into()?,
vars.advices[3..5].try_into()?,
)?;
}
Ok(base_gate)
}
@@ -1439,6 +1467,10 @@ impl Model {
max_lookup_inputs: region.max_lookup_inputs(),
min_lookup_inputs: region.min_lookup_inputs(),
max_range_size: region.max_range_size(),
num_dynamic_lookups: region.dynamic_lookup_index(),
dynamic_lookup_col_coord: region.dynamic_lookup_col_coord(),
num_shuffles: region.shuffle_index(),
shuffle_col_coord: region.shuffle_col_coord(),
outputs,
};

View File

@@ -23,7 +23,10 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::ops::{
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
},
change_axes::AxisOp,
cnn::{Conv, Deconv},
einsum::EinSum,
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
// Extract the max value
}
"ScatterNd" => {
if inputs.len() != 3 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"scatter nd".to_string(),
)));
};
// just verify it deserializes correctly
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherNd" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"gather nd".to_string(),
)));
};
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
let batch_dims = op.batch_dims;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherElements" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(

View File

@@ -420,20 +420,34 @@ impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
}
/// Allocate all columns that will be assigned to by a model.
pub fn new(
cs: &mut ConstraintSystem<F>,
logrows: usize,
var_len: usize,
num_inner_cols: usize,
num_constants: usize,
module_requires_fixed: bool,
) -> Self {
pub fn new(cs: &mut ConstraintSystem<F>, params: &GraphSettings) -> Self {
debug!("number of blinding factors: {}", cs.blinding_factors());
let advices = (0..3)
let logrows = params.run_args.logrows as usize;
let var_len = params.total_assignments;
let num_inner_cols = params.run_args.num_inner_cols;
let num_constants = params.total_const_size;
let module_requires_fixed = params.module_requires_fixed();
let requires_dynamic_lookup = params.requires_dynamic_lookup();
let requires_shuffle = params.requires_shuffle();
let dynamic_lookup_and_shuffle_size = params.dynamic_lookup_and_shuffle_col_size();
let mut advices = (0..3)
.map(|_| VarTensor::new_advice(cs, logrows, num_inner_cols, var_len))
.collect_vec();
if requires_dynamic_lookup || requires_shuffle {
let num_cols = if requires_dynamic_lookup { 3 } else { 2 };
for _ in 0..num_cols {
let dynamic_lookup =
VarTensor::new_advice(cs, logrows, 1, dynamic_lookup_and_shuffle_size);
if dynamic_lookup.num_blocks() > 1 {
panic!("dynamic lookup or shuffle should only have one block");
};
advices.push(dynamic_lookup);
}
}
debug!(
"model uses {} advice blocks (size={})",
advices.iter().map(|v| v.num_blocks()).sum::<usize>(),

View File

@@ -180,10 +180,8 @@ impl RunArgs {
if self.num_inner_cols < 1 {
return Err("num_inner_cols must be >= 1".into());
}
if self.tolerance.val > 0.0 {
if self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
}
if self.tolerance.val > 0.0 && self.output_visibility != Visibility::Public {
return Err("tolerance > 0.0 requires output_visibility to be public".into());
}
Ok(())
}

View File

@@ -484,7 +484,7 @@ where
pub fn create_keys<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
circuit: &C,
params: &'_ Scheme::ParamsProver,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<ProvingKey<Scheme::Curve>, halo2_proofs::plonk::Error>
where
C: Circuit<Scheme::Scalar>,
@@ -496,7 +496,7 @@ where
// Initialize verifying key
let now = Instant::now();
trace!("preparing VK");
let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?;
let vk = keygen_vk_custom(params, &empty_circuit, !disable_selector_compression)?;
let elapsed = now.elapsed();
info!("VK took {}.{}", elapsed.as_secs(), elapsed.subsec_millis());

View File

@@ -484,7 +484,6 @@ fn get_srs(
srs_path,
settings_path,
logrows,
CheckMode::SAFE,
))
.map_err(|e| {
let err_str = format!("Failed to get srs: {}", e);
@@ -622,7 +621,7 @@ fn mock_aggregate(
pk_path=PathBuf::from(DEFAULT_PK),
srs_path=None,
witness_path = None,
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
))]
fn setup(
model: PathBuf,
@@ -630,7 +629,7 @@ fn setup(
pk_path: PathBuf,
srs_path: Option<PathBuf>,
witness_path: Option<PathBuf>,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<bool, PyErr> {
crate::execute::setup(
model,
@@ -638,7 +637,7 @@ fn setup(
vk_path,
pk_path,
witness_path,
compress_selectors,
disable_selector_compression,
)
.map_err(|e| {
let err_str = format!("Failed to run setup: {}", e);
@@ -719,7 +718,7 @@ fn verify(
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
srs_path = None,
compress_selectors=DEFAULT_COMPRESS_SELECTORS.parse().unwrap(),
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
))]
fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
@@ -728,7 +727,7 @@ fn setup_aggregate(
logrows: u32,
split_proofs: bool,
srs_path: Option<PathBuf>,
compress_selectors: bool,
disable_selector_compression: bool,
) -> Result<bool, PyErr> {
crate::execute::setup_aggregate(
sample_snarks,
@@ -737,7 +736,7 @@ fn setup_aggregate(
srs_path,
logrows,
split_proofs,
compress_selectors,
disable_selector_compression,
)
.map_err(|e| {
let err_str = format!("Failed to setup aggregate: {}", e);
@@ -1104,6 +1103,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;

View File

@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&res), &dims)
}
/// Set a slice of the Tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
/// assert_eq!(a, b);
/// ```
pub fn set_slice(
&mut self,
indices: &[Range<usize>],
value: &Tensor<T>,
) -> Result<(), TensorError>
where
T: Send + Sync,
{
if indices.is_empty() {
return Ok(());
}
if self.dims.len() < indices.len() {
return Err(TensorError::DimError(format!(
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
indices, self.dims
)));
}
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
let omitted_dims = (indices.len()..self.dims.len())
.map(|i| self.dims[i])
.collect::<Vec<_>>();
for dim in &omitted_dims {
full_indices.push(0..*dim);
}
let full_dims = full_indices
.iter()
.map(|x| x.end - x.start)
.collect::<Vec<_>>();
// now broadcast the value to the full dims
let value = value.expand(&full_dims)?;
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
let _ = cartesian_coord
.iter()
.enumerate()
.map(|(i, e)| {
self.set(e, value[i].clone());
})
.collect::<Vec<_>>();
Ok(())
}
/// Get the array index from rows / columns indices.
///
/// ```
@@ -1526,18 +1588,20 @@ pub fn get_broadcasted_shape(
let num_dims_a = shape_a.len();
let num_dims_b = shape_b.len();
// reewrite the below using match
if num_dims_a == num_dims_b {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
match (num_dims_a, num_dims_b) {
(a, b) if a == b => {
let mut broadcasted_shape = Vec::with_capacity(num_dims_a);
for (dim_a, dim_b) in shape_a.iter().zip(shape_b.iter()) {
let max_dim = dim_a.max(dim_b);
broadcasted_shape.push(*max_dim);
}
Ok(broadcasted_shape)
}
Ok(broadcasted_shape)
} else if num_dims_a < num_dims_b {
Ok(shape_b.to_vec())
} else {
Ok(shape_a.to_vec())
(a, b) if a < b => Ok(shape_b.to_vec()),
(a, b) if a > b => Ok(shape_a.to_vec()),
_ => Err(Box::new(TensorError::DimError(
"Unknown condition for broadcasting".to_string(),
))),
}
}
////////////////////////

View File

@@ -2,7 +2,7 @@ use super::TensorError;
use crate::tensor::{Tensor, TensorType};
use itertools::Itertools;
use maybe_rayon::{
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
prelude::IntoParallelRefIterator,
};
use std::collections::{HashMap, HashSet};
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
Ok(output)
}
/// Gather ND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to gather
/// * `batch_dims` - Number of batch dimensions
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::gather_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3]),
/// &[2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 0, 1, 1]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
/// &[2, 2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 1, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 1).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
/// &[2, 2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
/// &[2, 2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if last_value > &(input_dims.len() - batch_dims) {
return Err(TensorError::DimMismatch("gather_nd".to_string()));
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
// cartesian coord over batch dims
let mut batch_cartesian_coord = input_dims[0..batch_dims]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let outputs = batch_cartesian_coord
.par_iter()
.map(|batch_coord| {
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&batch_slice)?;
index_slice.reshape(&index.dims()[batch_dims..])?;
let mut input_slice = input.get_slice(&batch_slice)?;
input_slice.reshape(&input.dims()[batch_dims..])?;
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let output = inner_cartesian_coord
.iter()
.map(|coord| {
let slice = coord
.iter()
.map(|x| *x..*x + 1)
.chain(batch_coord.iter().map(|x| *x..*x + 1))
.collect::<Vec<_>>();
let index_slice = index_slice
.get_slice(&slice)
.unwrap()
.iter()
.map(|x| *x..*x + 1)
.collect::<Vec<_>>();
input_slice.get_slice(&index_slice).unwrap()
})
.collect::<Tensor<_>>();
output.combine()
})
.collect::<Result<Vec<_>, _>>()?;
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
outputs.reshape(&output_size)?;
Ok(outputs)
}
/// Scatter ND.
/// This operator is the inverse of GatherND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to scatter
/// * `src` - Tensor of src
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::scatter_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[8],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[4, 3, 1, 7]),
/// &[4, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10, 11, 12]),
/// &[4],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 2]),
/// &[2, 1],
/// ).unwrap();
///
/// let src = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// ]),
/// &[2, 4, 4],
/// ).unwrap();
///
/// let result = scatter_nd(&x, &index, &src).unwrap();
///
/// let expected = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[2, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10]),
/// &[2],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[1, 1, 2],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9]),
/// &[1, 1],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
if last_value > &input_dims.len() {
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
}
let mut output = input.clone();
let cartesian_coord = index_dims[0..index_dims.len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
cartesian_coord
.iter()
.map(|coord| {
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok(())
})
.collect::<Result<Vec<_>, _>>()?;
Ok(output)
}
fn axes_op<T: TensorType + Send + Sync>(
a: &Tensor<T>,
axes: &[usize],

View File

@@ -4,6 +4,37 @@ use super::{
};
use halo2_proofs::{arithmetic::Field, plonk::Instance};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
val: F,
len: usize,
) -> ValTensor<F> {
let mut constant = Tensor::from(vec![ValType::Constant(val); len].into_iter());
constant.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(constant)
}
pub(crate) fn create_unit_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
len: usize,
) -> ValTensor<F> {
let mut unit = Tensor::from(vec![ValType::Constant(F::ONE); len].into_iter());
unit.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(unit)
}
pub(crate) fn create_zero_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
>(
len: usize,
) -> ValTensor<F> {
let mut zero = Tensor::from(vec![ValType::Constant(F::ZERO); len].into_iter());
zero.set_visibility(&crate::graph::Visibility::Fixed);
ValTensor::from(zero)
}
#[derive(Debug, Clone)]
/// A [ValType] is a wrapper around Halo2 value(s).
pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> {
@@ -318,6 +349,19 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
matches!(self, ValTensor::Instance { .. })
}
/// reverse order of elements whilst preserving the shape
pub fn reverse(&mut self) -> Result<(), Box<dyn Error>> {
match self {
ValTensor::Value { inner: v, .. } => {
v.reverse();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
}
};
Ok(())
}
///
pub fn set_initial_instance_offset(&mut self, offset: usize) {
if let ValTensor::Instance { initial_offset, .. } = self {
@@ -450,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(integer_evals.into_iter().into())
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
Ok(tensor)
}
/// Calls `pad_to_zero_rem` on the inner tensor.

View File

@@ -78,6 +78,17 @@ pub fn feltToBigEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String,
Ok(format!("{:?}", felt))
}
/// Converts a felt to a little endian string
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn feltToLittleEndian(array: wasm_bindgen::Clamped<Vec<u8>>) -> Result<String, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let repr = serde_json::to_string(&felt).unwrap();
let b: String = serde_json::from_str(&repr).unwrap();
Ok(b)
}
/// Converts a hex string to a byte array
#[wasm_bindgen]
#[allow(non_snake_case)]

View File

@@ -193,7 +193,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 77] = [
const TESTS: [&str; 79] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -275,6 +275,8 @@ mod native_tests {
"ltsf",
"remainder", //75
"bitshift",
"gather_nd",
"scatter_nd",
];
const WASM_TESTS: [&str; 46] = [
@@ -502,7 +504,7 @@ mod native_tests {
}
});
seq!(N in 0..=76 {
seq!(N in 0..=78 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -589,13 +591,16 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
}
}
#(#[test_case(TESTS[N])])*
@@ -853,7 +858,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm");
run_js_tests(path, test.to_string(), "testWasm", false);
// test_dir.close().unwrap();
}
@@ -866,7 +871,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm");
run_js_tests(path, test.to_string(), "testWasm", false);
test_dir.close().unwrap();
}
@@ -914,6 +919,7 @@ mod native_tests {
use crate::native_tests::kzg_fuzz;
use tempdir::TempDir;
use crate::native_tests::Hardfork;
use crate::native_tests::run_js_tests;
/// Currently only on chain inputs that return a non-negative value are supported.
const TESTS_ON_CHAIN_INPUT: [&str; 17] = [
@@ -1008,8 +1014,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1021,8 +1027,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify_render_seperately(2, path, test.to_string(), "private", "private", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", true);
test_dir.close().unwrap();
}
@@ -1035,8 +1041,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let mut _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "hashed", "private", "private");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1052,8 +1058,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let mut _anvil_child = crate::native_tests::start_anvil(false, hardfork);
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "private", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1065,8 +1071,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "hashed", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1078,8 +1084,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "hashed");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1091,8 +1097,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "kzgcommit", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1104,8 +1110,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "kzgcommit");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1116,8 +1122,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1507,7 +1513,7 @@ mod native_tests {
if any_output_scales_smol {
// set the tolerance to 0.0
settings.run_args.tolerance = Tolerance {
val: 0.0.into(),
val: 0.0,
scale: 0.0.into(),
};
settings
@@ -1849,6 +1855,7 @@ mod native_tests {
&format!("{}/{}/key.pk", test_dir, example_name),
"--vk-path",
&format!("{}/{}/key.vk", test_dir, example_name),
"--disable-selector-compression",
])
.status()
.expect("failed to execute process");
@@ -2178,15 +2185,17 @@ mod native_tests {
}
// run js browser evm verify tests for a given example
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str) {
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str, vk: bool) {
let example = format!("--example={}", example_name);
let dir = format!("--dir={}", test_dir);
let mut args = vec!["run", "test", js_test, &example, &dir];
let vk_string: String;
if vk {
vk_string = format!("--vk={}", vk);
args.push(&vk_string);
};
let status = Command::new("pnpm")
.args([
"run",
"test",
js_test,
&format!("--example={}", example_name),
&format!("--dir={}", test_dir),
])
.args(&args)
.status()
.expect("failed to execute process");
assert!(status.success());

View File

@@ -9,8 +9,8 @@ mod wasm32 {
use ezkl::pfsys;
use ezkl::wasm::{
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, genPk, genVk, genWitness, inputValidation, pkValidation,
poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
@@ -89,9 +89,16 @@ mod wasm32 {
.unwrap();
assert_eq!(integer, i as i128);
let hex_string = format!("{:?}", field_element);
let returned_string: String = feltToBigEndian(clamped).map_err(|_| "failed").unwrap();
let hex_string = format!("{:?}", field_element.clone());
let returned_string: String = feltToBigEndian(clamped.clone())
.map_err(|_| "failed")
.unwrap();
assert_eq!(hex_string, returned_string);
let repr = serde_json::to_string(&field_element).unwrap();
let little_endian_string: String = serde_json::from_str(&repr).unwrap();
let returned_string: String =
feltToLittleEndian(clamped).map_err(|_| "failed").unwrap();
assert_eq!(little_endian_string, returned_string);
}
}

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -27,6 +27,10 @@
"check_mode": "UNSAFE"
},
"num_rows": 16,
"total_dynamic_col_size": 0,
"num_dynamic_lookups": 0,
"num_shuffles": 0,
"total_shuffle_col_size": 0,
"total_assignments": 32,
"total_const_size": 8,
"model_instance_shapes": [

View File

@@ -1,28 +1,42 @@
import localEVMVerify, { Hardfork } from '@ezkljs/verify'
import localEVMVerify from '../../in-browser-evm-verifier/src/index'
import { serialize, deserialize } from '@ezkljs/engine/nodejs'
import { compileContracts } from './utils'
import * as fs from 'fs'
exports.USER_NAME = require("minimist")(process.argv.slice(2))["example"];
exports.EXAMPLE = require("minimist")(process.argv.slice(2))["example"];
exports.PATH = require("minimist")(process.argv.slice(2))["dir"];
exports.VK = require("minimist")(process.argv.slice(2))["vk"];
describe('localEVMVerify', () => {
let bytecode: string
let bytecode_verifier: string
let bytecode_vk: string | undefined = undefined
let proof: any
const example = exports.USER_NAME || "1l_mlp"
const example = exports.EXAMPLE || "1l_mlp"
const path = exports.PATH || "../ezkl/examples/onnx"
const vk = exports.VK || false
beforeEach(() => {
let solcOutput = compileContracts(path, example)
const solcOutput = compileContracts(path, example, 'kzg')
bytecode =
bytecode_verifier =
solcOutput.contracts['artifacts/Verifier.sol']['Halo2Verifier'].evm.bytecode
.object
console.log('size', bytecode.length)
if (vk) {
const solcOutput_vk = compileContracts(path, example, 'vk')
bytecode_vk =
solcOutput_vk.contracts['artifacts/Verifier.sol']['Halo2VerifyingKey'].evm.bytecode
.object
console.log('size of verifier bytecode', bytecode_verifier.length)
}
console.log('verifier bytecode', bytecode_verifier)
})
it('should return true when verification succeeds', async () => {
@@ -30,7 +44,9 @@ describe('localEVMVerify', () => {
proof = deserialize(proofFileBuffer)
const result = await localEVMVerify(proofFileBuffer, bytecode)
const result = await localEVMVerify(proofFileBuffer, bytecode_verifier, bytecode_vk)
console.log('result', result)
expect(result).toBe(true)
})
@@ -39,13 +55,16 @@ describe('localEVMVerify', () => {
let result: boolean = true
console.log(proof.proof)
try {
let index = Math.floor(Math.random() * (proof.proof.length - 2)) + 2
let number = (proof.proof[index] + 1) % 16
let index = Math.round((Math.random() * (proof.proof.length))) % proof.proof.length
console.log('index', index)
console.log('index', proof.proof[index])
let number = (proof.proof[index] + 1) % 256
console.log('index', index)
console.log('new number', number)
proof.proof[index] = number
console.log('index post', proof.proof[index])
const proofModified = serialize(proof)
result = await localEVMVerify(proofModified, bytecode)
result = await localEVMVerify(proofModified, bytecode_verifier, bytecode_vk)
} catch (error) {
// Check if the error thrown is the "out of gas" error.
expect(error).toEqual(

View File

@@ -38,7 +38,10 @@ describe('Generate witness, prove and verify', () => {
let pk = await readEzklArtifactsFile(path, example, 'key.pk');
let circuit_ser = await readEzklArtifactsFile(path, example, 'network.compiled');
circuit_settings_ser = await readEzklArtifactsFile(path, example, 'settings.json');
params_ser = await readEzklSrsFile(path, example);
// get the log rows from the circuit settings
const circuit_settings = deserialize(circuit_settings_ser) as any;
const logrows = circuit_settings.run_args.logrows as string;
params_ser = await readEzklSrsFile(logrows);
const startTimeProve = Date.now();
result = wasmFunctions.prove(witness, pk, circuit_ser, params_ser);
const endTimeProve = Date.now();
@@ -54,6 +57,7 @@ describe('Generate witness, prove and verify', () => {
let result
const vk = await readEzklArtifactsFile(path, example, 'key.vk');
const startTimeVerify = Date.now();
params_ser = await readEzklSrsFile("1");
result = wasmFunctions.verify(proof_ser, vk, circuit_settings_ser, params_ser);
const result_ref = wasmFunctions.verify(proof_ser_ref, vk, circuit_settings_ser, params_ser);
const endTimeVerify = Date.now();

View File

@@ -16,15 +16,7 @@ export async function readEzklArtifactsFile(path: string, example: string, filen
return new Uint8ClampedArray(buffer.buffer);
}
export async function readEzklSrsFile(path: string, example: string): Promise<Uint8ClampedArray> {
// const settingsPath = path.join(__dirname, '..', '..', 'ezkl', 'examples', 'onnx', example, 'settings.json');
const settingsPath = `${path}/${example}/settings.json`
const settingsBuffer = await fs.readFile(settingsPath, { encoding: 'utf-8' });
const settings = JSONBig.parse(settingsBuffer);
const logrows = settings.run_args.logrows;
// const filePath = path.join(__dirname, '..', '..', 'ezkl', 'examples', 'onnx', `kzg${logrows}.srs`);
// srs path is at $HOME/.ezkl/srs
export async function readEzklSrsFile(logrows: string): Promise<Uint8ClampedArray> {
const filePath = `${userHomeDir}/.ezkl/srs/kzg${logrows}.srs`
const buffer = await fs.readFile(filePath);
return new Uint8ClampedArray(buffer.buffer);
@@ -51,21 +43,21 @@ export function serialize(data: object | string): Uint8ClampedArray { // data is
return new Uint8ClampedArray(uint8Array.buffer);
}
export function getSolcInput(path: string, example: string) {
export function getSolcInput(path: string, example: string, name: string) {
return {
language: 'Solidity',
sources: {
'artifacts/Verifier.sol': {
content: fsSync.readFileSync(`${path}/${example}/kzg.sol`, 'utf-8'),
content: fsSync.readFileSync(`${path}/${example}/${name}.sol`, 'utf-8'),
},
// If more contracts were to be compiled, they should have their own entries here
},
settings: {
optimizer: {
enabled: true,
runs: 200,
runs: 1,
},
evmVersion: 'london',
evmVersion: 'shanghai',
outputSelection: {
'*': {
'*': ['abi', 'evm.bytecode'],
@@ -75,8 +67,8 @@ export function getSolcInput(path: string, example: string) {
}
}
export function compileContracts(path: string, example: string) {
const input = getSolcInput(path, example)
export function compileContracts(path: string, example: string, name: string) {
const input = getSolcInput(path, example, name)
const output = JSON.parse(solc.compile(JSON.stringify(input)))
let compilationFailed = false

Binary file not shown.

1
verifier_abi.json Normal file
View File

@@ -0,0 +1 @@
[{"type":"function","name":"verifyProof","inputs":[{"internalType":"bytes","name":"proof","type":"bytes"},{"internalType":"uint256[]","name":"instances","type":"uint256[]"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"nonpayable"}]

1
vk.abi Normal file
View File

@@ -0,0 +1 @@
[{"type":"constructor","inputs":[]}]