mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
6 Commits
ac/patch-p
...
release-v1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
18c7e35514 | ||
|
|
8aaf518b5e | ||
|
|
1b7b43e073 | ||
|
|
f78618ec59 | ||
|
|
0943e534ee | ||
|
|
316a9a3b40 |
@@ -1,4 +1,4 @@
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
name: Build and Publish EZKL Engine npm package
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
"web/ezkl_bg.wasm",
|
||||
"web/ezkl.js",
|
||||
"web/ezkl.d.ts",
|
||||
"web/snippets/wasm-bindgen-rayon-7afa899f36665473/src/workerHelpers.js",
|
||||
"web/snippets/**/*",
|
||||
"web/package.json",
|
||||
"web/utils.js",
|
||||
"ezkl.d.ts"
|
||||
@@ -79,6 +79,10 @@ jobs:
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" pkg/nodejs/ezkl.js
|
||||
|
||||
- name: Replace `import.meta.url` with `import.meta.resolve` definition in workerHelpers.js
|
||||
run: |
|
||||
find ./pkg/web/snippets -type f -name "*.js" -exec sed -i "s|import.meta.url|import.meta.resolve|" {} +
|
||||
|
||||
- name: Add serialize and deserialize methods to nodejs bundle
|
||||
run: |
|
||||
echo '
|
||||
@@ -174,40 +178,3 @@ jobs:
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
needs: ["publish-wasm-bindings"]
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
npm install
|
||||
npm run build
|
||||
npm ci
|
||||
npm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
15
.github/workflows/pypi.yml
vendored
15
.github/workflows/pypi.yml
vendored
@@ -128,6 +128,7 @@ jobs:
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
|
||||
- name: Install required libraries
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -359,3 +360,17 @@ jobs:
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
doc-publish:
|
||||
name: Trigger ReadTheDocs Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: pypi-publish
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Trigger RTDs build
|
||||
uses: dfm/rtds-action@v1
|
||||
with:
|
||||
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
|
||||
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
|
||||
commit_ref: ${{ github.ref_name }}
|
||||
|
||||
6
.github/workflows/rust.yml
vendored
6
.github/workflows/rust.yml
vendored
@@ -307,8 +307,8 @@ jobs:
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --no-frozen-lockfile
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm install --dir ./in-browser-evm-verifier --frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
@@ -380,7 +380,7 @@ jobs:
|
||||
cache: "pnpm"
|
||||
- name: Install dependencies for js tests
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
pnpm install --frozen-lockfile
|
||||
env:
|
||||
CI: false
|
||||
NODE_ENV: development
|
||||
|
||||
36
.github/workflows/tagging.yml
vendored
36
.github/workflows/tagging.yml
vendored
@@ -14,6 +14,40 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Bump version and push tag
|
||||
id: tag_version
|
||||
uses: mathieudutour/github-tag-action@v6.1
|
||||
uses: mathieudutour/github-tag-action@v6.2
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set Cargo.toml version to match github tag for docs
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
run: |
|
||||
mv docs/python/src/conf.py docs/python/src/conf.py.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
|
||||
rm docs/python/src/conf.py.orig
|
||||
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
|
||||
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
|
||||
rm docs/python/requirements-docs.txt.orig
|
||||
|
||||
- name: Commit files and create tag
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
run: |
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
git fetch --tags
|
||||
git checkout -b release-$RELEASE_TAG
|
||||
git add .
|
||||
git commit -m "ci: update version string in docs"
|
||||
git tag -d $RELEASE_TAG
|
||||
git tag $RELEASE_TAG
|
||||
|
||||
- name: Push changes
|
||||
uses: ad-m/github-push-action@master
|
||||
env:
|
||||
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
|
||||
with:
|
||||
branch: release-${{ steps.tag_version.outputs.new_tag }}
|
||||
force: true
|
||||
tags: true
|
||||
|
||||
54
.github/workflows/verify.yml
vendored
Normal file
54
.github/workflows/verify.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
name: Build and Publish EZKL npm packages (wasm bindings and in-browser evm verifier)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: "The tag to release"
|
||||
required: true
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: .
|
||||
jobs:
|
||||
in-browser-evm-ver-publish:
|
||||
name: publish-in-browser-evm-verifier-package
|
||||
runs-on: ubuntu-latest
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"version\": \".*\"|\"version\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update @ezkljs/engine version in package.json
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
sed -i "s|\"@ezkljs/engine\": \".*\"|\"@ezkljs/engine\": \"${{ github.ref_name }}\"|" in-browser-evm-verifier/package.json
|
||||
- name: Update the engine import in in-browser-evm-verifier to use @ezkljs/engine package instead of the local one;
|
||||
run: |
|
||||
sed -i "s|import { encodeVerifierCalldata } from '../nodejs/ezkl';|import { encodeVerifierCalldata } from '@ezkljs/engine';|" in-browser-evm-verifier/src/index.ts
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
registry-url: "https://registry.npmjs.org"
|
||||
- name: Publish to npm
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
pnpm install --frozen-lockfile
|
||||
pnpm run build
|
||||
pnpm publish
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -49,4 +49,5 @@ node_modules
|
||||
timingData.json
|
||||
!tests/wasm/pk.key
|
||||
!tests/wasm/vk.key
|
||||
docs/python/build
|
||||
!tests/wasm/vk_aggr.key
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12.1
|
||||
26
.readthedocs.yaml
Normal file
26
.readthedocs.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
# .readthedocs.yaml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
version: 2
|
||||
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.12"
|
||||
|
||||
# Build documentation in the "docs/" directory with Sphinx
|
||||
sphinx:
|
||||
configuration: ./docs/python/src/conf.py
|
||||
|
||||
# Optionally build your docs in additional formats such as PDF and ePub
|
||||
# formats:
|
||||
# - pdf
|
||||
# - epub
|
||||
|
||||
# Optional but recommended, declare the Python requirements required
|
||||
# to build your documentation
|
||||
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
||||
python:
|
||||
install:
|
||||
- requirements: ./docs/python/requirements-docs.txt
|
||||
49
Cargo.lock
generated
49
Cargo.lock
generated
@@ -2699,6 +2699,15 @@ dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.10"
|
||||
@@ -2968,9 +2977,9 @@ checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
|
||||
|
||||
[[package]]
|
||||
name = "memmap2"
|
||||
version = "0.5.10"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327"
|
||||
checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
@@ -4845,12 +4854,12 @@ checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
|
||||
|
||||
[[package]]
|
||||
name = "string-interner"
|
||||
version = "0.14.0"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91e2531d8525b29b514d25e275a43581320d587b86db302b9a7e464bac579648"
|
||||
checksum = "07f9fdfdd31a0ff38b59deb401be81b73913d76c9cc5b1aed4e1330a223420b9"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown 0.11.2",
|
||||
"hashbrown 0.14.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@@ -5389,8 +5398,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-core"
|
||||
version = "0.20.23-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bit-set",
|
||||
@@ -5413,12 +5422,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-data"
|
||||
version = "0.20.23-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"half 2.2.1",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"maplit",
|
||||
"ndarray",
|
||||
@@ -5432,8 +5441,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-hir"
|
||||
version = "0.20.23-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"log",
|
||||
@@ -5442,8 +5451,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-linalg"
|
||||
version = "0.20.23-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"derive-new",
|
||||
@@ -5466,8 +5475,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-nnef"
|
||||
version = "0.20.23-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"flate2",
|
||||
@@ -5480,8 +5489,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx"
|
||||
version = "0.20.23-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"derive-new",
|
||||
@@ -5497,8 +5506,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx-opl"
|
||||
version = "0.20.23-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7b1aa33b2f7d1f19b80e270c83320f0f94daff69#7b1aa33b2f7d1f19b80e270c83320f0f94daff69"
|
||||
version = "0.21.3"
|
||||
source = "git+https://github.com/sonos/tract/?rev=681a096f02c9d7d363102d9fb0e446d1710ac2c8#681a096f02c9d7d363102d9fb0e446d1710ac2c8"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"log",
|
||||
|
||||
@@ -80,7 +80,7 @@ pyo3-asyncio = { version = "0.20.0", features = [
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
|
||||
|
||||
@@ -70,8 +70,8 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: [(0, 0); 2],
|
||||
stride: (1, 1),
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -65,9 +65,9 @@ impl Circuit<Fr> for MyCircuit {
|
||||
&mut region,
|
||||
&[self.image.clone()],
|
||||
Box::new(HybridOp::SumPool {
|
||||
padding: [(0, 0); 2],
|
||||
stride: (1, 1),
|
||||
kernel_shape: (2, 2),
|
||||
padding: vec![(0, 0); 2],
|
||||
stride: vec![1, 1],
|
||||
kernel_shape: vec![2, 2],
|
||||
normalized: false,
|
||||
}),
|
||||
)
|
||||
|
||||
2
docs/python/build.sh
Executable file
2
docs/python/build.sh
Executable file
@@ -0,0 +1,2 @@
|
||||
#!/bin/sh
|
||||
sphinx-build ./src build
|
||||
4
docs/python/requirements-docs.txt
Normal file
4
docs/python/requirements-docs.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
ezkl==10.3.2
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
29
docs/python/src/conf.py
Normal file
29
docs/python/src/conf.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '10.3.2'
|
||||
version = release
|
||||
|
||||
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.inheritance_diagram',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx_rtd_theme',
|
||||
]
|
||||
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
templates_path = ['_templates']
|
||||
exclude_patterns = []
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_static_path = ['_static']
|
||||
11
docs/python/src/index.rst
Normal file
11
docs/python/src/index.rst
Normal file
@@ -0,0 +1,11 @@
|
||||
.. extension documentation master file, created by
|
||||
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
ezkl python bindings
|
||||
================================================
|
||||
|
||||
.. automodule:: ezkl
|
||||
:members:
|
||||
:undoc-members:
|
||||
@@ -203,8 +203,8 @@ where
|
||||
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
|
||||
|
||||
let op = PolyOp::Conv {
|
||||
padding: [(PADDING, PADDING); 2],
|
||||
stride: (STRIDE, STRIDE),
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
|
||||
@@ -20,16 +20,16 @@
|
||||
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ethereumjs/common": "^4.0.0",
|
||||
"@ethereumjs/evm": "^2.0.0",
|
||||
"@ethereumjs/statemanager": "^2.0.0",
|
||||
"@ethereumjs/tx": "^5.0.0",
|
||||
"@ethereumjs/util": "^9.0.0",
|
||||
"@ethereumjs/vm": "^7.0.0",
|
||||
"@ethersproject/abi": "^5.7.0",
|
||||
"@ethereumjs/common": "4.0.0",
|
||||
"@ethereumjs/evm": "2.0.0",
|
||||
"@ethereumjs/statemanager": "2.0.0",
|
||||
"@ethereumjs/tx": "5.0.0",
|
||||
"@ethereumjs/util": "9.0.0",
|
||||
"@ethereumjs/vm": "7.0.0",
|
||||
"@ethersproject/abi": "5.7.0",
|
||||
"@ezkljs/engine": "^9.4.4",
|
||||
"ethers": "^6.7.1",
|
||||
"json-bigint": "^1.0.0"
|
||||
"ethers": "6.7.1",
|
||||
"json-bigint": "1.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.8.3",
|
||||
|
||||
18
in-browser-evm-verifier/pnpm-lock.yaml
generated
18
in-browser-evm-verifier/pnpm-lock.yaml
generated
@@ -6,34 +6,34 @@ settings:
|
||||
|
||||
dependencies:
|
||||
'@ethereumjs/common':
|
||||
specifier: ^4.0.0
|
||||
specifier: 4.0.0
|
||||
version: 4.0.0
|
||||
'@ethereumjs/evm':
|
||||
specifier: ^2.0.0
|
||||
specifier: 2.0.0
|
||||
version: 2.0.0
|
||||
'@ethereumjs/statemanager':
|
||||
specifier: ^2.0.0
|
||||
specifier: 2.0.0
|
||||
version: 2.0.0
|
||||
'@ethereumjs/tx':
|
||||
specifier: ^5.0.0
|
||||
specifier: 5.0.0
|
||||
version: 5.0.0
|
||||
'@ethereumjs/util':
|
||||
specifier: ^9.0.0
|
||||
specifier: 9.0.0
|
||||
version: 9.0.0
|
||||
'@ethereumjs/vm':
|
||||
specifier: ^7.0.0
|
||||
specifier: 7.0.0
|
||||
version: 7.0.0
|
||||
'@ethersproject/abi':
|
||||
specifier: ^5.7.0
|
||||
specifier: 5.7.0
|
||||
version: 5.7.0
|
||||
'@ezkljs/engine':
|
||||
specifier: ^9.4.4
|
||||
version: 9.4.4
|
||||
ethers:
|
||||
specifier: ^6.7.1
|
||||
specifier: 6.7.1
|
||||
version: 6.7.1
|
||||
json-bigint:
|
||||
specifier: ^1.0.0
|
||||
specifier: 1.0.0
|
||||
version: 1.0.0
|
||||
|
||||
devDependencies:
|
||||
|
||||
@@ -36,7 +36,7 @@ if [ "$(which ezkl)s" != "s" ] && [ "$(which ezkl)" != "$EZKL_DIR/ezkl" ] ; the
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ":$PATH:" != *":${EZKl_DIR}:"* ]]; then
|
||||
if [[ ":$PATH:" != *":${EZKL_DIR}:"* ]]; then
|
||||
# Add the ezkl directory to the path and ensure the old PATH variables remain.
|
||||
echo >> $PROFILE && echo "export PATH=\"\$PATH:$EZKL_DIR\"" >> $PROFILE
|
||||
fi
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=0.14,<0.15"]
|
||||
requires = ["maturin>=1.0,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -2,7 +2,7 @@ attrs==23.2.0
|
||||
exceptiongroup==1.2.0
|
||||
importlib-metadata==7.1.0
|
||||
iniconfig==2.0.0
|
||||
maturin==1.5.0
|
||||
maturin==1.5.1
|
||||
packaging==24.0
|
||||
pluggy==1.4.0
|
||||
pytest==8.1.1
|
||||
@@ -11,4 +11,4 @@ typing-extensions==4.10.0
|
||||
zipp==3.18.1
|
||||
onnx==1.15.0
|
||||
onnxruntime==1.17.1
|
||||
numpy==1.26.4
|
||||
numpy==1.26.4
|
||||
|
||||
@@ -956,20 +956,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
values: &[ValTensor<F>],
|
||||
op: Box<dyn Op<F>>,
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
let res = op.layout(self, region, values)?;
|
||||
|
||||
if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() {
|
||||
if let Some(claimed_output) = &res {
|
||||
// during key generation this will be unknown vals so we use this as a flag to check
|
||||
let mut is_assigned = !claimed_output.any_unknowns()?;
|
||||
for val in values.iter() {
|
||||
is_assigned = is_assigned && !val.any_unknowns()?;
|
||||
}
|
||||
if is_assigned {
|
||||
op.safe_mode_check(claimed_output, values)?;
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
op.layout(self, region, values)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
fieldutils::i128_to_felt,
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -29,15 +29,15 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
},
|
||||
SumPool {
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
kernel_shape: (usize, usize),
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
kernel_shape: Vec<usize>,
|
||||
normalized: bool,
|
||||
},
|
||||
MaxPool2d {
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
pool_dims: (usize, usize),
|
||||
MaxPool {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
pool_dims: Vec<usize>,
|
||||
},
|
||||
ReduceMin {
|
||||
axes: Vec<usize>,
|
||||
@@ -85,93 +85,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = inputs[0].clone().map(|x| felt_to_i128(x));
|
||||
|
||||
let res = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
|
||||
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
|
||||
}
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
..
|
||||
} => crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
),
|
||||
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
|
||||
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather(&x, idx, *dim)?
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
|
||||
}
|
||||
|
||||
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
..
|
||||
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized,
|
||||
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
|
||||
HybridOp::Softmax {
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
} => tensor::ops::nonlinearities::softmax_axes(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
axes,
|
||||
),
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
|
||||
}
|
||||
HybridOp::Greater => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::greater(&x, &y)?
|
||||
}
|
||||
HybridOp::GreaterEqual => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::greater_equal(&x, &y)?
|
||||
}
|
||||
HybridOp::Less => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::less(&x, &y)?
|
||||
}
|
||||
HybridOp::LessEqual => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::less_equal(&x, &y)?
|
||||
}
|
||||
HybridOp::Equals => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::equals(&x, &y)?
|
||||
}
|
||||
};
|
||||
|
||||
// convert back to felt
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
@@ -201,12 +114,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
),
|
||||
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
|
||||
HybridOp::MaxPool2d {
|
||||
HybridOp::MaxPool {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
} => format!(
|
||||
"MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
|
||||
padding, stride, pool_dims
|
||||
),
|
||||
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
|
||||
@@ -253,9 +166,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*padding,
|
||||
*stride,
|
||||
*kernel_shape,
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
*normalized,
|
||||
)?,
|
||||
HybridOp::Recip {
|
||||
@@ -315,17 +228,17 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
}
|
||||
}
|
||||
|
||||
HybridOp::MaxPool2d {
|
||||
HybridOp::MaxPool {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
} => layouts::max_pool2d(
|
||||
} => layouts::max_pool(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*padding,
|
||||
*stride,
|
||||
*pool_dims,
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
)?,
|
||||
HybridOp::ReduceMax { axes } => {
|
||||
layouts::max_axes(config, region, values[..].try_into()?, axes)?
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -135,15 +135,12 @@ impl LookupOp {
|
||||
let range = range as i128;
|
||||
(-range, range)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_i128(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
@@ -235,6 +232,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the name of the operation
|
||||
fn as_string(&self) -> String {
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -35,8 +35,6 @@ pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
|
||||
/// Returns a string representation of the operation.
|
||||
fn as_string(&self) -> String;
|
||||
|
||||
@@ -71,33 +69,6 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Safe mode output checl
|
||||
fn safe_mode_check(
|
||||
&self,
|
||||
claimed_output: &ValTensor<F>,
|
||||
original_values: &[ValTensor<F>],
|
||||
) -> Result<(), TensorError> {
|
||||
let felt_evals = original_values
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?;
|
||||
evals.reshape(v.dims())?;
|
||||
Ok(evals)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let ref_op: Tensor<F> = self.f(&felt_evals)?.output;
|
||||
|
||||
let mut output = claimed_output
|
||||
.get_felt_evals()
|
||||
.map_err(|_| TensorError::FeltError)?;
|
||||
output.reshape(claimed_output.dims())?;
|
||||
|
||||
assert_eq!(output, ref_op);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
@@ -176,12 +147,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
self
|
||||
}
|
||||
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Ok(ForwardResult {
|
||||
output: x[0].clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
"Input".into()
|
||||
}
|
||||
@@ -235,9 +200,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Err(TensorError::WrongMethod)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
"Unknown".into()
|
||||
@@ -307,11 +269,6 @@ impl<
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let output = self.quantized_values.clone();
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("CONST (scale={})", self.quantized_values.scale().unwrap())
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
fieldutils::felt_to_i128,
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -32,8 +31,8 @@ pub enum PolyOp {
|
||||
equation: String,
|
||||
},
|
||||
Conv {
|
||||
padding: [(usize, usize); 2],
|
||||
stride: (usize, usize),
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -41,9 +40,9 @@ pub enum PolyOp {
|
||||
modulo: usize,
|
||||
},
|
||||
DeConv {
|
||||
padding: [(usize, usize); 2],
|
||||
output_padding: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
padding: Vec<(usize, usize)>,
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -58,10 +57,13 @@ pub enum PolyOp {
|
||||
destination: usize,
|
||||
},
|
||||
Flatten(Vec<usize>),
|
||||
Pad([(usize, usize); 2]),
|
||||
Pad(Vec<(usize, usize)>),
|
||||
Sum {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
MeanOfSquares {
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
Prod {
|
||||
axes: Vec<usize>,
|
||||
len_prod: usize,
|
||||
@@ -105,10 +107,28 @@ impl<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterND { .. } => "SCATTERND".into(),
|
||||
PolyOp::GatherElements { dim, constant_idx } => format!(
|
||||
"GATHERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
constant_idx.is_some()
|
||||
),
|
||||
PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices,
|
||||
} => format!(
|
||||
"GATHERND (batch_dims={}, constant_idx{})",
|
||||
batch_dims,
|
||||
indices.is_some()
|
||||
),
|
||||
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
|
||||
PolyOp::ScatterElements { dim, constant_idx } => format!(
|
||||
"SCATTERELEMENTS (dim={}, constant_idx{})",
|
||||
dim,
|
||||
constant_idx.is_some()
|
||||
),
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
|
||||
}
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -120,15 +140,26 @@ impl<
|
||||
}
|
||||
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
|
||||
PolyOp::Flatten(_) => "FLATTEN".into(),
|
||||
PolyOp::Pad(_) => "PAD".into(),
|
||||
PolyOp::Pad(pads) => format!("PAD (pads={:?})", pads),
|
||||
PolyOp::Add => "ADD".into(),
|
||||
PolyOp::Mult => "MULT".into(),
|
||||
PolyOp::Sub => "SUB".into(),
|
||||
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Conv { .. } => "CONV".into(),
|
||||
PolyOp::DeConv { .. } => "DECONV".into(),
|
||||
PolyOp::Conv { stride, padding } => {
|
||||
format!("CONV (stride={:?}, padding={:?})", stride, padding)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
|
||||
stride, padding, output_padding
|
||||
)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
format!("SLICE (axis={}, start={}, end={})", axis, start, end)
|
||||
@@ -142,146 +173,6 @@ impl<
|
||||
}
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let mut inputs = inputs.to_vec();
|
||||
let res = match &self {
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch(
|
||||
"multibroadcastto inputs".to_string(),
|
||||
));
|
||||
}
|
||||
inputs[0].expand(shape)
|
||||
}
|
||||
PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]),
|
||||
PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]),
|
||||
PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]),
|
||||
PolyOp::Not => tensor::ops::not(&inputs[0]),
|
||||
PolyOp::Downsample {
|
||||
axis,
|
||||
stride,
|
||||
modulo,
|
||||
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
|
||||
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
|
||||
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
|
||||
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
|
||||
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
|
||||
PolyOp::Reshape(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::MoveAxis {
|
||||
source,
|
||||
destination,
|
||||
} => inputs[0].move_axis(*source, *destination),
|
||||
PolyOp::Flatten(new_dims) => {
|
||||
let mut t = inputs[0].clone();
|
||||
t.reshape(new_dims)?;
|
||||
Ok(t)
|
||||
}
|
||||
PolyOp::Pad(p) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pad inputs".to_string()));
|
||||
}
|
||||
tensor::ops::pad(&inputs[0], *p)
|
||||
}
|
||||
PolyOp::Add => tensor::ops::add(&inputs),
|
||||
PolyOp::Neg => tensor::ops::neg(&inputs[0]),
|
||||
PolyOp::Sub => tensor::ops::sub(&inputs),
|
||||
PolyOp::Mult => tensor::ops::mult(&inputs),
|
||||
PolyOp::Conv { padding, stride } => tensor::ops::conv(&inputs, *padding, *stride),
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
|
||||
PolyOp::Pow(u) => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("pow inputs".to_string()));
|
||||
}
|
||||
inputs[0].pow(*u)
|
||||
}
|
||||
PolyOp::Sum { axes } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("sum inputs".to_string()));
|
||||
}
|
||||
tensor::ops::sum_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::Prod { axes, .. } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("prod inputs".to_string()));
|
||||
}
|
||||
tensor::ops::prod_axes(&inputs[0], axes)
|
||||
}
|
||||
PolyOp::Concat { axis } => {
|
||||
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
|
||||
}
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("slice inputs".to_string()));
|
||||
}
|
||||
tensor::ops::slice(&inputs[0], axis, start, end)
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::GatherND {
|
||||
indices,
|
||||
batch_dims,
|
||||
} => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = indices {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_nd(&x, &y, *batch_dims)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
}
|
||||
|
||||
PolyOp::ScatterND { constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
let src = if constant_idx.is_some() {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter_nd(&x, &idx, &src)
|
||||
}
|
||||
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
@@ -292,6 +183,9 @@ impl<
|
||||
PolyOp::MultiBroadcastTo { shape } => {
|
||||
layouts::expand(config, region, values[..].try_into()?, shape)?
|
||||
}
|
||||
PolyOp::MeanOfSquares { axes } => {
|
||||
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
|
||||
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
|
||||
@@ -318,7 +212,7 @@ impl<
|
||||
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
|
||||
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
@@ -370,9 +264,9 @@ impl<
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*padding,
|
||||
*output_padding,
|
||||
*stride,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
@@ -388,7 +282,7 @@ impl<
|
||||
)));
|
||||
}
|
||||
let mut input = values[0].clone();
|
||||
input.pad(*p)?;
|
||||
input.pad(p.clone(), 0)?;
|
||||
input
|
||||
}
|
||||
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
|
||||
@@ -404,6 +298,7 @@ impl<
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
let scale = match self {
|
||||
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
|
||||
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
|
||||
PolyOp::Iff => in_scales[1],
|
||||
PolyOp::Einsum { .. } => {
|
||||
|
||||
@@ -17,8 +17,6 @@ use crate::{
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
use super::Op;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i128, i128);
|
||||
|
||||
@@ -113,11 +111,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
let chunk = chunk as i128;
|
||||
// we index from 1 to prevent soundness issues
|
||||
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
|
||||
let op_f = Op::<F>::f(
|
||||
&self.nonlinearity,
|
||||
&[Tensor::from(vec![first_element].into_iter())],
|
||||
)
|
||||
.unwrap();
|
||||
let op_f = self
|
||||
.nonlinearity
|
||||
.f(&[Tensor::from(vec![first_element].into_iter())])
|
||||
.unwrap();
|
||||
(first_element, op_f.output[0])
|
||||
}
|
||||
|
||||
@@ -205,8 +202,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
@@ -1048,8 +1048,8 @@ mod conv {
|
||||
&mut region,
|
||||
&self.inputs,
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1198,8 +1198,8 @@ mod conv_col_ultra_overflow {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1343,8 +1343,8 @@ mod conv_relu_col_ultra_overflow {
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
|
||||
@@ -1200,6 +1200,20 @@ impl Model {
|
||||
.collect();
|
||||
|
||||
for (idx, node) in self.graph.nodes.iter() {
|
||||
debug!("laying out {}: {}", idx, node.as_str(),);
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
region.debug_report();
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
|
||||
node.inputs()
|
||||
.iter()
|
||||
@@ -1211,25 +1225,11 @@ impl Model {
|
||||
// we re-assign inputs, always from the 0 outlet
|
||||
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
|
||||
};
|
||||
|
||||
debug!("laying out {}: {}", idx, node.as_str(),);
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
region.debug_report();
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!("output dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input_dims {:?}",
|
||||
"input dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
|
||||
@@ -14,7 +14,6 @@ use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::new_op_from_onnx;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tensor::TensorError;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -61,20 +60,6 @@ impl Op<Fp> for Rescaled {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
if self.scale.len() != x.len() {
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
|
||||
}
|
||||
|
||||
let mut rescaled_inputs = vec![];
|
||||
let inputs = &mut x.to_vec();
|
||||
for (i, ri) in inputs.iter_mut().enumerate() {
|
||||
let mult_tensor = Tensor::from([Fp::from(self.scale[i].1 as u64)].into_iter());
|
||||
let res = (ri.clone() * mult_tensor)?;
|
||||
rescaled_inputs.push(res);
|
||||
}
|
||||
Op::<Fp>::f(&*self.inner, &rescaled_inputs)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED INPUT ({})", self.inner.as_string())
|
||||
@@ -215,13 +200,6 @@ impl Op<Fp> for RebaseScale {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
let mut res = Op::<Fp>::f(&*self.inner, x)?;
|
||||
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
|
||||
res.output = rebase_res.output;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!(
|
||||
@@ -389,13 +367,6 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
fn f(
|
||||
&self,
|
||||
inputs: &[Tensor<Fp>],
|
||||
) -> Result<crate::circuit::ForwardResult<Fp>, crate::tensor::TensorError> {
|
||||
self.as_op().f(inputs)
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
|
||||
@@ -509,7 +509,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
deleted_indices.push(1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
deleted_indices.push(1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
|
||||
batch_dims,
|
||||
indices: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -582,7 +582,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
deleted_indices.push(1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
@@ -734,6 +734,19 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::Sum { axes })
|
||||
}
|
||||
"Reduce<MeanOfSquares>" => {
|
||||
if inputs.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
idx,
|
||||
"mean of squares".to_string(),
|
||||
)));
|
||||
};
|
||||
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
|
||||
SupportedOp::Linear(PolyOp::MeanOfSquares { axes })
|
||||
}
|
||||
|
||||
"Max" => {
|
||||
// Extract the max value
|
||||
// first find the input that is a constant
|
||||
@@ -1106,17 +1119,7 @@ pub fn new_op_from_onnx(
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
@@ -1124,26 +1127,10 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
let (stride_h, stride_w) = if stride.len() == 1 {
|
||||
(1, stride[0])
|
||||
} else if stride.len() == 2 {
|
||||
(stride[0], stride[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
|
||||
};
|
||||
|
||||
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
|
||||
(1, kernel_shape[0])
|
||||
} else if kernel_shape.len() == 2 {
|
||||
(kernel_shape[0], kernel_shape[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("kernel".to_string())));
|
||||
};
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool2d {
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool {
|
||||
padding,
|
||||
stride: (stride_h, stride_w),
|
||||
pool_dims: (kernel_height, kernel_width),
|
||||
stride: stride.to_vec(),
|
||||
pool_dims: kernel_shape.to_vec(),
|
||||
})
|
||||
}
|
||||
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
|
||||
@@ -1165,7 +1152,7 @@ pub fn new_op_from_onnx(
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
deleted_indices.push(1);
|
||||
if c.raw_values.len() > 1 {
|
||||
unimplemented!("only support scalar pow")
|
||||
}
|
||||
@@ -1205,15 +1192,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
let stride = match conv_node.pool_spec.strides.clone() {
|
||||
Some(s) => {
|
||||
if s.len() == 1 {
|
||||
(s[0], s[0])
|
||||
} else if s.len() == 2 {
|
||||
(s[0], s[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
}
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
@@ -1221,17 +1200,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let padding = match &conv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
@@ -1286,33 +1255,20 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
|
||||
let stride = match deconv_node.pool_spec.strides.clone() {
|
||||
Some(s) => (s[0], s[1]),
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
|
||||
}
|
||||
};
|
||||
let padding = match &deconv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
};
|
||||
|
||||
let output_padding: (usize, usize) =
|
||||
(deconv_node.adjustments[0], deconv_node.adjustments[1]);
|
||||
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
let bias_scale = input_scales[2];
|
||||
@@ -1331,7 +1287,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
SupportedOp::Linear(PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
stride,
|
||||
})
|
||||
}
|
||||
@@ -1432,46 +1388,17 @@ pub fn new_op_from_onnx(
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
if b.len() == 2 && a.len() == 2 {
|
||||
[(b[0], b[1]), (a[0], a[1])]
|
||||
} else if b.len() == 1 && a.len() == 1 {
|
||||
[(b[0], b[0]), (a[0], a[0])]
|
||||
} else if b.len() == 1 && a.len() == 2 {
|
||||
[(b[0], b[0]), (a[0], a[1])]
|
||||
} else if b.len() == 2 && a.len() == 1 {
|
||||
[(b[0], b[1]), (a[0], a[0])]
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
|
||||
}
|
||||
};
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
let (stride_h, stride_w) = if stride.len() == 1 {
|
||||
(1, stride[0])
|
||||
} else if stride.len() == 2 {
|
||||
(stride[0], stride[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
|
||||
};
|
||||
|
||||
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
|
||||
(1, kernel_shape[0])
|
||||
} else if kernel_shape.len() == 2 {
|
||||
(kernel_shape[0], kernel_shape[1])
|
||||
} else {
|
||||
return Err(Box::new(GraphError::MissingParams(
|
||||
"kernel shape".to_string(),
|
||||
)));
|
||||
};
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::SumPool {
|
||||
padding,
|
||||
stride: (stride_h, stride_w),
|
||||
kernel_shape: (kernel_height, kernel_width),
|
||||
stride: stride.to_vec(),
|
||||
kernel_shape: pool_spec.kernel_shape.to_vec(),
|
||||
normalized: sumpool_node.normalize,
|
||||
})
|
||||
}
|
||||
@@ -1498,29 +1425,7 @@ pub fn new_op_from_onnx(
|
||||
)));
|
||||
}
|
||||
|
||||
let padding_len = pad_node.pads.len();
|
||||
|
||||
// we only support symmetrical padding that affects the last 2 dims (height and width params)
|
||||
for (i, pad_params) in pad_node.pads.iter().enumerate() {
|
||||
if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) {
|
||||
return Err(Box::new(GraphError::MisformedParams(
|
||||
"ezkl currently only supports padding height and width dimensions"
|
||||
.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let padding = [
|
||||
(
|
||||
pad_node.pads[padding_len - 2].0,
|
||||
pad_node.pads[padding_len - 1].0,
|
||||
),
|
||||
(
|
||||
pad_node.pads[padding_len - 2].1,
|
||||
pad_node.pads[padding_len - 1].1,
|
||||
),
|
||||
];
|
||||
SupportedOp::Linear(PolyOp::Pad(padding))
|
||||
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
|
||||
}
|
||||
"RmAxis" | "Reshape" | "AddAxis" => {
|
||||
// Extract the slope layer hyperparams
|
||||
|
||||
680
src/python.rs
680
src/python.rs
@@ -33,6 +33,7 @@ use tokio::runtime::Runtime;
|
||||
|
||||
type PyFelt = String;
|
||||
|
||||
/// pyclass representing an enum
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
enum PyTestDataSource {
|
||||
@@ -51,14 +52,17 @@ impl From<PyTestDataSource> for TestDataSource {
|
||||
}
|
||||
}
|
||||
|
||||
/// pyclass containing the struct used for G1
|
||||
/// pyclass containing the struct used for G1, this is mostly a helper class
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyG1 {
|
||||
#[pyo3(get, set)]
|
||||
/// Field Element representing x
|
||||
x: PyFelt,
|
||||
#[pyo3(get, set)]
|
||||
/// Field Element representing y
|
||||
y: PyFelt,
|
||||
/// Field Element representing y
|
||||
#[pyo3(get, set)]
|
||||
z: PyFelt,
|
||||
}
|
||||
@@ -134,39 +138,59 @@ impl pyo3::ToPyObject for PyG1Affine {
|
||||
}
|
||||
}
|
||||
|
||||
/// pyclass containing the struct used for run_args
|
||||
/// Python class containing the struct used for run_args
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// PyRunArgs
|
||||
///
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// float: The tolerance for error on model outputs
|
||||
pub tolerance: f32,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing inputs
|
||||
pub input_scale: crate::Scale,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The denominator in the fixed point representation used when quantizing parameters
|
||||
pub param_scale: crate::Scale,
|
||||
#[pyo3(get, set)]
|
||||
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
|
||||
pub scale_rebase_multiplier: u32,
|
||||
#[pyo3(get, set)]
|
||||
/// list[int]: The min and max elements in the lookup table input column
|
||||
pub lookup_range: crate::circuit::table::Range,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The log_2 number of rows
|
||||
pub logrows: u32,
|
||||
#[pyo3(get, set)]
|
||||
/// int: The number of inner columns used for the lookup table
|
||||
pub num_inner_cols: usize,
|
||||
#[pyo3(get, set)]
|
||||
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
|
||||
pub input_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
|
||||
pub output_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
|
||||
pub param_visibility: Visibility,
|
||||
#[pyo3(get, set)]
|
||||
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
|
||||
pub variables: Vec<(String, usize)>,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Rebase the scale using lookup table for division instead of using a range check
|
||||
pub div_rebasing: bool,
|
||||
#[pyo3(get, set)]
|
||||
/// bool: Should constants with 0.0 fraction be rebased to scale 0
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
/// str: check mode, accepts `safe`, `unsafe`
|
||||
pub check_mode: CheckMode,
|
||||
#[pyo3(get, set)]
|
||||
/// str: commitment type, accepts `kzg`, `ipa`
|
||||
pub commitment: PyCommitments,
|
||||
}
|
||||
|
||||
@@ -226,7 +250,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Pyclass marking the type of commitment
|
||||
/// pyclass representing an enum, denoting the type of commitment
|
||||
pub enum PyCommitments {
|
||||
/// KZG commitment
|
||||
KZG,
|
||||
@@ -273,7 +297,19 @@ impl FromStr for PyCommitments {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a felt to big endian
|
||||
/// Converts a field element hex string to big endian
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// felt: str
|
||||
/// The field element represented as a string
|
||||
///
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// str
|
||||
/// field element represented as a string
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
@@ -283,22 +319,45 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
}
|
||||
|
||||
/// Converts a field element hex string to an integer
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// felt: str
|
||||
/// The field element represented as a string
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// int
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
array,
|
||||
felt,
|
||||
))]
|
||||
fn felt_to_int(array: PyFelt) -> PyResult<i128> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&array);
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<i128> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_i128(felt);
|
||||
Ok(int_rep)
|
||||
}
|
||||
|
||||
/// Converts a field eleement hex string to a floating point number
|
||||
/// Converts a field element hex string to a floating point number
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// felt: str
|
||||
/// The field element represented as a string
|
||||
///
|
||||
/// scale: float
|
||||
/// The scaling factor used to convert the field element into a floating point representation
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// float
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
array,
|
||||
felt,
|
||||
scale
|
||||
))]
|
||||
fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&array);
|
||||
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_i128(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier;
|
||||
@@ -306,9 +365,23 @@ fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
}
|
||||
|
||||
/// Converts a floating point element to a field element hex string
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// input: float
|
||||
/// The field element represented as a string
|
||||
///
|
||||
/// scale: float
|
||||
/// The scaling factor used to quantize the float into a field element
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// str
|
||||
/// The field element represented as a string
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
input,
|
||||
scale
|
||||
input,
|
||||
scale
|
||||
))]
|
||||
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
@@ -318,9 +391,20 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
}
|
||||
|
||||
/// Converts a buffer to vector of field elements
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// buffer: list[int]
|
||||
/// List of integers representing a buffer
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[str]
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
buffer
|
||||
))]
|
||||
))]
|
||||
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
|
||||
let mut n: u128 = 0;
|
||||
@@ -379,9 +463,20 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
}
|
||||
|
||||
/// Generate a poseidon hash.
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[str]
|
||||
/// List of field elements represented as strings
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
))]
|
||||
))]
|
||||
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
let message: Vec<Fr> = message
|
||||
.iter()
|
||||
@@ -402,12 +497,31 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
}
|
||||
|
||||
/// Generate a kzg commitment.
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the Structure Reference String (SRS) file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[PyG1Affine]
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
))]
|
||||
fn kzg_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
@@ -442,12 +556,31 @@ fn kzg_commit(
|
||||
}
|
||||
|
||||
/// Generate an ipa commitment.
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// message: list[str]
|
||||
/// List of field elements represnted as strings
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the Structure Reference String (SRS) file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// list[PyG1Affine]
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
))]
|
||||
fn ipa_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
@@ -482,10 +615,19 @@ fn ipa_commit(
|
||||
}
|
||||
|
||||
/// Swap the commitments in a proof
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// proof_path: str
|
||||
/// Path to the proof file
|
||||
///
|
||||
/// witness_path: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
witness_path=PathBuf::from(DEFAULT_WITNESS),
|
||||
))]
|
||||
))]
|
||||
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
|
||||
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
|
||||
@@ -494,11 +636,27 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
|
||||
}
|
||||
|
||||
/// Generates a vk from a pk for a model circuit and saves it to a file
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// path_to_pk: str
|
||||
/// Path to the proving key
|
||||
///
|
||||
/// circuit_settings_path: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
/// vk_output_path: str
|
||||
/// Path to create the vk file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK),
|
||||
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK),
|
||||
))]
|
||||
))]
|
||||
fn gen_vk_from_pk_single(
|
||||
path_to_pk: PathBuf,
|
||||
circuit_settings_path: PathBuf,
|
||||
@@ -520,10 +678,22 @@ fn gen_vk_from_pk_single(
|
||||
}
|
||||
|
||||
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
|
||||
///
|
||||
/// Arguments
|
||||
/// -------
|
||||
/// path_to_pk: str
|
||||
/// Path to the proving key
|
||||
///
|
||||
/// vk_output_path: str
|
||||
/// Path to create the vk file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
#[pyfunction(signature = (
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
))]
|
||||
))]
|
||||
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
|
||||
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
|
||||
@@ -538,6 +708,17 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
|
||||
}
|
||||
|
||||
/// Displays the table as a string in python
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// Returns
|
||||
/// ---------
|
||||
/// str
|
||||
/// Table of the nodes in the onnx file
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model = PathBuf::from(DEFAULT_MODEL),
|
||||
py_run_args = None
|
||||
@@ -553,7 +734,16 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
|
||||
}
|
||||
}
|
||||
|
||||
/// generates the srs
|
||||
/// Generates the Structured Reference String (SRS), use this only for testing purposes
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// srs_path: str
|
||||
/// Path to the create the SRS file
|
||||
///
|
||||
/// logrows: int
|
||||
/// The number of logrows for the SRS file
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
srs_path,
|
||||
logrows,
|
||||
@@ -564,7 +754,26 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// gets a public srs
|
||||
/// Gets a public srs
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// logrows: int
|
||||
/// The number of logrows for the SRS file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the create the SRS file
|
||||
///
|
||||
/// commitment: str
|
||||
/// Specify the commitment used ("kzg", "ipa")
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
logrows=None,
|
||||
@@ -597,7 +806,23 @@ fn get_srs(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// generates the circuit settings
|
||||
/// Generates the circuit settings
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// output: str
|
||||
/// Path to create the settings file
|
||||
///
|
||||
/// py_run_args: PyRunArgs
|
||||
/// PyRunArgs object to initialize the settings
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
output=PathBuf::from(DEFAULT_SETTINGS),
|
||||
@@ -618,7 +843,38 @@ fn gen_settings(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// calibrates the circuit settings
|
||||
/// Calibrates the circuit settings
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// data: str
|
||||
/// Path to the calibration data
|
||||
///
|
||||
/// model: str
|
||||
/// Path to the onnx file
|
||||
///
|
||||
/// settings: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// lookup_safety_margin: int
|
||||
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
///
|
||||
/// scales: list[int]
|
||||
/// Optional scales to specifically try for calibration
|
||||
///
|
||||
/// scale_rebase_multiplier: list[int]
|
||||
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
|
||||
///
|
||||
/// max_logrows: int
|
||||
/// Optional max logrows to use for calibration
|
||||
///
|
||||
/// only_range_check_rebase: bool
|
||||
/// Check ranges when rebasing
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
|
||||
model = PathBuf::from(DEFAULT_MODEL),
|
||||
@@ -660,7 +916,30 @@ fn calibrate_settings(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// runs the forward pass operation
|
||||
/// Runs the forward pass operation to generate a witness
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// data: str
|
||||
/// Path to the data file
|
||||
///
|
||||
/// model: str
|
||||
/// Path to the compiled model file
|
||||
///
|
||||
/// output: str
|
||||
/// Path to create the witness file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// dict
|
||||
/// Python object containing the witness values
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
data=PathBuf::from(DEFAULT_DATA),
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
@@ -687,7 +966,20 @@ fn gen_witness(
|
||||
Python::with_gil(|py| Ok(output.to_object(py)))
|
||||
}
|
||||
|
||||
/// mocks the prover
|
||||
/// Mocks the prover
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// witness: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
/// model: str
|
||||
/// Path to the compiled model file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
witness=PathBuf::from(DEFAULT_WITNESS),
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
@@ -700,7 +992,23 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// mocks the aggregate prover
|
||||
/// Mocks the aggregate prover
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_snarks: list[str]
|
||||
/// List of paths to the relevant proof files
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows to use for the aggregation circuit
|
||||
///
|
||||
/// split_proofs: bool
|
||||
/// Indicates whether the accumulated are segments of a larger proof
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
@@ -719,7 +1027,32 @@ fn mock_aggregate(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// runs the prover on a set of inputs
|
||||
/// Runs the setup process
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the compiled model file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to create the verification key file
|
||||
///
|
||||
/// pk_path: str
|
||||
/// Path to create the proving key file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// witness_path: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
/// disable_selector_compression: bool
|
||||
/// Whether to compress the selectors or not
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
@@ -752,7 +1085,32 @@ fn setup(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// runs the prover on a set of inputs
|
||||
/// Runs the prover on a set of inputs
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// witness: str
|
||||
/// Path to the witness file
|
||||
///
|
||||
/// model: str
|
||||
/// Path to the compiled model file
|
||||
///
|
||||
/// pk_path: str
|
||||
/// Path to the proving key file
|
||||
///
|
||||
/// proof_path: str
|
||||
/// Path to create the proof file
|
||||
///
|
||||
/// proof_type: str
|
||||
/// Accepts `single`, `for-aggr`
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
witness=PathBuf::from(DEFAULT_WITNESS),
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
@@ -786,7 +1144,29 @@ fn prove(
|
||||
Python::with_gil(|py| Ok(snark.to_object(py)))
|
||||
}
|
||||
|
||||
/// verifies a given proof
|
||||
/// Verifies a given proof
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// proof_path: str
|
||||
/// Path to create the proof file
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the verification key file
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// non_reduced_srs: bool
|
||||
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
@@ -816,6 +1196,38 @@ fn verify(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Runs the setup process for an aggregate setup
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// sample_snarks: list[str]
|
||||
/// List of paths to the various proofs
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to create the aggregated VK
|
||||
///
|
||||
/// pk_path: str
|
||||
/// Path to create the aggregated PK
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows to use
|
||||
///
|
||||
/// split_proofs: bool
|
||||
/// Whether the accumulated are segments of a larger proof
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS file
|
||||
///
|
||||
/// disable_selector_compression: bool
|
||||
/// Whether to compress selectors
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts `kzg`, `ipa`
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
sample_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
@@ -854,6 +1266,23 @@ fn setup_aggregate(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Compiles the circuit for use in other steps
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// model: str
|
||||
/// Path to the onnx model file
|
||||
///
|
||||
/// compiled_circuit: str
|
||||
/// Path to output the compiled circuit
|
||||
///
|
||||
/// settings_path: str
|
||||
/// Path to the settings files
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
model=PathBuf::from(DEFAULT_MODEL),
|
||||
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
|
||||
@@ -872,7 +1301,41 @@ fn compile_circuit(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// creates an aggregated proof
|
||||
/// Creates an aggregated proof
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_snarks: list[str]
|
||||
/// List of paths to the various proofs
|
||||
///
|
||||
/// proof_path: str
|
||||
/// Path to output the aggregated proof
|
||||
///
|
||||
/// vk_path: str
|
||||
/// Path to the VK file
|
||||
///
|
||||
/// transcript:
|
||||
/// Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
|
||||
///
|
||||
/// logrows:
|
||||
/// Logrows used for aggregation circuit
|
||||
///
|
||||
/// check_mode: str
|
||||
/// Run sanity checks during calculations. Accepts `safe` or `unsafe`
|
||||
///
|
||||
/// split-proofs: bool
|
||||
/// Whether the accumulated proofs are segments of a larger circuit
|
||||
///
|
||||
/// srs_path: str
|
||||
/// Path to the SRS used
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts "kzg" or "ipa"
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
|
||||
@@ -915,7 +1378,32 @@ fn aggregate(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// verifies and aggregate proof
|
||||
/// Verifies and aggregate proof
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// proof_path: str
|
||||
/// The path to the proof file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// The path to the verification key file
|
||||
///
|
||||
/// logrows: int
|
||||
/// logrows used for aggregation circuit
|
||||
///
|
||||
/// commitment: str
|
||||
/// Accepts "kzg" or "ipa"
|
||||
///
|
||||
/// reduced_srs: bool
|
||||
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
@@ -948,7 +1436,32 @@ fn verify_aggr(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// creates an EVM compatible verifier, you will need solc installed in your environment to run this
|
||||
/// Creates an EVM compatible verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// vk_path: str
|
||||
/// The path to the verification key file
|
||||
///
|
||||
/// settings_path: str
|
||||
/// The path to the settings file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the create the solidity verifier
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to create the ABI for the solidity verifier
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
@@ -981,7 +1494,26 @@ fn create_evm_verifier(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
// creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// input_data: str
|
||||
/// The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
|
||||
///
|
||||
/// settings_path: str
|
||||
/// The path to the settings file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the create the solidity verifier
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to create the ABI for the solidity verifier
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
input_data=PathBuf::from(DEFAULT_DATA),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
@@ -1003,6 +1535,32 @@ fn create_evm_data_attestation(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Setup test evm witness
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// data_path: str
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
///
|
||||
/// compiled_circuit_path: str
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
///
|
||||
/// test_data: str
|
||||
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
///
|
||||
/// input_sources: str
|
||||
/// Where the input data comes from
|
||||
///
|
||||
/// output_source: str
|
||||
/// Where the output data comes from
|
||||
///
|
||||
/// rpc_url: str
|
||||
/// RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
data_path,
|
||||
compiled_circuit_path,
|
||||
@@ -1037,6 +1595,7 @@ fn setup_test_evm_witness(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// deploys the solidity verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
|
||||
@@ -1069,6 +1628,7 @@ fn deploy_evm(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// deploys the solidity vk verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
|
||||
@@ -1101,6 +1661,7 @@ fn deploy_vk_evm(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// deploys the solidity da verifier
|
||||
#[pyfunction(signature = (
|
||||
addr_path,
|
||||
input_data,
|
||||
@@ -1138,6 +1699,27 @@ fn deploy_da_evm(
|
||||
Ok(true)
|
||||
}
|
||||
/// verifies an evm compatible proof, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// addr_verifier: str
|
||||
/// The path to verifier contract's address
|
||||
///
|
||||
/// proof_path: str
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
///
|
||||
/// rpc_url: str
|
||||
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
///
|
||||
/// addr_da: str
|
||||
/// does the verifier use data attestation ?
|
||||
///
|
||||
/// addr_vk: str
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
addr_verifier,
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
@@ -1183,7 +1765,35 @@ fn verify_evm(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
/// Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// aggregation_settings: str
|
||||
/// path to the settings file
|
||||
///
|
||||
/// vk_path: str
|
||||
/// The path to load the desired verification key file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the Solidity code
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to output the Solidity verifier ABI
|
||||
///
|
||||
/// logrows: int
|
||||
/// Number of logrows used during aggregated setup
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
aggregation_settings=vec![PathBuf::from(DEFAULT_PROOF)],
|
||||
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
|
||||
@@ -60,6 +60,9 @@ pub enum TensorError {
|
||||
/// Unsupported operation
|
||||
#[error("Unsupported operation on a tensor type")]
|
||||
Unsupported,
|
||||
/// Overflow
|
||||
#[error("Unsigned integer overflow or underflow error in op: {0}")]
|
||||
Overflow(String),
|
||||
}
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
@@ -934,6 +937,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
|
||||
assert!(source < self.dims.len());
|
||||
assert!(destination < self.dims.len());
|
||||
|
||||
let mut new_dims = self.dims.clone();
|
||||
new_dims.remove(source);
|
||||
new_dims.insert(destination, self.dims[source]);
|
||||
@@ -965,6 +969,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
old_coord[source - 1] = *c;
|
||||
} else if (i < source && source < destination)
|
||||
|| (i < destination && source > destination)
|
||||
|| (i > source && source > destination)
|
||||
|| (i > destination && source < destination)
|
||||
{
|
||||
old_coord[i] = *c;
|
||||
} else if i > source && source < destination {
|
||||
@@ -977,7 +983,10 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
));
|
||||
}
|
||||
}
|
||||
output.set(&coord, self.get(&old_coord));
|
||||
|
||||
let value = self.get(&old_coord);
|
||||
|
||||
output.set(&coord, value);
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
|
||||
2116
src/tensor/ops.rs
2116
src/tensor/ops.rs
File diff suppressed because it is too large
Load Diff
@@ -316,6 +316,12 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i128].
|
||||
pub fn from_i128_tensor(t: Tensor<i128>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(i128_to_felt(x))));
|
||||
inner.into()
|
||||
}
|
||||
|
||||
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
|
||||
pub fn new_instance(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -873,13 +879,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
/// Calls `pad` on the inner [Tensor].
|
||||
pub fn pad(&mut self, padding: [(usize, usize); 2]) -> Result<(), TensorError> {
|
||||
/// Calls `pad_spatial_dims` on the inner [Tensor].
|
||||
pub fn pad(&mut self, padding: Vec<(usize, usize)>, offset: usize) -> Result<(), TensorError> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = pad(v, padding)?;
|
||||
*v = pad(v, padding, offset)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
|
||||
Reference in New Issue
Block a user