Compare commits

...

7 Commits

Author SHA1 Message Date
Ethan Cemer
56e2326be1 *nuke (#742) 2024-03-14 14:11:03 -05:00
Ethan Cemer
2be181db35 feat: merge @ezkljs/verify package into core repo. (#736) 2024-03-14 01:13:14 +00:00
jmjac
de9e3f2673 Add __version__ to python bindings (#739) 2024-03-13 14:22:20 +00:00
dante
a1450f8df7 feat: gather_nd/scatter_nd support (#737) 2024-03-11 22:05:40 +00:00
dante
ea535e2ecd refactor: use linear index constraints for gather and scatter (#735) 2024-03-09 18:00:21 +00:00
Alexander Camuto
f8aa91ed08 fix: windows compile 2024-03-06 11:40:44 +00:00
dante
a59e3780b2 chore: rm recip_int helper (#733) 2024-03-05 21:51:14 +00:00
43 changed files with 3303 additions and 578 deletions

View File

@@ -1,4 +1,4 @@
name: Build and Publish WASM<>JS Bindings
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
on:
workflow_dispatch:
@@ -14,7 +14,7 @@ defaults:
run:
working-directory: .
jobs:
wasm-publish:
publish-wasm-bindings:
name: publish-wasm-bindings
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
@@ -174,3 +174,40 @@ jobs:
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
in-browser-evm-ver-publish:
name: publish-in-browser-evm-verifier-package
needs: ["publish-wasm-bindings"]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4
- name: Update version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update @ezkljs/engine version in package.json
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
run: |
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
- name: Set up Node.js
uses: actions/setup-node@v3
with:
node-version: "18.12.1"
registry-url: "https://registry.npmjs.org"
- name: Publish to npm
run: |
cd in-browser-evm-verifier
npm install
npm run build
npm ci
npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

View File

@@ -303,12 +303,24 @@ jobs:
with:
node-version: "18.12.1"
cache: "pnpm"
- name: Install dependencies
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --no-frozen-lockfile
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
env:
CI: false
NODE_ENV: development
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --release --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
- name: Build @ezkljs/verify package
run: |
cd in-browser-evm-verifier
pnpm build:commonjs
cd ..
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
@@ -364,7 +376,7 @@ jobs:
with:
node-version: "18.12.1"
cache: "pnpm"
- name: Install dependencies
- name: Install dependencies for js tests
run: |
pnpm install --no-frozen-lockfile
env:

1
.gitignore vendored
View File

@@ -45,6 +45,7 @@ var/
*.whl
*.bak
node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key

87
Cargo.lock generated
View File

@@ -843,7 +843,7 @@ dependencies = [
"anstyle",
"bitflags 1.3.2",
"clap_lex",
"strsim 0.10.0",
"strsim",
]
[[package]]
@@ -1191,41 +1191,6 @@ dependencies = [
"cuda-config",
]
[[package]]
name = "darling"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.9.3",
"syn 1.0.109",
]
[[package]]
name = "darling_macro"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72"
dependencies = [
"darling_core",
"quote",
"syn 1.0.109",
]
[[package]]
name = "der"
version = "0.7.6"
@@ -1258,31 +1223,6 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "derive_builder"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0"
dependencies = [
"darling",
"derive_builder_core",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "derive_builder_core"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "derive_more"
version = "0.99.17"
@@ -2290,12 +2230,10 @@ dependencies = [
"icicle",
"log",
"maybe-rayon",
"plotters",
"rand_chacha",
"rand_core 0.6.4",
"rustacuda",
"sha3 0.9.1",
"tabbycat",
"tracing",
]
@@ -2624,12 +2562,6 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.4.0"
@@ -5014,12 +4946,6 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "strsim"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
[[package]]
name = "strsim"
version = "0.10.0"
@@ -5110,17 +5036,6 @@ dependencies = [
"libc",
]
[[package]]
name = "tabbycat"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c45590f0f859197b4545be1b17b2bc3cc7bb075f7d1cc0ea1dc6521c0bf256a3"
dependencies = [
"anyhow",
"derive_builder",
"regex",
]
[[package]]
name = "tabled"
version = "0.12.2"

View File

@@ -15,73 +15,96 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "main" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
"derive_serde",
] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.3.3", features = ["derive"]}
clap = { version = "4.3.3", features = ["derive"] }
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = ["float_roundtrip", "raw_value"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = [
"float_roundtrip",
"raw_value",
], optional = true }
log = { version = "0.4.17", default_features = false, optional = true }
thiserror = { version = "1.0.38", default_features = false }
hex = { version = "0.4.3", default_features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features=["derive_serde"]}
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
indicatif = {version = "0.17.5", features = ["rayon"]}
gag = { version = "1.0.0", default_features = false}
ethers = { version = "2.0.11", default_features = false, features = [
"ethers-solc",
] }
indicatif = { version = "0.17.5", features = ["rayon"] }
gag = { version = "1.0.0", default_features = false }
instant = { version = "0.1" }
reqwest = { version = "0.11.14", default-features = false, features = ["default-tls", "multipart", "stream"] }
reqwest = { version = "0.11.14", default-features = false, features = [
"default-tls",
"multipart",
"stream",
] }
openssl = { version = "0.10.55", features = ["vendored"] }
postgres = "0.19.5"
pg_bigdecimal = "0.1.5"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true}
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
tokio = { version = "1.26.0", default_features = false, features = [
"macros",
"rt",
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3 = { version = "0.20.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = [
"attributes",
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true}
env_logger = { version = "0.10.0", default_features = false, optional = true}
colored = { version = "2.0.0", default_features = false, optional = true }
env_logger = { version = "0.10.0", default_features = false, optional = true }
chrono = "0.4.31"
sha256 = "1.4.0"
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = [ "wasm-bindgen", "inaccurate" ] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional=true }
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"]}
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[dev-dependencies]
criterion = {version = "0.3", features = ["html_reports"]}
criterion = { version = "0.3", features = ["html_reports"] }
tempfile = "3.3.0"
lazy_static = "1.4.0"
mnist = "0.5"
@@ -153,11 +176,24 @@ required-features = ["ezkl"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup"]
render = ["halo2_proofs/dev-graph", "plotters"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled/color", "colored_json", "halo2_proofs/circuit-params"]
mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidity_verifier/mv-lookup"]
ezkl = [
"onnx",
"serde",
"serde_json",
"log",
"colored",
"env_logger",
"tabled/color",
"colored_json",
"halo2_proofs/circuit-params",
]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
"halo2_solidity_verifier/mv-lookup",
]
det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
@@ -165,7 +201,7 @@ no-banner = []
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix"}
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[profile.release]
rustflags = [ "-C", "relocation-model=pic" ]
rustflags = ["-C", "relocation-model=pic"]

View File

@@ -74,6 +74,10 @@ For more details visit the [docs](https://docs.ezkl.xyz).
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
#### In-browser EVM verifier
As an alternative to running the native Halo2 verifier as a WASM binding in the browser, you can use the in-browser EVM verifier. The source code of which you can find in the `in-browser-evm-verifier` directory and a README with instructions on how to use it.
### building the project 🔨

View File

@@ -0,0 +1,48 @@
from torch import nn
import json
import numpy as np
import tf2onnx
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
# gather_nd in tf then export to onnx
x = in1 = Input((15, 18,))
w = in2 = Input((15, 1), dtype=tf.int32)
x = tf.gather_nd(x, w, batch_dims=1)
tm = Model((in1, in2), x )
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 15, 18]
index_shape = [1, 15, 1]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)
# w = random int tensor
w = np.random.randint(0, 10, index_shape)
spec = tf.TensorSpec(shape, tf.float32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d, d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import sys
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}

Binary file not shown.

View File

@@ -0,0 +1,60 @@
# inbrowser-evm-verify
We would like the Solidity verifier to be canonical and usually all you ever need. For this, we need to be able to run that verifier in browser.
## How to use (Node js)
```ts
import localEVMVerify from '@ezkljs/verify';
// Load in the proof file as a buffer
const proofFileBuffer = fs.readFileSync(`${path}/${example}/proof.pf`)
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
const result = await localEVMVerify(proofFileBuffer, bytecode)
console.log('result', result)
```
**Note**: Run `ezkl create-evm-verifier` to get the Solidity verifier, with which you can retrieve the bytecode once compiled. We recommend compiling to the Shanghai hardfork target, else you will have to pass an additional parameter specifying the EVM version to the `localEVMVerify` function like so (for Paris hardfork):
```ts
import localEVMVerify, { hardfork } from '@ezkljs/verify';
const result = await localEVMVerify(proofFileBuffer, bytecode, hardfork['Paris'])
```
**Note**: You can also verify separated vk verifiers using the `localEVMVerify` function. Just pass the vk verifier bytecode as the third parameter like so:
```ts
import localEVMVerify from '@ezkljs/verify';
const result = await localEVMVerify(proofFileBuffer, verifierBytecode, VKBytecode)
```
## How to use (Browser)
```ts
import localEVMVerify from '@ezkljs/verify';
// Load in the proof file as a buffer using the web apis (fetch, FileReader, etc)
// We use fetch in this example to load the proof file as a buffer
const proofFileBuffer = await fetch(`${path}/${example}/proof.pf`).then(res => res.arrayBuffer())
// Stringified EZKL evm verifier bytecode (this is just an example don't use in production)
const bytecode = '0x608060405234801561001057600080fd5b5060d38061001f6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c8063cfae321714610046575b600080fd5b6100496100f1565b60405161005691906100f1565b60405180910390f35b'
const result = await browserEVMVerify(proofFileBuffer, bytecode)
console.log('result', result)
```
Output:
```ts
result: true
```

View File

@@ -0,0 +1,42 @@
{
"name": "@ezkljs/verify",
"version": "0.0.0",
"publishConfig": {
"access": "public"
},
"description": "Evm verify EZKL proofs in the browser.",
"main": "dist/commonjs/index.js",
"module": "dist/esm/index.js",
"types": "dist/commonjs/index.d.ts",
"files": [
"dist",
"LICENSE",
"README.md"
],
"scripts": {
"clean": "rm -r dist || true",
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",
"@ethereumjs/evm": "^2.0.0",
"@ethereumjs/statemanager": "^2.0.0",
"@ethereumjs/tx": "^5.0.0",
"@ethereumjs/util": "^9.0.0",
"@ethereumjs/vm": "^7.0.0",
"@ethersproject/abi": "^5.7.0",
"@ezkljs/engine": "^9.4.4",
"ethers": "^6.7.1",
"json-bigint": "^1.0.0"
},
"devDependencies": {
"@types/node": "^20.8.3",
"ts-loader": "^9.5.0",
"ts-node": "^10.9.1",
"resolve-tspaths": "^0.8.16",
"tsconfig-paths": "^4.2.0",
"typescript": "^5.2.2"
}
}

1479
in-browser-evm-verifier/pnpm-lock.yaml generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,145 @@
import { defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
import { Address, hexToBytes } from '@ethereumjs/util'
import { Chain, Common, Hardfork } from '@ethereumjs/common'
import { LegacyTransaction, LegacyTxData } from '@ethereumjs/tx'
// import { DefaultStateManager } from '@ethereumjs/statemanager'
// import { Blockchain } from '@ethereumjs/blockchain'
import { VM } from '@ethereumjs/vm'
import { EVM } from '@ethereumjs/evm'
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
import { getAccountNonce, insertAccount } from './utils/account-utils'
import { encodeVerifierCalldata } from '../nodejs/ezkl';
import { error } from 'console'
async function deployContract(
vm: VM,
common: Common,
senderPrivateKey: Uint8Array,
deploymentBytecode: string
): Promise<Address> {
// Contracts are deployed by sending their deployment bytecode to the address 0
// The contract params should be abi-encoded and appended to the deployment bytecode.
// const data =
const data = encodeDeployment(deploymentBytecode)
const txData = {
data,
nonce: await getAccountNonce(vm, senderPrivateKey),
}
const tx = LegacyTransaction.fromTxData(
buildTransaction(txData) as LegacyTxData,
{ common, allowUnlimitedInitCodeSize: true },
).sign(senderPrivateKey)
const deploymentResult = await vm.runTx({
tx,
skipBlockGasLimitValidation: true,
skipNonce: true
})
if (deploymentResult.execResult.exceptionError) {
throw deploymentResult.execResult.exceptionError
}
return deploymentResult.createdAddress!
}
async function verify(
vm: VM,
contractAddress: Address,
caller: Address,
proof: Uint8Array | Uint8ClampedArray,
vkAddress?: Address | Uint8Array,
): Promise<boolean> {
if (proof instanceof Uint8Array) {
proof = new Uint8ClampedArray(proof.buffer)
}
if (vkAddress) {
const vkAddressBytes = hexToBytes(vkAddress.toString())
const vkAddressArray = Array.from(vkAddressBytes)
let string = JSON.stringify(vkAddressArray)
const uint8Array = new TextEncoder().encode(string);
// Step 3: Convert to Uint8ClampedArray
vkAddress = new Uint8Array(uint8Array.buffer);
// convert uitn8array of length
error('vkAddress', vkAddress)
}
const data = encodeVerifierCalldata(proof, vkAddress)
const verifyResult = await vm.evm.runCall({
to: contractAddress,
caller: caller,
origin: caller, // The tx.origin is also the caller here
data: data,
})
if (verifyResult.execResult.exceptionError) {
throw verifyResult.execResult.exceptionError
}
const results = AbiCoder.decode(['bool'], verifyResult.execResult.returnValue)
return results[0]
}
/**
* Spins up an ephemeral EVM instance for executing the bytecode of a solidity verifier
* @param proof Json serialized proof file
* @param bytecode The bytecode of a compiled solidity verifier.
* @param bytecode_vk The bytecode of a contract that stores the vk. (Optional, only required if the vk is stored in a separate contract)
* @param evmVersion The evm version to use for the verification. (Default: London)
* @returns The result of the evm verification.
* @throws If the verify transaction reverts
*/
export default async function localEVMVerify(
proof: Uint8Array | Uint8ClampedArray,
bytecode_verifier: string,
bytecode_vk?: string,
evmVersion?: Hardfork,
): Promise<boolean> {
try {
const hardfork = evmVersion ? evmVersion : Hardfork['Shanghai']
const common = new Common({ chain: Chain.Mainnet, hardfork })
const accountPk = hexToBytes(
'0xe331b6d69882b4cb4ea581d88e0b604039a3de5967688d3dcffdd2270c0fd109', // anvil deterministic Pk
)
const evm = new EVM({
allowUnlimitedContractSize: true,
allowUnlimitedInitCodeSize: true,
})
const vm = await VM.create({ common, evm })
const accountAddress = Address.fromPrivateKey(accountPk)
await insertAccount(vm, accountAddress)
const verifierAddress = await deployContract(
vm,
common,
accountPk,
bytecode_verifier
)
if (bytecode_vk) {
const accountPk = hexToBytes("0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80"); // anvil deterministic Pk
const accountAddress = Address.fromPrivateKey(accountPk)
await insertAccount(vm, accountAddress)
const output = await deployContract(vm, common, accountPk, bytecode_vk)
const result = await verify(vm, verifierAddress, accountAddress, proof, output)
return true
}
const result = await verify(vm, verifierAddress, accountAddress, proof)
return result
} catch (error) {
// log or re-throw the error, depending on your needs
console.error('An error occurred:', error)
throw error
}
}

View File

@@ -0,0 +1,32 @@
import { VM } from '@ethereumjs/vm'
import { Account, Address } from '@ethereumjs/util'
export const keyPair = {
secretKey:
'0x3cd7232cd6f3fc66a57a6bedc1a8ed6c228fff0a327e169c2bcc5e869ed49511',
publicKey:
'0x0406cc661590d48ee972944b35ad13ff03c7876eae3fd191e8a2f77311b0a3c6613407b5005e63d7d8d76b89d5f900cde691497688bb281e07a5052ff61edebdc0',
}
export const insertAccount = async (vm: VM, address: Address) => {
const acctData = {
nonce: 0,
balance: BigInt('1000000000000000000'), // 1 eth
}
const account = Account.fromAccountData(acctData)
await vm.stateManager.putAccount(address, account)
}
export const getAccountNonce = async (
vm: VM,
accountPrivateKey: Uint8Array,
) => {
const address = Address.fromPrivateKey(accountPrivateKey)
const account = await vm.stateManager.getAccount(address)
if (account) {
return account.nonce
} else {
return BigInt(0)
}
}

View File

@@ -0,0 +1,59 @@
import { Interface, defaultAbiCoder as AbiCoder } from '@ethersproject/abi'
import {
AccessListEIP2930TxData,
FeeMarketEIP1559TxData,
TxData,
} from '@ethereumjs/tx'
type TransactionsData =
| TxData
| AccessListEIP2930TxData
| FeeMarketEIP1559TxData
export const encodeFunction = (
method: string,
params?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
types: any[]
values: unknown[]
},
): string => {
const parameters = params?.types ?? []
const methodWithParameters = `function ${method}(${parameters.join(',')})`
const signatureHash = new Interface([methodWithParameters]).getSighash(method)
const encodedArgs = AbiCoder.encode(parameters, params?.values ?? [])
return signatureHash + encodedArgs.slice(2)
}
export const encodeDeployment = (
bytecode: string,
params?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
types: any[]
values: unknown[]
},
) => {
const deploymentData = '0x' + bytecode
if (params) {
const argumentsEncoded = AbiCoder.encode(params.types, params.values)
return deploymentData + argumentsEncoded.slice(2)
}
return deploymentData
}
export const buildTransaction = (
data: Partial<TransactionsData>,
): TransactionsData => {
const defaultData: Partial<TransactionsData> = {
gasLimit: 3_000_000_000_000_000,
gasPrice: 7,
value: 0,
data: '0x',
}
return {
...defaultData,
...data,
}
}

View File

@@ -0,0 +1,7 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "CommonJS",
"outDir": "./dist/commonjs"
}
}

View File

@@ -0,0 +1,7 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "ES2020",
"outDir": "./dist/esm"
}
}

View File

@@ -0,0 +1,62 @@
{
"compilerOptions": {
"rootDir": "src",
"target": "es2017",
"outDir": "dist",
"declaration": true,
"lib": [
"dom",
"dom.iterable",
"esnext"
],
"allowJs": true,
"checkJs": true,
"skipLibCheck": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noEmit": false,
"esModuleInterop": true,
"module": "CommonJS",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
// "incremental": true,
"noUncheckedIndexedAccess": true,
"baseUrl": ".",
"paths": {
"@/*": [
"./src/*"
]
}
},
"include": [
"src/**/*.ts",
"src/**/*.tsx",
"src/**/*.cjs",
"src/**/*.mjs"
],
"exclude": [
"node_modules"
],
// NEW: Options for file/directory watching
"watchOptions": {
// Use native file system events for files and directories
"watchFile": "useFsEvents",
"watchDirectory": "useFsEvents",
// Poll files for updates more frequently
// when they're updated a lot.
"fallbackPolling": "dynamicPriority",
// Don't coalesce watch notification
"synchronousWatchDirectory": true,
// Finally, two additional settings for reducing the amount of possible
// files to track work from these directories
"excludeDirectories": [
"**/node_modules",
"_build"
],
"excludeFiles": [
"build/fileWhichChangesOften.ts"
]
}
}

View File

@@ -7,7 +7,7 @@
"test": "jest"
},
"devDependencies": {
"@ezkljs/engine": "^2.4.5",
"@ezkljs/engine": "^9.4.4",
"@ezkljs/verify": "^0.0.6",
"@jest/types": "^29.6.3",
"@types/file-saver": "^2.0.5",
@@ -27,4 +27,4 @@
"tsconfig-paths": "^4.2.0",
"typescript": "5.1.6"
}
}
}

11
pnpm-lock.yaml generated
View File

@@ -6,8 +6,8 @@ settings:
devDependencies:
'@ezkljs/engine':
specifier: ^2.4.5
version: 2.4.5
specifier: ^9.4.4
version: 9.4.4
'@ezkljs/verify':
specifier: ^0.0.6
version: 0.0.6(buffer@6.0.3)
@@ -785,6 +785,13 @@ packages:
json-bigint: 1.0.0
dev: true
/@ezkljs/engine@9.4.4:
resolution: {integrity: sha512-kNsTmDQa8mIiQ6yjJmBMwVgAAxh4nfs4NCtnewJifonyA8Mfhs+teXwwW8WhERRDoQPUofKO2pT8BPvV/XGIDA==}
dependencies:
'@types/json-bigint': 1.0.1
json-bigint: 1.0.0
dev: true
/@ezkljs/verify@0.0.6(buffer@6.0.3):
resolution: {integrity: sha512-9DHoEhLKl1DBGuUVseXLThuMyYceY08Zymr/OsLH0zbdA9OoISYhb77j4QPm4ANRKEm5dCi8oHDqkwGbFc2xFQ==}
dependencies:

View File

@@ -17,7 +17,6 @@ pub enum BaseOp {
Sub,
SumInit,
Sum,
IsZero,
IsBoolean,
}
@@ -35,7 +34,6 @@ impl BaseOp {
BaseOp::Add => a + b,
BaseOp::Sub => a - b,
BaseOp::Mult => a * b,
BaseOp::IsZero => b,
BaseOp::IsBoolean => b,
_ => panic!("nonaccum_f called on accumulating operation"),
}
@@ -76,7 +74,6 @@ impl BaseOp {
BaseOp::Mult => "MULT",
BaseOp::Sum => "SUM",
BaseOp::SumInit => "SUMINIT",
BaseOp::IsZero => "ISZERO",
BaseOp::IsBoolean => "ISBOOLEAN",
}
}
@@ -93,7 +90,6 @@ impl BaseOp {
BaseOp::Mult => (0, 1),
BaseOp::Sum => (-1, 2),
BaseOp::SumInit => (0, 1),
BaseOp::IsZero => (0, 1),
BaseOp::IsBoolean => (0, 1),
}
}
@@ -110,7 +106,6 @@ impl BaseOp {
BaseOp::Mult => 2,
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}
@@ -127,7 +122,6 @@ impl BaseOp {
BaseOp::SumInit => 0,
BaseOp::CumProd => 1,
BaseOp::CumProdInit => 0,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
}
}

View File

@@ -387,7 +387,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
nonaccum_selectors.insert((BaseOp::Add, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Sub, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::Mult, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsZero, i, j), meta.selector());
nonaccum_selectors.insert((BaseOp::IsBoolean, i, j), meta.selector());
}
}
@@ -432,12 +431,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
vec![expected_output[base_op.constraint_idx()].clone()]
}
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)

View File

@@ -132,41 +132,6 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
Ok(claimed_output)
}
fn recip_int<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input: &[ValTensor<F>; 1],
) -> Result<ValTensor<F>, Box<dyn Error>> {
// assert is boolean
let zero_inverse_val = tensor::ops::nonlinearities::zero_recip(1.0)[0];
// get values where input is 0
let zero_mask = equals_zero(config, region, input)?;
let zero_mask_minus_one = pairwise(
config,
region,
&[zero_mask.clone(), create_unit_tensor(1)],
BaseOp::Sub,
)?;
let zero_inverse_val = pairwise(
config,
region,
&[
zero_mask,
create_constant_tensor(i128_to_felt(zero_inverse_val), 1),
],
BaseOp::Mult,
)?;
pairwise(
config,
region,
&[zero_mask_minus_one, zero_inverse_val],
BaseOp::Add,
)
}
/// recip accumulated layout
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -175,10 +140,6 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
input_scale: F,
output_scale: F,
) -> Result<ValTensor<F>, Box<dyn Error>> {
if output_scale == F::ONE || output_scale == F::ZERO {
return recip_int(config, region, value);
}
let input = value[0].clone();
let input_dims = input.dims();
@@ -188,8 +149,11 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
// range_check_bracket is min of input_scale * output_scale and 2^F::S - 3
let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4));
let input_scale_ratio =
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len);
let input_scale_ratio = if range_check_len > 0 {
i128_to_felt(integer_input_scale * integer_output_scale / range_check_len)
} else {
F::ONE
};
let range_check_bracket = range_check_len / 2;
@@ -234,11 +198,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;
let equal_inverse_mask = equals(
config,
region,
&[claimed_output.clone(), zero_inverse],
)?;
let equal_inverse_mask = equals(config, region, &[claimed_output.clone(), zero_inverse])?;
// assert the two masks are equal
enforce_equality(
@@ -249,12 +209,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
let unit_scale = create_constant_tensor(i128_to_felt(range_check_len), 1);
let unit_mask = pairwise(
config,
region,
&[equal_zero_mask, unit_scale],
BaseOp::Mult,
)?;
let unit_mask = pairwise(config, region, &[equal_zero_mask, unit_scale], BaseOp::Mult)?;
// now add the unit mask to the rebased_div
let rebased_offset_div = pairwise(config, region, &[rebased_div, unit_mask], BaseOp::Add)?;
@@ -691,14 +646,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
dim_indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, index) = (values[0].clone(), values[1].clone());
input.flatten();
if !(dim_indices.all_prev_assigned() || region.is_dummy()) {
return Err("dim_indices must be assigned".into());
}
// these will be assigned as constants
let dim_indices: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
@@ -974,91 +928,25 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 2],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index_clone) = (values[0].clone(), values[1].clone());
let (input, mut index_clone) = (values[0].clone(), values[1].clone());
index_clone.flatten();
if index_clone.is_singleton() {
index_clone.reshape(&[1])?;
}
let mut assigned_len = vec![];
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index_clone.all_prev_assigned() {
index_clone = region.assign(&config.custom_gates.inputs[1], &index_clone)?;
assigned_len.push(index_clone.len());
}
if !assigned_len.is_empty() {
// safe to unwrap since we've just checked it has at least one element
region.increment(*assigned_len.iter().max().unwrap());
}
// Calculate the output tensor size
let input_dims = input.dims();
let mut output_size = input_dims.to_vec();
output_size[dim] = index_clone.dims()[0];
// these will be assigned as constants
let mut indices = Tensor::from((0..input.dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let linear_index =
linearize_element_index(config, region, &[index_clone], input_dims, dim, true)?;
let mut iteration_dims = output_size.clone();
iteration_dims[dim] = 1;
let mut output = select(config, region, &[input, linear_index])?;
// Allocate memory for the output tensor
let cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut results = HashMap::new();
for coord in cartesian_coord {
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dims[dim];
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let res = select(
config,
region,
&[sliced_input, index_clone.clone()],
indices.clone(),
)?;
results.insert(coord, res);
}
// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
let coord = cartesian_coord[i].clone();
let mut key = coord.clone();
key[dim] = 0;
let result = &results.get(&key).ok_or("missing result")?;
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
Ok::<ValType<F>, region::RegionError>(o)
})?;
// Reshape the output tensor
if index_clone.is_singleton() {
output_size.remove(dim);
}
output.reshape(&output_size)?;
Ok(output.into())
Ok(output)
}
/// Gather accumulated layout
@@ -1067,82 +955,387 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index) = (values[0].clone(), values[1].clone());
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let (input, index) = (values[0].clone(), values[1].clone());
assert_eq!(input.dims().len(), index.dims().len());
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
}
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
}
region.increment(std::cmp::max(input.len(), index.len()));
// Calculate the output tensor size
let input_dims = input.dims();
let output_size = index.dims().to_vec();
// these will be assigned as constants
let mut indices = Tensor::from((0..input_dims[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let linear_index = linearize_element_index(config, region, &[index], input.dims(), dim, false)?;
let mut iteration_dims = output_size.clone();
iteration_dims[dim] = 1;
let mut output = select(config, region, &[input, linear_index.clone()])?;
output.reshape(&output_size)?;
Ok((output, linear_index))
}
/// Gather accumulated layout
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
batch_dims: usize,
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let (input, index) = (values[0].clone(), values[1].clone());
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if index_dims.last() > Some(&(input_dims.len() - batch_dims)) {
return Err(TensorError::DimMismatch("gather_nd".to_string()).into());
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
let linear_index = linearize_nd_index(config, region, &[index], input.dims(), batch_dims)?;
let mut output = select(config, region, &[input, linear_index.clone()])?;
output.reshape(&output_size)?;
Ok((output, linear_index))
}
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
dims: &[usize],
dim: usize,
is_flat_index: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let index = values[0].clone();
if !is_flat_index {
assert_eq!(index.dims().len(), dims.len());
// if the index is already flat, return it
if index.dims().len() == 1 {
return Ok(index);
}
}
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
let mut res = 1;
for dim in dims.iter().skip(i + 1) {
res *= dim;
}
Ok::<_, region::RegionError>(F::from(res as u64))
})?;
let iteration_dims = if is_flat_index {
let mut dims = dims.to_vec();
dims[dim] = index.len();
dims
} else {
index.dims().to_vec()
};
// Allocate memory for the output tensor
let cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut results = HashMap::new();
let val_dim_multiplier: ValTensor<F> = dim_multiplier
.get_slice(&[dim..dim + 1])?
.map(|x| ValType::Constant(x))
.into();
for coord in cartesian_coord {
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dims[dim];
let mut output = Tensor::new(None, &[cartesian_coord.len()])?;
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
let coord = cartesian_coord[i].clone();
let slice: Vec<Range<usize>> = if is_flat_index {
coord[dim..dim + 1].iter().map(|x| *x..*x + 1).collect()
} else {
coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>()
};
slice[dim] = 0..output_size[dim];
let mut sliced_index = index.get_slice(&slice)?;
sliced_index.flatten();
let index_val = index.get_slice(&slice)?;
let res = select(
let mut const_offset = F::ZERO;
for i in 0..dims.len() {
if i != dim {
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
}
let const_offset = create_constant_tensor(const_offset, 1);
let res = pairwise(
config,
region,
&[sliced_input, sliced_index],
indices.clone(),
&[index_val, val_dim_multiplier.clone()],
BaseOp::Mult,
)?;
results.insert(coord, res);
}
let res = pairwise(config, region, &[res, const_offset], BaseOp::Add)?;
// Allocate memory for the output tensor
let cartesian_coord = output_size
Ok(res.get_inner_tensor()?[0].clone())
};
region.apply_in_loop(&mut output, inner_loop_function)?;
Ok(output.into())
}
/// Takes a tensor representing a nd index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims integer b, this operator gathers slices of data into an output tensor of rank q + r - indices_shape[-1] - 1 - b.
/// indices is an q-dimensional integer tensor, best thought of as a (q-1)-dimensional tensor of index-tuples into data, where each element defines a slice of data
/// batch_dims (denoted as b) is an integer indicating the number of batch dimensions, i.e the leading b number of dimensions of data tensor and indices are representing the batches, and the gather starts from the b+1 dimension.
/// Some salient points about the inputs rank and shape:
/// r >= 1 and q >= 1 are to be honored. There is no dependency condition to be met between ranks r and q
/// The first b dimensions of the shape of indices tensor and data tensor must be equal.
/// b < min(q, r) is to be honored.
/// The indices_shape[-1] should have a value between 1 (inclusive) and rank r-b (inclusive)
/// All values in indices are expected to be within bounds [-s, s-1] along axis of size s (i.e.) -data_shape[i] <= indices[...,i] <= data_shape[i] - 1. It is an error if any of the index values are out of bounds.
// The output is computed as follows:
/// The output tensor is obtained by mapping each index-tuple in the indices tensor to the corresponding slice of the input data.
/// If indices_shape[-1] > r-b => error condition
/// If indices_shape[-1] == r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b, where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
/// Let us think of each such r-b ranked tensor as indices_slice. Each scalar value corresponding to data[0:b-1,indices_slice] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Example 1 below)
/// If indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b. Let us think of each such tensors as indices_slice. Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Examples 2, 3, 4 and 5 below)
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
dims: &[usize],
batch_dims: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let index = values[0].clone();
let index_dims = index.dims().to_vec();
let last_dim = index.dims().last().unwrap();
let input_rank = dims[batch_dims..].len();
let dim_multiplier: Tensor<usize> = Tensor::new(None, &[dims.len()])?;
let dim_multiplier: Tensor<F> = dim_multiplier.par_enum_map(|i, _| {
let mut res = 1;
for dim in dims.iter().skip(i + 1) {
res *= dim;
}
Ok::<_, region::RegionError>(F::from(res as u64))
})?;
let iteration_dims = index.dims()[0..batch_dims].to_vec();
let mut batch_cartesian_coord = iteration_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let output = Tensor::new(None, &output_size)?.par_enum_map(|i, _: ValType<F>| {
let coord = cartesian_coord[i].clone();
let mut key = coord.clone();
key[dim] = 0;
let result = &results.get(&key).ok_or("missing result")?;
let o = result.get_inner_tensor().map_err(|_| "missing tensor")?[coord[dim]].clone();
Ok::<ValType<F>, region::RegionError>(o)
})?;
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let index_dim_multiplier: ValTensor<F> = dim_multiplier
.get_slice(&[batch_dims..dims.len()])?
.map(|x| ValType::Constant(x))
.into();
let mut outer_results = vec![];
for coord in batch_cartesian_coord {
let slice: Vec<Range<usize>> = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&slice)?;
index_slice.reshape(&index_dims[batch_dims..])?;
// expand the index to the full dims by iterating over the rest of the dims and inserting constants
// eg in the case
// batch_dims = 0
// data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
// indices = [[0,1],[1,0]] # indices_shape = [2, 2]
// output = [[2,3],[4,5]] # output_shape = [2, 2]
// the index should be expanded to the shape [2,2,3]: [[0,1,0],[0,1,1],[1,0,0],[1,0,1]]
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let indices = if last_dim < &input_rank {
inner_cartesian_coord
.iter()
.map(|x| {
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index = index_slice.get_slice(&slice)?;
// map over cartesian coord of rest of dims and insert constants
let grid = (*last_dim..input_rank)
.map(|x| 0..dims[x])
.multi_cartesian_product();
Ok(grid
.map(|x| {
let index = index.clone();
let constant_valtensor: ValTensor<F> = Tensor::from(
x.into_iter().map(|x| ValType::Constant(F::from(x as u64))),
)
.into();
index.concat(constant_valtensor)
})
.collect::<Result<Vec<_>, TensorError>>()?)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>()
} else {
inner_cartesian_coord
.iter()
.map(|x| {
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
Ok(index_slice.get_slice(&slice)?)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
};
let mut const_offset = F::ZERO;
for i in 0..batch_dims {
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
let const_offset = create_constant_tensor(const_offset, 1);
let mut results = vec![];
for index_val in indices {
let mut index_val = index_val.clone();
index_val.flatten();
let res = pairwise(
config,
region,
&[index_val.clone(), index_dim_multiplier.clone()],
BaseOp::Mult,
)?;
let res = res.concat(const_offset.clone())?;
let res = sum(config, region, &[res])?;
results.push(res.get_inner_tensor()?.clone());
// assert than res is less than the product of the dims
assert!(
res.get_int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as i128),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
let result_tensor = Tensor::from(results.into_iter());
outer_results.push(result_tensor.combine()?);
}
let output = Tensor::from(outer_results.into_iter());
let output = output.combine()?;
Ok(output.into())
}
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
ordered: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, fullset) = (values[0].clone(), values[1].clone());
let set_len = fullset.len();
input.flatten();
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
let mut fullset_evals = fullset.get_int_evals()?.into_iter().collect::<Vec<_>>();
// get the difference between the two vectors
for eval in input_evals.iter() {
// delete first occurence of that value
if let Some(pos) = fullset_evals.iter().position(|x| x == eval) {
fullset_evals.remove(pos);
}
}
// if fullset + input is the same length, then input is a subset of fullset, else randomly delete elements, this is a patch for
// the fact that we can't have a tensor of unknowns when using constant during gen-settings
if fullset_evals.len() != set_len - input.len() {
fullset_evals.truncate(set_len - input.len());
}
fullset_evals
.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
let dim = fullset.len() - input.len();
Tensor::new(Some(&vec![Value::<F>::unknown(); dim]), &[dim])?.into()
};
// assign the claimed output
claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
// input and claimed output should be the shuffles of fullset
// concatentate input and claimed output
let input_and_claimed_output = input.concat(claimed_output.clone())?;
// assert that this is a permutation/shuffle
shuffles(
config,
region,
&[input_and_claimed_output.clone()],
&[fullset.clone()],
)?;
if ordered {
// assert that the claimed output is sorted
claimed_output = _sort_ascending(config, region, &[claimed_output])?;
}
Ok(claimed_output)
}
/// Gather accumulated layout
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -1150,88 +1343,158 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 3],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone());
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
assert_eq!(input.dims().len(), index.dims().len());
let mut assigned_len = vec![];
if !input.all_prev_assigned() {
input = region.assign(&config.custom_gates.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
assigned_len.push(index.len());
}
if !src.all_prev_assigned() {
src = region.assign(&config.custom_gates.output, &src)?;
assigned_len.push(src.len());
region.increment(index.len());
}
if !assigned_len.is_empty() {
// safe to unwrap since we've just checked it has at least one element
region.increment(*assigned_len.iter().max().unwrap());
}
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
// Calculate the output tensor size
let input_dim = input.dims()[dim];
let output_size = index.dims().to_vec();
let claimed_output: ValTensor<F> = if is_assigned {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
// these will be assigned as constants
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut output: Tensor<()> = Tensor::new(None, &output_size)?;
let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| {
let coord = cartesian_coord[i].clone();
let index_val = index.get_inner_tensor()?.get(&coord);
let src_val = src.get_inner_tensor()?.get(&coord);
let src_valtensor: ValTensor<F> = Tensor::from([src_val.clone()].into_iter()).into();
let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dim;
let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();
let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();
let mask = equals(config, region, &[index_valtensor, indices.clone()])?;
let res = iff(config, region, &[mask, src_valtensor, sliced_input])?;
let input_cartesian_coord = slice.into_iter().multi_cartesian_product();
let mutable_input_inner = input.get_inner_tensor_mut()?;
for (i, r) in res.get_inner_tensor()?.iter().enumerate() {
let coord = input_cartesian_coord
.clone()
.nth(i)
.ok_or("invalid coord")?;
*mutable_input_inner.get_mut(&coord) = r.clone();
}
Ok(())
res.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
output
.iter_mut()
.enumerate()
.map(|(i, _)| inner_loop_function(i, region))
.collect::<Result<Vec<()>, Box<dyn Error>>>()?;
// assign the claimed output
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
Ok(input)
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) = gather_elements(
config,
region,
&[claimed_output.clone(), index.clone()],
dim,
)?;
// assert this is equal to the src
enforce_equality(config, region, &[gather_src, src])?;
let full_index_set: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let input_indices = get_missing_set_elements(
config,
region,
&[linear_index, full_index_set.clone()],
true,
)?;
claimed_output.flatten();
let (gather_input, _) = gather_elements(
config,
region,
&[claimed_output.clone(), input_indices.clone()],
0,
)?;
// assert this is a subset of the input
dynamic_lookup(
config,
region,
&[input_indices, gather_input],
&[full_index_set, input.clone()],
)?;
claimed_output.reshape(input.dims())?;
Ok(claimed_output)
}
/// Scatter Nd
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (input, mut index, src) = (values[0].clone(), values[1].clone(), values[2].clone());
if !index.all_prev_assigned() {
index = region.assign(&config.custom_gates.inputs[1], &index)?;
region.increment(index.len());
}
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter_nd(&input_inner, &index_inner, &src_inner)?;
res.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
&[input.len()],
)?
.into()
};
// assign the claimed output
let mut claimed_output = region.assign(&config.custom_gates.output, &claimed_output)?;
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) =
gather_nd(config, region, &[claimed_output.clone(), index.clone()], 0)?;
// assert this is equal to the src
enforce_equality(config, region, &[gather_src, src])?;
let full_index_set: ValTensor<F> =
Tensor::from((0..input.len() as u64).map(|x| ValType::Constant(F::from(x)))).into();
let input_indices = get_missing_set_elements(
config,
region,
&[linear_index, full_index_set.clone()],
true,
)?;
// now that it is flattened we can gather over elements on dim 0
claimed_output.flatten();
let (gather_input, _) = gather_elements(
config,
region,
&[claimed_output.clone(), input_indices.clone()],
0,
)?;
// assert this is a subset of the input
dynamic_lookup(
config,
region,
&[input_indices, gather_input],
&[full_index_set, input.clone()],
)?;
claimed_output.reshape(input.dims())?;
Ok(claimed_output)
}
/// sum accumulated layout
@@ -1488,17 +1751,11 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// these will be assigned as constants
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let argmax = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
argmax(config, region, values, indices.clone())
};
let argmax =
move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> { argmax(config, region, values) };
// calculate value of output
axes_wise_op(config, region, values, &[dim], argmax)
@@ -1524,18 +1781,12 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// calculate value of output
// these will be assigned as constants
let mut indices = Tensor::from((0..values[0].dims()[dim] as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.custom_gates.inputs[1], &indices.try_into()?)?;
region.increment(indices.len());
let argmin = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
argmin(config, region, values, indices.clone())
};
let argmin =
move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> { argmin(config, region, values) };
axes_wise_op(config, region, values, &[dim], argmin)
}
@@ -1851,7 +2102,8 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
// take the product of diff and output
let prod_check = pairwise(config, region, &[values, output.clone()], BaseOp::Mult)?;
is_zero_identity(config, region, &[prod_check], false)?;
let zero_tensor = create_zero_tensor(prod_check.len());
enforce_equality(config, region, &[prod_check, zero_tensor])?;
Ok(output)
}
@@ -1963,13 +2215,7 @@ pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
.map(|coord| {
let (b, i) = (coord[0], coord[1]);
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
let output = conv(
config,
region,
&[input, kernel.clone()],
padding,
stride,
)?;
let output = conv(config, region, &[input, kernel.clone()], padding, stride)?;
res.push(output);
Ok(())
})
@@ -2448,38 +2694,6 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// is zero identity constraint.
pub(crate) fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let output = if assign || !values[0].get_const_indices()?.is_empty() {
let output = region.assign(&config.custom_gates.output, &values[0])?;
region.increment(output.len());
output
} else {
values[0].clone()
};
// Enable the selectors
if !region.is_dummy() {
(0..output.len())
.map(|j| {
let index = region.linear_coord() - j - 1;
let (x, y, z) = config.custom_gates.output.cartesian_coord(index);
let selector = config.custom_gates.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
Ok(output)
}
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -2747,7 +2961,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// this is safe because we later constrain it
let argmax = values[0]
@@ -2770,7 +2983,6 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
config,
region,
&[values[0].clone(), assigned_argmax.clone()],
indices,
)?;
let max_val = max(config, region, &[values[0].clone()])?;
@@ -2785,7 +2997,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
indices: ValTensor<F>,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// this is safe because we later constrain it
let argmin = values[0]
@@ -2809,7 +3020,6 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
config,
region,
&[values[0].clone(), assigned_argmin.clone()],
indices,
)?;
let min_val = min(config, region, &[values[0].clone()])?;

View File

@@ -14,10 +14,17 @@ pub enum PolyOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
GatherND {
batch_dims: usize,
indices: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterND {
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -89,7 +96,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -213,6 +222,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
@@ -229,6 +250,21 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
}?;
Ok(ForwardResult { output: res })
@@ -276,7 +312,17 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?.0
}
}
PolyOp::GatherND {
batch_dims,
indices,
} => {
if let Some(idx) = indices {
tensor::ops::gather_nd(values[0].get_inner_tensor()?, idx, *batch_dims)?.into()
} else {
layouts::gather_nd(config, region, values[..].try_into()?, *batch_dims)?.0
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
@@ -292,6 +338,18 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterND { constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter_nd(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
)?
.into()
} else {
layouts::scatter_nd(config, region, values[..].try_into()?)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -389,7 +447,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
} else if matches!(self, PolyOp::ScatterElements { .. })
| matches!(self, PolyOp::ScatterND { .. })
{
vec![0, 2]
} else {
vec![]

View File

@@ -294,21 +294,6 @@ pub enum Commands {
args: RunArgs,
},
#[cfg(feature = "render")]
/// Renders the model circuit to a .png file. For an overview of how to interpret these plots, see https://zcash.github.io/halo2/user/dev-tools.html
#[command(arg_required_else_help = true)]
RenderCircuit {
/// The path to the .onnx model file
#[arg(short = 'M', long)]
model: PathBuf,
/// Path to save the .png circuit render
#[arg(short = 'O', long)]
output: PathBuf,
/// proving arguments
#[clap(flatten)]
args: RunArgs,
},
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
@@ -402,9 +387,6 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
logrows: Option<u32>,
/// Check mode for SRS. Verifies downloaded srs is valid. Set to unsafe for speed.
#[arg(long, default_value = DEFAULT_CHECKMODE)]
check: CheckMode,
},
/// Loads model and input and runs mock prover (for testing)
Mock {

View File

@@ -48,8 +48,6 @@ use log::debug;
use log::{info, trace, warn};
#[cfg(not(target_arch = "wasm32"))]
use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
#[cfg(feature = "render")]
use plotters::prelude::*;
#[cfg(not(target_arch = "wasm32"))]
use rand::Rng;
use std::error::Error;
@@ -159,15 +157,8 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
settings_path,
logrows,
check,
} => get_srs_cmd(srs_path, settings_path, logrows, check).await,
} => get_srs_cmd(srs_path, settings_path, logrows).await,
Commands::Table { model, args } => table(model, args),
#[cfg(feature = "render")]
Commands::RenderCircuit {
model,
output,
args,
} => render(model, output, args),
Commands::GenSettings {
model,
settings_path,
@@ -492,23 +483,28 @@ async fn fetch_srs(uri: &str) -> Result<Vec<u8>, Box<dyn Error>> {
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
pub(crate) fn get_file_hash(path: &PathBuf) -> Result<String, Box<dyn Error>> {
use std::io::Read;
let path = get_srs_path(logrows, srs_path);
let file = std::fs::File::open(path.clone())?;
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);
let mut buffer = vec![];
let mut reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
let bytes_read = reader.read_to_end(&mut buffer)?;
info!(
"read {} bytes from SRS file (vector of len = {})",
"read {} bytes from file (vector of len = {})",
bytes_read,
buffer.len()
);
let hash = sha256::digest(buffer);
info!("SRS hash: {}", hash);
info!("file hash: {}", hash);
Ok(hash)
}
#[cfg(not(target_arch = "wasm32"))]
fn check_srs_hash(logrows: u32, srs_path: Option<PathBuf>) -> Result<String, Box<dyn Error>> {
let path = get_srs_path(logrows, srs_path);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
Some(h) => h,
@@ -532,7 +528,6 @@ pub(crate) async fn get_srs_cmd(
srs_path: Option<PathBuf>,
settings_path: Option<PathBuf>,
logrows: Option<u32>,
check_mode: CheckMode,
) -> Result<String, Box<dyn Error>> {
// logrows overrides settings
@@ -560,21 +555,20 @@ pub(crate) async fn get_srs_cmd(
let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k);
let mut reader = Cursor::new(fetch_srs(&srs_uri).await?);
// check the SRS
if matches!(check_mode, CheckMode::SAFE) {
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated");
}
#[cfg(not(target_arch = "wasm32"))]
let pb = init_spinner();
#[cfg(not(target_arch = "wasm32"))]
pb.set_message("Validating SRS (this may take a while) ...");
let params = ParamsKZG::<Bn256>::read(&mut reader)?;
#[cfg(not(target_arch = "wasm32"))]
pb.finish_with_message("SRS validated.");
info!("Saving SRS to disk...");
let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone()))?;
let mut buffer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, &mut file);
buffer.write_all(reader.get_ref())?;
buffer.flush()?;
params.write(&mut buffer)?;
info!("Saved SRS to disk.");
info!("SRS downloaded");
} else {
@@ -969,8 +963,8 @@ pub(crate) fn calibrate(
continue;
}
}
// drop the gag
// drop the gag
#[cfg(unix)]
drop(_r);
#[cfg(unix)]
@@ -1194,29 +1188,6 @@ pub(crate) fn mock(
Ok(String::new())
}
#[cfg(feature = "render")]
pub(crate) fn render(
model: PathBuf,
output: PathBuf,
args: RunArgs,
) -> Result<String, Box<dyn Error>> {
let circuit = GraphCircuit::from_run_args(&args, &model)?;
info!("Rendering circuit");
// Create the area we want to draw on.
// We could use SVGBackend if we want to render to .svg instead.
// for an overview of how to interpret these plots, see https://zcash.github.io/halo2/user/dev-tools.html
let root = BitMapBackend::new(&output, (512, 512)).into_drawing_area();
root.fill(&TRANSPARENT)?;
let root = root.titled("Layout", ("sans-serif", 20))?;
halo2_proofs::dev::CircuitLayout::default()
// We hide labels, else most circuits become impossible to decipher because of overlaid text
.show_labels(false)
.render(circuit.settings().run_args.logrows, &circuit, &root)?;
Ok(String::new())
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_verifier(
vk_path: PathBuf,
@@ -1695,7 +1666,7 @@ pub(crate) fn fuzz(
let logrows = circuit.settings().run_args.logrows;
info!("setting up tests");
#[cfg(unix)]
let _r = Gag::stdout()?;
let params = gen_srs::<KZGCommitmentScheme<Bn256>>(logrows);
@@ -1713,6 +1684,7 @@ pub(crate) fn fuzz(
let public_inputs = circuit.prepare_public_inputs(&data)?;
let strategy = KZGSingleStrategy::new(&params);
#[cfg(unix)]
std::mem::drop(_r);
info!("starting fuzzing");
@@ -1903,6 +1875,7 @@ pub(crate) fn run_fuzz_fn(
passed: &AtomicBool,
) {
let num_failures = AtomicI64::new(0);
#[cfg(unix)]
let _r = Gag::stdout().unwrap();
let pb = init_bar(num_runs as u64);
@@ -1916,6 +1889,7 @@ pub(crate) fn run_fuzz_fn(
pb.inc(1);
});
pb.finish_with_message("Done.");
#[cfg(unix)]
std::mem::drop(_r);
info!(
"num failures: {} out of {}",

View File

@@ -23,7 +23,10 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::ops::{
array::{Gather, GatherElements, MultiBroadcastTo, OneHot, ScatterElements, Slice, Topk},
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
},
change_axes::AxisOp,
cnn::{Conv, Deconv},
einsum::EinSum,
@@ -467,6 +470,78 @@ pub fn new_op_from_onnx(
// Extract the max value
}
"ScatterNd" => {
if inputs.len() != 3 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"scatter nd".to_string(),
)));
};
// just verify it deserializes correctly
let _op = load_op::<ScatterNd>(node.op(), idx, node.op().name().to_string())?;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherNd" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"gather nd".to_string(),
)));
};
let op = load_op::<GatherNd>(node.op(), idx, node.op().name().to_string())?;
let batch_dims = op.batch_dims;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
})
}
// }
if inputs[1].opkind().is_input() {
inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input {
scale: 0,
datum_type: InputType::TDim,
}));
inputs[1].bump_scale(0);
}
op
}
"GatherElements" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(

View File

@@ -484,7 +484,6 @@ fn get_srs(
srs_path,
settings_path,
logrows,
CheckMode::SAFE,
))
.map_err(|e| {
let err_str = format!("Failed to get srs: {}", e);
@@ -1104,6 +1103,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;

View File

@@ -673,6 +673,68 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&res), &dims)
}
/// Set a slice of the Tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
/// assert_eq!(a, b);
/// ```
pub fn set_slice(
&mut self,
indices: &[Range<usize>],
value: &Tensor<T>,
) -> Result<(), TensorError>
where
T: Send + Sync,
{
if indices.is_empty() {
return Ok(());
}
if self.dims.len() < indices.len() {
return Err(TensorError::DimError(format!(
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
indices, self.dims
)));
}
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
let omitted_dims = (indices.len()..self.dims.len())
.map(|i| self.dims[i])
.collect::<Vec<_>>();
for dim in &omitted_dims {
full_indices.push(0..*dim);
}
let full_dims = full_indices
.iter()
.map(|x| x.end - x.start)
.collect::<Vec<_>>();
// now broadcast the value to the full dims
let value = value.expand(&full_dims)?;
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
let _ = cartesian_coord
.iter()
.enumerate()
.map(|(i, e)| {
self.set(e, value[i].clone());
})
.collect::<Vec<_>>();
Ok(())
}
/// Get the array index from rows / columns indices.
///
/// ```

View File

@@ -2,7 +2,7 @@ use super::TensorError;
use crate::tensor::{Tensor, TensorType};
use itertools::Itertools;
use maybe_rayon::{
iter::IndexedParallelIterator, iter::IntoParallelRefMutIterator, iter::ParallelIterator,
iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
prelude::IntoParallelRefIterator,
};
use std::collections::{HashMap, HashSet};
@@ -1328,6 +1328,316 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
Ok(output)
}
/// Gather ND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to gather
/// * `batch_dims` - Number of batch dimensions
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::gather_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3]),
/// &[2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 0, 1, 1]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 3]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[0, 1, 2, 3, 4, 5, 6, 7]),
/// &[2, 2, 2],
/// ).unwrap();
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 1, 0]),
/// &[2, 1, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[1, 0]),
/// &[2, 1],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 1).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1]),
/// &[2, 2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 0, 1, 1, 1, 0]),
/// &[2, 2, 2],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1, 0, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = gather_nd(&x, &index, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("gather_nd".to_string()))?;
if last_value > &(input_dims.len() - batch_dims) {
return Err(TensorError::DimMismatch("gather_nd".to_string()));
}
let output_size =
// If indices_shape[-1] == r-b, since the rank of indices is q,
// indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b,
// where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
// Let us think of each such r-b ranked tensor as indices_slice.
// Each scalar value corresponding to data[0:b-1,indices_slice] is filled into
// the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
// if indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b.
// Let us think of each such tensors as indices_slice.
// Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor
{
let output_rank = input_dims.len() + index_dims.len() - 1 - batch_dims - last_value;
let mut dims = index_dims[..index_dims.len() - 1].to_vec();
let input_offset = batch_dims + last_value;
dims.extend(input_dims[input_offset..input_dims.len()].to_vec());
assert_eq!(output_rank, dims.len());
dims
};
// cartesian coord over batch dims
let mut batch_cartesian_coord = input_dims[0..batch_dims]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if batch_cartesian_coord.is_empty() {
batch_cartesian_coord.push(vec![]);
}
let outputs = batch_cartesian_coord
.par_iter()
.map(|batch_coord| {
let batch_slice = batch_coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut index_slice = index.get_slice(&batch_slice)?;
index_slice.reshape(&index.dims()[batch_dims..])?;
let mut input_slice = input.get_slice(&batch_slice)?;
input_slice.reshape(&input.dims()[batch_dims..])?;
let mut inner_cartesian_coord = index_slice.dims()[0..index_slice.dims().len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
if inner_cartesian_coord.is_empty() {
inner_cartesian_coord.push(vec![]);
}
let output = inner_cartesian_coord
.iter()
.map(|coord| {
let slice = coord
.iter()
.map(|x| *x..*x + 1)
.chain(batch_coord.iter().map(|x| *x..*x + 1))
.collect::<Vec<_>>();
let index_slice = index_slice
.get_slice(&slice)
.unwrap()
.iter()
.map(|x| *x..*x + 1)
.collect::<Vec<_>>();
input_slice.get_slice(&index_slice).unwrap()
})
.collect::<Tensor<_>>();
output.combine()
})
.collect::<Result<Vec<_>, _>>()?;
let mut outputs = outputs.into_iter().flatten().collect::<Tensor<_>>();
outputs.reshape(&output_size)?;
Ok(outputs)
}
/// Scatter ND.
/// This operator is the inverse of GatherND.
/// # Arguments
/// * `input` - Tensor
/// * `index` - Tensor of indices to scatter
/// * `src` - Tensor of src
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::scatter_nd;
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[8],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[4, 3, 1, 7]),
/// &[4, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10, 11, 12]),
/// &[4],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 2]),
/// &[2, 1],
/// ).unwrap();
///
/// let src = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// ]),
/// &[2, 4, 4],
/// ).unwrap();
///
/// let result = scatter_nd(&x, &index, &src).unwrap();
///
/// let expected = Tensor::<i128>::new(
/// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
/// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
/// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
/// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8]),
/// &[4, 4, 4],
/// ).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[2, 1],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9, 10]),
/// &[2],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
///
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8]),
/// &[2, 4],
/// ).unwrap();
///
/// let index = Tensor::<usize>::new(
/// Some(&[0, 1]),
/// &[1, 1, 2],
/// ).unwrap();
/// let src = Tensor::<i128>::new(
/// Some(&[9]),
/// &[1, 1],
/// ).unwrap();
/// let result = scatter_nd(&x, &index, &src).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap();
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
let last_value = index_dims
.last()
.ok_or(TensorError::DimMismatch("scatter_nd".to_string()))?;
if last_value > &input_dims.len() {
return Err(TensorError::DimMismatch("scatter_nd".to_string()));
}
let mut output = input.clone();
let cartesian_coord = index_dims[0..index_dims.len() - 1]
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();
cartesian_coord
.iter()
.map(|coord| {
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok(())
})
.collect::<Result<Vec<_>, _>>()?;
Ok(output)
}
fn axes_op<T: TensorType + Send + Sync>(
a: &Tensor<T>,
axes: &[usize],

View File

@@ -494,7 +494,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
_ => return Err(Box::new(TensorError::WrongMethod)),
};
Ok(integer_evals.into_iter().into())
let mut tensor: Tensor<i128> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
Ok(tensor)
}
/// Calls `pad_to_zero_rem` on the inner tensor.

View File

@@ -193,7 +193,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 77] = [
const TESTS: [&str; 79] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -275,6 +275,8 @@ mod native_tests {
"ltsf",
"remainder", //75
"bitshift",
"gather_nd",
"scatter_nd",
];
const WASM_TESTS: [&str; 46] = [
@@ -502,7 +504,7 @@ mod native_tests {
}
});
seq!(N in 0..=76 {
seq!(N in 0..=78 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -589,13 +591,16 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let large_batch_dir = &format!("large_batches_{}", test);
crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10);
mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0);
test_dir.close().unwrap();
}
}
#(#[test_case(TESTS[N])])*
@@ -853,7 +858,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm");
run_js_tests(path, test.to_string(), "testWasm", false);
// test_dir.close().unwrap();
}
@@ -866,7 +871,7 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testWasm");
run_js_tests(path, test.to_string(), "testWasm", false);
test_dir.close().unwrap();
}
@@ -914,6 +919,7 @@ mod native_tests {
use crate::native_tests::kzg_fuzz;
use tempdir::TempDir;
use crate::native_tests::Hardfork;
use crate::native_tests::run_js_tests;
/// Currently only on chain inputs that return a non-negative value are supported.
const TESTS_ON_CHAIN_INPUT: [&str; 17] = [
@@ -1008,8 +1014,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1021,8 +1027,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify_render_seperately(2, path, test.to_string(), "private", "private", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", true);
test_dir.close().unwrap();
}
@@ -1035,8 +1041,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let mut _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "hashed", "private", "private");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1052,8 +1058,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let mut _anvil_child = crate::native_tests::start_anvil(false, hardfork);
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "private", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1065,8 +1071,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "hashed", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1078,8 +1084,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "hashed");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1091,8 +1097,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "kzgcommit", "public");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1104,8 +1110,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "private", "private", "kzgcommit");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -1116,8 +1122,8 @@ mod native_tests {
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(false, Hardfork::Latest);
kzg_evm_prove_and_verify(2, path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit");
// #[cfg(not(feature = "icicle"))]
// run_js_tests(path, test.to_string(), "testBrowserEvmVerify");
#[cfg(not(feature = "icicle"))]
run_js_tests(path, test.to_string(), "testBrowserEvmVerify", false);
test_dir.close().unwrap();
}
@@ -2179,15 +2185,17 @@ mod native_tests {
}
// run js browser evm verify tests for a given example
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str) {
fn run_js_tests(test_dir: &str, example_name: String, js_test: &str, vk: bool) {
let example = format!("--example={}", example_name);
let dir = format!("--dir={}", test_dir);
let mut args = vec!["run", "test", js_test, &example, &dir];
let vk_string: String;
if vk {
vk_string = format!("--vk={}", vk);
args.push(&vk_string);
};
let status = Command::new("pnpm")
.args([
"run",
"test",
js_test,
&format!("--example={}", example_name),
&format!("--dir={}", test_dir),
])
.args(&args)
.status()
.expect("failed to execute process");
assert!(status.success());

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -1,28 +1,42 @@
import localEVMVerify, { Hardfork } from '@ezkljs/verify'
import localEVMVerify from '../../in-browser-evm-verifier/src/index'
import { serialize, deserialize } from '@ezkljs/engine/nodejs'
import { compileContracts } from './utils'
import * as fs from 'fs'
exports.USER_NAME = require("minimist")(process.argv.slice(2))["example"];
exports.EXAMPLE = require("minimist")(process.argv.slice(2))["example"];
exports.PATH = require("minimist")(process.argv.slice(2))["dir"];
exports.VK = require("minimist")(process.argv.slice(2))["vk"];
describe('localEVMVerify', () => {
let bytecode: string
let bytecode_verifier: string
let bytecode_vk: string | undefined = undefined
let proof: any
const example = exports.USER_NAME || "1l_mlp"
const example = exports.EXAMPLE || "1l_mlp"
const path = exports.PATH || "../ezkl/examples/onnx"
const vk = exports.VK || false
beforeEach(() => {
let solcOutput = compileContracts(path, example)
const solcOutput = compileContracts(path, example, 'kzg')
bytecode =
bytecode_verifier =
solcOutput.contracts['artifacts/Verifier.sol']['Halo2Verifier'].evm.bytecode
.object
console.log('size', bytecode.length)
if (vk) {
const solcOutput_vk = compileContracts(path, example, 'vk')
bytecode_vk =
solcOutput_vk.contracts['artifacts/Verifier.sol']['Halo2VerifyingKey'].evm.bytecode
.object
console.log('size of verifier bytecode', bytecode_verifier.length)
}
console.log('verifier bytecode', bytecode_verifier)
})
it('should return true when verification succeeds', async () => {
@@ -30,7 +44,9 @@ describe('localEVMVerify', () => {
proof = deserialize(proofFileBuffer)
const result = await localEVMVerify(proofFileBuffer, bytecode)
const result = await localEVMVerify(proofFileBuffer, bytecode_verifier, bytecode_vk)
console.log('result', result)
expect(result).toBe(true)
})
@@ -39,13 +55,16 @@ describe('localEVMVerify', () => {
let result: boolean = true
console.log(proof.proof)
try {
let index = Math.floor(Math.random() * (proof.proof.length - 2)) + 2
let number = (proof.proof[index] + 1) % 16
let index = Math.round((Math.random() * (proof.proof.length))) % proof.proof.length
console.log('index', index)
console.log('index', proof.proof[index])
let number = (proof.proof[index] + 1) % 256
console.log('index', index)
console.log('new number', number)
proof.proof[index] = number
console.log('index post', proof.proof[index])
const proofModified = serialize(proof)
result = await localEVMVerify(proofModified, bytecode)
result = await localEVMVerify(proofModified, bytecode_verifier, bytecode_vk)
} catch (error) {
// Check if the error thrown is the "out of gas" error.
expect(error).toEqual(

View File

@@ -43,21 +43,21 @@ export function serialize(data: object | string): Uint8ClampedArray { // data is
return new Uint8ClampedArray(uint8Array.buffer);
}
export function getSolcInput(path: string, example: string) {
export function getSolcInput(path: string, example: string, name: string) {
return {
language: 'Solidity',
sources: {
'artifacts/Verifier.sol': {
content: fsSync.readFileSync(`${path}/${example}/kzg.sol`, 'utf-8'),
content: fsSync.readFileSync(`${path}/${example}/${name}.sol`, 'utf-8'),
},
// If more contracts were to be compiled, they should have their own entries here
},
settings: {
optimizer: {
enabled: true,
runs: 200,
runs: 1,
},
evmVersion: 'london',
evmVersion: 'shanghai',
outputSelection: {
'*': {
'*': ['abi', 'evm.bytecode'],
@@ -67,8 +67,8 @@ export function getSolcInput(path: string, example: string) {
}
}
export function compileContracts(path: string, example: string) {
const input = getSolcInput(path, example)
export function compileContracts(path: string, example: string, name: string) {
const input = getSolcInput(path, example, name)
const output = JSON.parse(solc.compile(JSON.stringify(input)))
let compilationFailed = false

Binary file not shown.

1
verifier_abi.json Normal file
View File

@@ -0,0 +1 @@
[{"type":"function","name":"verifyProof","inputs":[{"internalType":"bytes","name":"proof","type":"bytes"},{"internalType":"uint256[]","name":"instances","type":"uint256[]"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"nonpayable"}]

1
vk.abi Normal file
View File

@@ -0,0 +1 @@
[{"type":"constructor","inputs":[]}]