Compare commits

...

15 Commits

Author SHA1 Message Date
dante
201d679806 Update utilities.rs 2024-12-27 17:51:50 -05:00
dante
4218f0c004 fix: get_slice should not use intermediate Vec 2024-12-27 17:34:17 -05:00
Jseam
8b223354cc fix: add version string and sed (#893) 2024-12-27 14:24:28 -05:00
dante
caa6ef8e16 fix: const filtering strat is size dependent (#891) 2024-12-27 09:43:59 -05:00
Artem
c4354c10a5 fix: ios bindings update action (#886) 2024-12-16 10:49:13 -05:00
dante
c1ce8c88d0 chore: rm wasm serialization checks (#890) 2024-12-12 22:20:29 -05:00
dante
876a9584a1 chore: optimize wasm bundle for speed over size (#889) 2024-12-12 15:35:17 -05:00
dante
7d7f049cc4 chore: neural bag of words example (#888) 2024-12-12 14:20:21 -05:00
dante
96f3fd94b2 feat: ICICLE MSM and NTT integration (#884) 2024-12-07 00:32:09 +00:00
dante
6263510c56 fix: bump pypi-publish to unstable to use twine updates (#881) 2024-12-06 23:19:29 +00:00
Jseam
f5b8ae3213 fix: revert pypi to 1.11.0 (#880) 2024-12-05 14:46:40 -05:00
dante
b2e4e414f0 chore: update pyo3 and add stub (#879) 2024-12-05 10:35:06 -05:00
Dmitry
0b0199e2b7 fix: typo in lib.rs (#877) 2024-12-03 18:46:46 -05:00
dante
5e169bdd17 chore: update tract to 0.21.8-pre (#878) 2024-12-03 16:52:03 -05:00
dante
64cbcb3f7e chore: explicitly compile div op (#876) 2024-11-28 17:14:53 +09:00
40 changed files with 3199 additions and 749 deletions

View File

@@ -2,3 +2,16 @@
runner = 'wasm-bindgen-test-runner'
rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals","-C",
"link-arg=--max-memory=4294967296"]
[target.x86_64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]
[target.aarch64-apple-darwin]
rustflags = [
"-C", "link-arg=-undefined",
"-C", "link-arg=dynamic_lookup",
]

View File

@@ -34,6 +34,7 @@ jobs:
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/ezkl/ezkl-gpu/" pyproject.toml.orig >pyproject.toml
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- uses: actions-rs/toolchain@v1
with:
@@ -98,14 +99,14 @@ jobs:
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
packages-dir: ./
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./

View File

@@ -233,6 +233,14 @@ jobs:
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag
shell: bash
env:
RELEASE_TAG: ${{ github.ref_name }}
run: |
mv pyproject.toml pyproject.toml.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" pyproject.toml.orig >pyproject.toml
- name: Set Cargo.toml version to match github tag
shell: bash
env:
@@ -242,7 +250,6 @@ jobs:
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.toml.orig >Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -348,14 +355,14 @@ jobs:
# publishes to PyPI
- name: Publish package distributions to PyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
packages-dir: ./
# publishes to TestPyPI
- name: Publish package distribution to TestPyPI
continue-on-error: true
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@unstable/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./

View File

@@ -207,23 +207,6 @@ jobs:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Circuit Render
run: cargo nextest run --release --verbose tests::tutorial_
mock-proving-tests:
runs-on: non-gpu
needs: [build, library-tests, docs, python-tests, python-integration-tests]
@@ -494,23 +477,23 @@ jobs:
- name: Mock aggr tests (KZG)
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
prove-and-verify-aggr-tests-gpu:
runs-on: GPU
env:
ENABLE_ICICLE_GPU: true
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: KZG )tests
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
# prove-and-verify-aggr-tests-gpu:
# runs-on: GPU
# env:
# ENABLE_ICICLE_GPU: true
# steps:
# - uses: actions/checkout@v4
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: nightly-2024-07-18
# override: true
# components: rustfmt, clippy
# - uses: baptiste0928/cargo-install@v1
# with:
# crate: cargo-nextest
# locked: true
# - name: KZG tests
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
@@ -614,8 +597,6 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
- name: Public inputs
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
- name: fixed params
@@ -669,6 +650,10 @@ jobs:
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Neural bow
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
- name: Felt conversion
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials

129
.github/workflows/swift-pm.yml vendored Normal file
View File

@@ -0,0 +1,129 @@
name: Build and Publish EZKL iOS SPM package
on:
push:
tags:
# Only support SemVer versioning tags
- 'v[0-9]+.[0-9]+.[0-9]+'
- '[0-9]+.[0-9]+.[0-9]+'
jobs:
build-and-update:
runs-on: macos-latest
env:
EZKL_SWIFT_PACKAGE_REPO: github.com/zkonduit/ezkl-swift-package.git
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
- name: Extract TAG from github.ref_name
run: |
# github.ref_name is provided by GitHub Actions and contains the tag name directly.
TAG="${{ github.ref_name }}"
echo "Original TAG: $TAG"
# Remove leading 'v' if present to match the Swift Package Manager version format.
NEW_TAG=${TAG#v}
echo "Stripped TAG: $NEW_TAG"
echo "TAG=$NEW_TAG" >> $GITHUB_ENV
- name: Install Rust (nightly)
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://${{ env.EZKL_SWIFT_PACKAGE_REPO }}
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/*
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Check for changes
id: check_changes
run: |
cd ezkl-swift-package
if git diff --quiet Sources/EzklCoreBindings Tests/EzklAssets; then
echo "no_changes=true" >> $GITHUB_OUTPUT
else
echo "no_changes=false" >> $GITHUB_OUTPUT
fi
- name: Set up Xcode environment
if: steps.check_changes.outputs.no_changes == 'false'
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Setup Git
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git remote set-url origin https://zkonduit:${EZKL_SWIFT_PACKAGE_REPO_TOKEN}@${{ env.EZKL_SWIFT_PACKAGE_REPO }}
env:
EZKL_SWIFT_PACKAGE_REPO_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}
- name: Commit and Push Changes
if: steps.check_changes.outputs.no_changes == 'false'
run: |
cd ezkl-swift-package
git add Sources/EzklCoreBindings Tests/EzklAssets
git commit -m "Automatically updated EzklCoreBindings for EZKL"
if ! git push origin; then
echo "::error::Failed to push changes to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure that EZKL_PORTER_TOKEN has the correct permissions."
exit 1
fi
- name: Tag the latest commit
run: |
cd ezkl-swift-package
source $GITHUB_ENV
# Tag the latest commit on the current branch
if git rev-parse "$TAG" >/dev/null 2>&1; then
echo "Tag $TAG already exists locally. Skipping tag creation."
else
git tag "$TAG"
fi
if ! git push origin "$TAG"; then
echo "::error::Failed to push tag '$TAG' to ${{ env.EZKL_SWIFT_PACKAGE_REPO }}. Please ensure EZKL_PORTER_TOKEN has correct permissions."
exit 1
fi

View File

@@ -1,85 +0,0 @@
name: Build and Publish EZKL iOS SPM package
on:
workflow_dispatch:
inputs:
tag:
description: "The tag to release"
required: true
push:
tags:
- "*"
jobs:
build-and-update:
runs-on: macos-latest
steps:
- name: Checkout EZKL
uses: actions/checkout@v3
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
- name: Build EzklCoreBindings
run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features
- name: Clone ezkl-swift-package repository
run: |
git clone https://github.com/zkonduit/ezkl-swift-package.git
- name: Copy EzklCoreBindings
run: |
rm -rf ezkl-swift-package/Sources/EzklCoreBindings
cp -r build/EzklCoreBindings ezkl-swift-package/Sources/
- name: Copy Test Files
run: |
rm -rf ezkl-swift-package/Tests/EzklAssets/*
cp tests/assets/kzg ezkl-swift-package/Tests/EzklAssets/kzg.srs
cp tests/assets/input.json ezkl-swift-package/Tests/EzklAssets/input.json
cp tests/assets/model.compiled ezkl-swift-package/Tests/EzklAssets/network.ezkl
cp tests/assets/settings.json ezkl-swift-package/Tests/EzklAssets/settings.json
- name: Set up Xcode environment
run: |
sudo xcode-select -s /Applications/Xcode.app/Contents/Developer
sudo xcodebuild -license accept
- name: Run Package Tests
run: |
cd ezkl-swift-package
xcodebuild test \
-scheme EzklPackage \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-resultBundlePath ../testResults
- name: Run Example App Tests
run: |
cd ezkl-swift-package/Example
xcodebuild test \
-project Example.xcodeproj \
-scheme EzklApp \
-destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \
-parallel-testing-enabled NO \
-resultBundlePath ../../exampleTestResults \
-skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder
- name: Commit and Push Changes to feat/ezkl-direct-integration
run: |
cd ezkl-swift-package
git config user.name "GitHub Action"
git config user.email "action@github.com"
git add Sources/EzklCoreBindings
git add Tests/EzklAssets
git commit -m "Automatically updated EzklCoreBindings for EZKL"
git tag ${{ github.event.inputs.tag }}
git remote set-url origin https://zkonduit:${EZKL_PORTER_TOKEN}@github.com/zkonduit/ezkl-swift-package.git
git push origin
git push origin tag ${{ github.event.inputs.tag }}
env:
EZKL_PORTER_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }}

3
.gitignore vendored
View File

@@ -27,7 +27,6 @@ __pycache__/
*.pyc
*.pyo
*.py[cod]
bin/
build/
develop-eggs/
dist/
@@ -49,4 +48,4 @@ timingData.json
!tests/assets/pk.key
!tests/assets/vk.key
docs/python/build
!tests/assets/vk_aggr.key
!tests/assets/vk_aggr.key

717
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -16,11 +16,11 @@ crate-type = ["cdylib", "rlib", "staticlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2_gadgets = { git = "https://github.com/zkonduit/halo2" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
"derive_serde",
] }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = [
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", features = [
"circuit-params",
] }
rand = { version = "0.8", default-features = false }
@@ -35,7 +35,7 @@ halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves", optional = true }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", optional = true }
maybe-rayon = { version = "0.1.1", default-features = false }
bincode = { version = "1.3.3", default-features = false }
unzip-n = "0.1.2"
@@ -79,21 +79,22 @@ tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.21.2", features = [
pyo3 = { version = "0.23.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default-features = false, optional = true }
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch = "migration-pyo3-0.21", features = [
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
"attributes",
"tokio-runtime",
], default-features = false, optional = true }
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
@@ -146,6 +147,10 @@ shellexpand = "3.1.0"
runner = 'wasm-bindgen-test-runner'
[[bench]]
name = "zero_finder"
harness = false
[[bench]]
name = "accum_dot"
harness = false
@@ -210,6 +215,10 @@ required-features = ["ezkl"]
name = "ios_gen_bindings"
required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
[[bin]]
name = "py_stub_gen"
required-features = ["python-bindings"]
[features]
web = ["wasm-bindgen-rayon"]
default = [
@@ -220,7 +229,7 @@ default = [
"parallel-poly-read",
]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
ezkl = [
@@ -269,12 +278,9 @@ empty-cmd = []
no-banner = []
no-update = []
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b", package = "halo2_proofs" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
@@ -284,3 +290,11 @@ rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
# panic = "abort"
[package.metadata.wasm-pack.profile.release]
wasm-opt = [
"-O4",
"--flexible-inline-max-function-size",
"4294967295",
]

116
benches/zero_finder.rs Normal file
View File

@@ -0,0 +1,116 @@
use std::thread;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use halo2curves::{bn256::Fr as F, ff::Field};
use maybe_rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
slice::ParallelSlice,
};
use rand::Rng;
// Assuming these are your types
#[derive(Clone)]
enum ValType {
Constant(F),
AssignedConstant(usize, F),
Other,
}
// Helper to generate test data
fn generate_test_data(size: usize, zero_probability: f64) -> Vec<ValType> {
let mut rng = rand::thread_rng();
(0..size)
.map(|_i| {
if rng.gen::<f64>() < zero_probability {
ValType::Constant(F::ZERO)
} else {
ValType::Constant(F::ONE) // Or some other non-zero value
}
})
.collect()
}
fn bench_zero_finding(c: &mut Criterion) {
let sizes = [
1_000, // 1K
10_000, // 10K
100_000, // 100K
256 * 256 * 2, // Our specific case
1_000_000, // 1M
10_000_000, // 10M
];
let zero_probability = 0.1; // 10% zeros
let mut group = c.benchmark_group("zero_finding");
group.sample_size(10); // Adjust based on your needs
for &size in &sizes {
let data = generate_test_data(size, zero_probability);
// Benchmark sequential version
group.bench_function(format!("sequential_{}", size), |b| {
b.iter(|| {
let result = data
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});
// Benchmark parallel version
group.bench_function(format!("parallel_{}", size), |b| {
b.iter(|| {
let result = data
.par_iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect::<Vec<_>>();
black_box(result)
})
});
// Benchmark chunked parallel version
group.bench_function(format!("chunked_parallel_{}", size), |b| {
b.iter(|| {
let num_cores = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (size / num_cores).max(100);
let result = data
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter() // Make sure we use par_iter() here
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>();
black_box(result)
})
});
}
group.finish();
}
criterion_group!(benches, bench_zero_finding);
criterion_main!(benches);

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,766 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"This is a zk version of the tutorial found [here](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/main/1%20-%20Neural%20Bag%20of%20Words.ipynb). The original tutorial is part of the PyTorch Sentiment Analysis series by Ben Trevett.\n",
"\n",
"1 - NBoW\n",
"\n",
"In this series we'll be building a machine learning model to perform sentiment analysis -- a subset of text classification where the task is to detect if a given sentence is positive or negative -- using PyTorch and torchtext. The dataset used will be movie reviews from the IMDb dataset, which we'll obtain using the datasets library.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Preparing Data\n",
"\n",
"Before we can implement our NBoW model, we first have to perform quite a few steps to get our data ready to use. NLP usually requires quite a lot of data wrangling beforehand, though libraries such as datasets and torchtext handle most of this for us.\n",
"\n",
"The steps to take are:\n",
"\n",
" 1. importing modules\n",
" 2. loading data\n",
" 3. tokenizing data\n",
" 4. creating data splits\n",
" 5. creating a vocabulary\n",
" 6. numericalizing data\n",
" 7. creating the data loaders\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install torchtex"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"\n",
"import datasets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchtext\n",
"import tqdm\n",
"\n",
"# It is usually good practice to run your experiments multiple times with different random seeds -- both to measure the variance of your model and also to avoid having results only calculated with either \"good\" or \"bad\" seeds, i.e. being very lucky or unlucky with the randomness in the training process.\n",
"\n",
"seed = 1234\n",
"\n",
"np.random.seed(seed)\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"torch.backends.cudnn.deterministic = True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data, test_data = datasets.load_dataset(\"imdb\", split=[\"train\", \"test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check the features attribute of a split to get more information about the features. We can see that text is a Value of dtype=string -- in other words, it's a string -- and that label is a ClassLabel. A ClassLabel means the feature is an integer representation of which class the example belongs to. num_classes=2 means that our labels are one of two values, 0 or 1, and names=['neg', 'pos'] gives us the human-readable versions of those values. Thus, a label of 0 means the example is a negative review and a label of 1 means the example is a positive review."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data.features\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the first things we need to do to our data is tokenize it. Machine learning models aren't designed to handle strings, they're design to handle numbers. So what we need to do is break down our string into individual tokens, and then convert these tokens to numbers. We'll get to the conversion later, but first we'll look at tokenization.\n",
"\n",
"Tokenization involves using a tokenizer to process the strings in our dataset. A tokenizer is a function that goes from a string to a list of strings. There are many types of tokenizers available, but we're going to use a relatively simple one provided by torchtext called the basic_english tokenizer. We load our tokenizer as such:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = torchtext.data.utils.get_tokenizer(\"basic_english\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_example(example, tokenizer, max_length):\n",
" tokens = tokenizer(example[\"text\"])[:max_length]\n",
" return {\"tokens\": tokens}\n",
"\n",
"\n",
"max_length = 256\n",
"\n",
"train_data = train_data.map(\n",
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
")\n",
"test_data = test_data.map(\n",
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
")\n",
"\n",
"\n",
"# create validation data \n",
"# Why have both a validation set and a test set? Your test set respresents the real world data that you'd see if you actually deployed this model. You won't be able to see what data your model will be fed once deployed, and your test set is supposed to reflect that. Every time we tune our model hyperparameters or training set-up to make it do a bit better on the test set, we are leak information from the test set into the training process. If we do this too often then we begin to overfit on the test set. Hence, we need some data which can act as a \"proxy\" test set which we can look at more frequently in order to evaluate how well our model actually does on unseen data -- this is the validation set.\n",
"\n",
"test_size = 0.25\n",
"\n",
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
"train_data = train_valid_data[\"train\"]\n",
"valid_data = train_valid_data[\"test\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we have to build a vocabulary. This is look-up table where every unique token in your dataset has a corresponding index (an integer).\n",
"\n",
"We do this as machine learning models cannot operate on strings, only numerical vaslues. Each index is used to construct a one-hot vector for each token. A one-hot vector is a vector where all the elements are 0, except one, which is 1, and the dimensionality is the total number of unique tokens in your vocabulary, commonly denoted by V."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"min_freq = 5\n",
"special_tokens = [\"<unk>\", \"<pad>\"]\n",
"\n",
"vocab = torchtext.vocab.build_vocab_from_iterator(\n",
" train_data[\"tokens\"],\n",
" min_freq=min_freq,\n",
" specials=special_tokens,\n",
")\n",
"\n",
"# We store the indices of the unknown and padding tokens (zero and one, respectively) in variables, as we'll use these further on in this notebook.\n",
"\n",
"unk_index = vocab[\"<unk>\"]\n",
"pad_index = vocab[\"<pad>\"]\n",
"\n",
"\n",
"vocab.set_default_index(unk_index)\n",
"\n",
"# To look-up a list of tokens, we can use the vocabulary's lookup_indices method.\n",
"vocab.lookup_indices([\"hello\", \"world\", \"some_token\", \"<pad>\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have our vocabulary, we can numericalize our data. This involves converting the tokens within our dataset into indices. Similar to how we tokenized our data using the Dataset.map method, we'll define a function that takes an example and our vocabulary, gets the index for each token in each example and then creates an ids field which containes the numericalized tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def numericalize_example(example, vocab):\n",
" ids = vocab.lookup_indices(example[\"tokens\"])\n",
" return {\"ids\": ids}\n",
"\n",
"train_data = train_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"valid_data = valid_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"test_data = test_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
"\n",
"train_data = train_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
"valid_data = valid_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
"test_data = test_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The final step of preparing the data is creating the data loaders. We can iterate over a data loader to retrieve batches of examples. This is also where we will perform any padding that is necessary.\n",
"\n",
"We first need to define a function to collate a batch, consisting of a list of examples, into what we want our data loader to output.\n",
"\n",
"Here, our desired output from the data loader is a dictionary with keys of \"ids\" and \"label\".\n",
"\n",
"The value of batch[\"ids\"] should be a tensor of shape [batch size, length], where length is the length of the longest sentence (in terms of tokens) within the batch, and all sentences shorter than this should be padded to that length.\n",
"\n",
"The value of batch[\"label\"] should be a tensor of shape [batch size] consisting of the label for each sentence in the batch.\n",
"\n",
"We define a function, get_collate_fn, which is passed the pad token index and returns the actual collate function. Within the actual collate function, collate_fn, we get a list of \"ids\" tensors for each example in the batch, and then use the pad_sequence function, which converts the list of tensors into the desired [batch size, length] shaped tensor and performs padding using the specified pad_index. By default, pad_sequence will return a [length, batch size] shaped tensor, but by setting batch_first=True, these two dimensions are switched. We get a list of \"label\" tensors and convert the list of tensors into a single [batch size] shaped tensor."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_collate_fn(pad_index):\n",
" def collate_fn(batch):\n",
" batch_ids = [i[\"ids\"] for i in batch]\n",
" batch_ids = nn.utils.rnn.pad_sequence(\n",
" batch_ids, padding_value=pad_index, batch_first=True\n",
" )\n",
" batch_label = [i[\"label\"] for i in batch]\n",
" batch_label = torch.stack(batch_label)\n",
" batch = {\"ids\": batch_ids, \"label\": batch_label}\n",
" return batch\n",
"\n",
" return collate_fn\n",
"\n",
"def get_data_loader(dataset, batch_size, pad_index, shuffle=False):\n",
" collate_fn = get_collate_fn(pad_index)\n",
" data_loader = torch.utils.data.DataLoader(\n",
" dataset=dataset,\n",
" batch_size=batch_size,\n",
" collate_fn=collate_fn,\n",
" shuffle=shuffle,\n",
" )\n",
" return data_loader\n",
"\n",
"\n",
"batch_size = 512\n",
"\n",
"train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)\n",
"valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)\n",
"test_data_loader = get_data_loader(test_data, batch_size, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class NBoW(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
"\n",
" def forward(self, ids):\n",
" # ids = [batch size, seq len]\n",
" embedded = self.embedding(ids)\n",
" # embedded = [batch size, seq len, embedding dim]\n",
" pooled = embedded.mean(dim=1)\n",
" # pooled = [batch size, embedding dim]\n",
" prediction = self.fc(pooled)\n",
" # prediction = [batch size, output dim]\n",
" return prediction\n",
"\n",
"\n",
"vocab_size = len(vocab)\n",
"embedding_dim = 300\n",
"output_dim = len(train_data.unique(\"label\"))\n",
"\n",
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)\n",
"\n",
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"\n",
"print(f\"The model has {count_parameters(model):,} trainable parameters\")\n",
"\n",
"vectors = torchtext.vocab.GloVe()\n",
"\n",
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())\n",
"\n",
"optimizer = optim.Adam(model.parameters())\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(data_loader, model, criterion, optimizer, device):\n",
" model.train()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
" for batch in tqdm.tqdm(data_loader, desc=\"training...\"):\n",
" ids = batch[\"ids\"].to(device)\n",
" label = batch[\"label\"].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
" return np.mean(epoch_losses), np.mean(epoch_accs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(data_loader, model, criterion, device):\n",
" model.eval()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
" with torch.no_grad():\n",
" for batch in tqdm.tqdm(data_loader, desc=\"evaluating...\"):\n",
" ids = batch[\"ids\"].to(device)\n",
" label = batch[\"label\"].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
" return np.mean(epoch_losses), np.mean(epoch_accs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(prediction, label):\n",
" batch_size, _ = prediction.shape\n",
" predicted_classes = prediction.argmax(dim=-1)\n",
" correct_predictions = predicted_classes.eq(label).sum()\n",
" accuracy = correct_predictions / batch_size\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_epochs = 10\n",
"best_valid_loss = float(\"inf\")\n",
"\n",
"metrics = collections.defaultdict(list)\n",
"\n",
"for epoch in range(n_epochs):\n",
" train_loss, train_acc = train(\n",
" train_data_loader, model, criterion, optimizer, device\n",
" )\n",
" valid_loss, valid_acc = evaluate(valid_data_loader, model, criterion, device)\n",
" metrics[\"train_losses\"].append(train_loss)\n",
" metrics[\"train_accs\"].append(train_acc)\n",
" metrics[\"valid_losses\"].append(valid_loss)\n",
" metrics[\"valid_accs\"].append(valid_acc)\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), \"nbow.pt\")\n",
" print(f\"epoch: {epoch}\")\n",
" print(f\"train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}\")\n",
" print(f\"valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(metrics[\"train_losses\"], label=\"train loss\")\n",
"ax.plot(metrics[\"valid_losses\"], label=\"valid loss\")\n",
"ax.set_xlabel(\"epoch\")\n",
"ax.set_ylabel(\"loss\")\n",
"ax.set_xticks(range(n_epochs))\n",
"ax.legend()\n",
"ax.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(10, 6))\n",
"ax = fig.add_subplot(1, 1, 1)\n",
"ax.plot(metrics[\"train_accs\"], label=\"train accuracy\")\n",
"ax.plot(metrics[\"valid_accs\"], label=\"valid accuracy\")\n",
"ax.set_xlabel(\"epoch\")\n",
"ax.set_ylabel(\"loss\")\n",
"ax.set_xticks(range(n_epochs))\n",
"ax.legend()\n",
"ax.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(\"nbow.pt\"))\n",
"\n",
"test_loss, test_acc = evaluate(test_data_loader, model, criterion, device)\n",
"\n",
"print(f\"test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = vocab.lookup_indices(tokens)\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" prediction = model(tensor).squeeze(dim=0)\n",
" probability = torch.softmax(prediction, dim=-1)\n",
" predicted_class = prediction.argmax(dim=-1).item()\n",
" predicted_probability = probability[predicted_class].item()\n",
" return predicted_class, predicted_probability"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is not terrible, it's great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"This film is not great, it's terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def text_to_tensor(text, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = vocab.lookup_indices(tokens)\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" return tensor\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we do onnx stuff to get the data ready for the zk-circuit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import json\n",
"\n",
"text = \"This film is terrible!\"\n",
"x = text_to_tensor(text, tokenizer, vocab, device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"model.eval()\n",
"model.to('cpu')\n",
"\n",
"model_path = \"network.onnx\"\n",
"data_path = \"input.json\"\n",
"\n",
" # Export the model\n",
"torch.onnx.export(model, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data_json = dict(input_data = [data_array])\n",
"\n",
"print(data_json)\n",
"\n",
" # Serialize data into file:\n",
"json.dump(data_json, open(data_path, 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.logrows = 23\n",
"run_args.scale_rebase_multiplier = 10\n",
"# inputs should be auditable by all\n",
"run_args.input_visibility = \"public\"\n",
"# same with outputs\n",
"run_args.output_visibility = \"public\"\n",
"# for simplicity, we'll just use the fixed model visibility: i.e it is public and can't be changed by the prover\n",
"run_args.param_visibility = \"fixed\"\n",
"\n",
"\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(py_run_args=run_args)\n",
"assert res == True\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit()\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file\n",
"res = await ezkl.gen_witness()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.mock()\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"res = ezkl.setup()\n",
"\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"res = ezkl.prove(proof_path=\"proof.json\")\n",
"\n",
"print(res)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"res = ezkl.verify()\n",
"\n",
"assert res == True\n",
"print(\"verified\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also verify it on chain by creating an onchain verifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"solc-select\"])\n",
" !solc-select install 0.8.20\n",
" !solc-select use 0.8.20\n",
" !solc --version\n",
" import os\n",
"\n",
"# rely on local installation if the notebook is not in colab\n",
"except:\n",
" import os\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = await ezkl.create_evm_verifier()\n",
"assert res == True\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see a `Verifier.sol`. Right-click and save it locally.\n",
"\n",
"Now go to [https://remix.ethereum.org](https://remix.ethereum.org).\n",
"\n",
"Create a new file within remix and copy the verifier code over.\n",
"\n",
"Finally, compile the code and deploy. For the demo you can deploy to the test environment within remix.\n",
"\n",
"If everything works, you would have deployed your verifer onchain! Copy the values in the cell above to the respective fields to test if the verifier is working."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,75 +1,52 @@
import torch
import torch.nn as nn
import sys
from torch import nn
import json
sys.path.append("..")
class Model(nn.Module):
"""
Just one Linear layer
"""
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
# Use this line if you want to visualize the weights
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
self.channels = configs.enc_in
self.individual = configs.individual
if self.individual:
self.Linear = nn.ModuleList()
for i in range(self.channels):
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
else:
self.Linear = nn.Linear(self.seq_len, self.pred_len)
def forward(self, x):
# x: [Batch, Input length, Channel]
if self.individual:
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
for i in range(self.channels):
output[:,:,i] = self.Linear[i](x[:,:,i])
x = output
else:
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
return x # [Batch, Output length, Channel]
class Configs:
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.individual = individual
model = 'Linear'
seq_len = 10
pred_len = 4
enc_in = 3
configs = Configs(seq_len, pred_len, enc_in, True)
circuit = Model(configs)
x = torch.randn(1, seq_len, pred_len)
import numpy as np
import tf2onnx
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
# gather_nd in tf then export to onnx
x = in1 = Input((4, 1), dtype=tf.int32)
w = in2 = Input((4, ), dtype=tf.int32)
class MyLayer(Layer):
def call(self, x, w):
shape = tf.constant([8])
return tf.scatter_nd(x, w, shape)
x = MyLayer()(x, w)
tm = Model((in1, in2), x)
tm.summary()
tm.compile(optimizer='adam', loss='mse')
shape = [1, 4, 1]
index_shape = [1, 4]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = np.random.randint(0, 4, shape)
# w = random int tensor
w = np.random.randint(0, 4, index_shape)
spec = tf.TensorSpec(shape, tf.int32, name='input_0')
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
model_path = "network.onnx"
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
d = x.reshape([-1]).tolist()
d1 = w.reshape([-1]).tolist()
data = dict(
input_data=[d1],
input_data=[d, d1],
)
# Serialize data into file:

View File

@@ -1 +1,16 @@
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
{
"input_data": [
[
0,
1,
2,
3
],
[
1,
0,
2,
1
]
]
}

851
ezkl.pyi Normal file
View File

@@ -0,0 +1,851 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401
import os
import pathlib
import typing
from enum import Enum, auto
class PyG1:
r"""
pyclass containing the struct used for G1, this is mostly a helper class
"""
...
class PyG1Affine:
r"""
pyclass containing the struct used for G1
"""
...
class PyRunArgs:
r"""
Python class containing the struct used for run_args
Returns
-------
PyRunArgs
"""
...
class PyCommitments(Enum):
r"""
pyclass representing an enum, denoting the type of commitment
"""
KZG = auto()
IPA = auto()
class PyInputType(Enum):
Bool = auto()
F16 = auto()
F32 = auto()
F64 = auto()
Int = auto()
TDim = auto()
class PyTestDataSource(Enum):
r"""
pyclass representing an enum
"""
File = auto()
OnChain = auto()
def aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,transcript:str,logrows:int,check_mode:str,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:PyCommitments) -> bool:
r"""
Creates an aggregated proof
Arguments
---------
aggregation_snarks: list[str]
List of paths to the various proofs
proof_path: str
Path to output the aggregated proof
vk_path: str
Path to the VK file
transcript:
Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
logrows:
Logrows used for aggregation circuit
check_mode: str
Run sanity checks during calculations. Accepts `safe` or `unsafe`
split-proofs: bool
Whether the accumulated proofs are segments of a larger circuit
srs_path: str
Path to the SRS used
commitment: str
Accepts "kzg" or "ipa"
Returns
-------
bool
"""
...
def buffer_to_felts(buffer:typing.Sequence[int]) -> list[str]:
r"""
Converts a buffer to vector of field elements
Arguments
-------
buffer: list[int]
List of integers representing a buffer
Returns
-------
list[str]
List of field elements represented as strings
"""
...
def calibrate_settings(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,settings:str | os.PathLike | pathlib.Path,target:str,lookup_safety_margin:float,scales:typing.Optional[typing.Sequence[int]],scale_rebase_multiplier:typing.Sequence[int],max_logrows:typing.Optional[int]) -> typing.Any:
r"""
Calibrates the circuit settings
Arguments
---------
data: str
Path to the calibration data
model: str
Path to the onnx file
settings: str
Path to the settings file
lookup_safety_margin: int
the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
scales: list[int]
Optional scales to specifically try for calibration
scale_rebase_multiplier: list[int]
Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
max_logrows: int
Optional max logrows to use for calibration
Returns
-------
bool
"""
...
def compile_circuit(model:str | os.PathLike | pathlib.Path,compiled_circuit:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Compiles the circuit for use in other steps
Arguments
---------
model: str
Path to the onnx model file
compiled_circuit: str
Path to output the compiled circuit
settings_path: str
Path to the settings files
Returns
-------
bool
"""
...
def create_evm_data_attestation(input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,witness_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
Arguments
---------
input_data: str
The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifier
abi_path: str
The path to create the ABI for the solidity verifier
Returns
-------
bool
"""
...
def create_evm_verifier(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
r"""
Creates an EVM compatible verifier, you will need solc installed in your environment to run this
Arguments
---------
vk_path: str
The path to the verification key file
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifier
abi_path: str
The path to create the ABI for the solidity verifier
srs_path: str
The path to the SRS file
reusable: bool
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
Returns
-------
bool
"""
...
def create_evm_verifier_aggr(aggregation_settings:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,logrows:int,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
r"""
Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
Arguments
---------
aggregation_settings: str
path to the settings file
vk_path: str
The path to load the desired verification key file
sol_code_path: str
The path to the Solidity code
abi_path: str
The path to output the Solidity verifier ABI
logrows: int
Number of logrows used during aggregated setup
srs_path: str
The path to the SRS file
reusable: bool
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
Returns
-------
bool
"""
...
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
Arguments
---------
vk_path: str
The path to the verification key file
settings_path: str
The path to the settings file
sol_code_path: str
The path to the create the solidity verifying key.
abi_path: str
The path to create the ABI for the solidity verifier
srs_path: str
The path to the SRS file
Returns
-------
bool
"""
...
def deploy_da_evm(addr_path:str | os.PathLike | pathlib.Path,input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
r"""
deploys the solidity da verifier
"""
...
def deploy_evm(addr_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],contract_type:str,optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
r"""
deploys the solidity verifier
"""
...
def encode_evm_calldata(proof:str | os.PathLike | pathlib.Path,calldata:str | os.PathLike | pathlib.Path,addr_vk:typing.Optional[str]) -> list[int]:
r"""
Creates encoded evm calldata from a proof file
Arguments
---------
proof: str
Path to the proof file
calldata: str
Path to the calldata file to save
addr_vk: str
The address of the verification key contract (if the verifier key is to be rendered as a separate contract)
Returns
-------
vec[u8]
The encoded calldata
"""
...
def felt_to_big_endian(felt:str) -> str:
r"""
Converts a field element hex string to big endian
Arguments
-------
felt: str
The field element represented as a string
Returns
-------
str
field element represented as a string
"""
...
def felt_to_float(felt:str,scale:int) -> float:
r"""
Converts a field element hex string to a floating point number
Arguments
-------
felt: str
The field element represented as a string
scale: float
The scaling factor used to convert the field element into a floating point representation
Returns
-------
float
"""
...
def felt_to_int(felt:str) -> int:
r"""
Converts a field element hex string to an integer
Arguments
-------
felt: str
The field element represented as a string
Returns
-------
int
"""
...
def float_to_felt(input:float,scale:int,input_type:PyInputType) -> str:
r"""
Converts a floating point element to a field element hex string
Arguments
-------
input: float
The field element represented as a string
scale: float
The scaling factor used to quantize the float into a field element
input_type: PyInputType
The type of the input
Returns
-------
str
The field element represented as a string
"""
...
def gen_settings(model:str | os.PathLike | pathlib.Path,output:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> bool:
r"""
Generates the circuit settings
Arguments
---------
model: str
Path to the onnx file
output: str
Path to create the settings file
py_run_args: PyRunArgs
PyRunArgs object to initialize the settings
Returns
-------
bool
"""
...
def gen_srs(srs_path:str | os.PathLike | pathlib.Path,logrows:int) -> None:
r"""
Generates the Structured Reference String (SRS), use this only for testing purposes
Arguments
---------
srs_path: str
Path to the create the SRS file
logrows: int
The number of logrows for the SRS file
"""
...
def gen_vk_from_pk_aggr(path_to_pk:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Generates a vk from a pk for an aggregate circuit and saves it to a file
Arguments
-------
path_to_pk: str
Path to the proving key
vk_output_path: str
Path to create the vk file
Returns
-------
bool
"""
...
def gen_vk_from_pk_single(path_to_pk:str | os.PathLike | pathlib.Path,circuit_settings_path:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
r"""
Generates a vk from a pk for a model circuit and saves it to a file
Arguments
-------
path_to_pk: str
Path to the proving key
circuit_settings_path: str
Path to the witness file
vk_output_path: str
Path to create the vk file
Returns
-------
bool
"""
...
def gen_witness(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,output:typing.Optional[str | os.PathLike | pathlib.Path],vk_path:typing.Optional[str | os.PathLike | pathlib.Path],srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Runs the forward pass operation to generate a witness
Arguments
---------
data: str
Path to the data file
model: str
Path to the compiled model file
output: str
Path to create the witness file
vk_path: str
Path to the verification key
srs_path: str
Path to the SRS file
Returns
-------
dict
Python object containing the witness values
"""
...
def get_srs(settings_path:typing.Optional[str | os.PathLike | pathlib.Path],logrows:typing.Optional[int],srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:typing.Optional[PyCommitments]) -> typing.Any:
r"""
Gets a public srs
Arguments
---------
settings_path: str
Path to the settings file
logrows: int
The number of logrows for the SRS file
srs_path: str
Path to the create the SRS file
commitment: str
Specify the commitment used ("kzg", "ipa")
Returns
-------
bool
"""
...
def ipa_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
r"""
Generate an ipa commitment.
Arguments
-------
message: list[str]
List of field elements represnted as strings
vk_path: str
Path to the verification key
settings_path: str
Path to the settings file
srs_path: str
Path to the Structure Reference String (SRS) file
Returns
-------
list[PyG1Affine]
"""
...
def kzg_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
r"""
Generate a kzg commitment.
Arguments
-------
message: list[str]
List of field elements represnted as strings
vk_path: str
Path to the verification key
settings_path: str
Path to the settings file
srs_path: str
Path to the Structure Reference String (SRS) file
Returns
-------
list[PyG1Affine]
"""
...
def mock(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path) -> bool:
r"""
Mocks the prover
Arguments
---------
witness: str
Path to the witness file
model: str
Path to the compiled model file
Returns
-------
bool
"""
...
def mock_aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],logrows:int,split_proofs:bool) -> bool:
r"""
Mocks the aggregate prover
Arguments
---------
aggregation_snarks: list[str]
List of paths to the relevant proof files
logrows: int
Number of logrows to use for the aggregation circuit
split_proofs: bool
Indicates whether the accumulated are segments of a larger proof
Returns
-------
bool
"""
...
def poseidon_hash(message:typing.Sequence[str]) -> list[str]:
r"""
Generate a poseidon hash.
Arguments
-------
message: list[str]
List of field elements represented as strings
Returns
-------
list[str]
List of field elements represented as strings
"""
...
def prove(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,proof_path:typing.Optional[str | os.PathLike | pathlib.Path],proof_type:str,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
r"""
Runs the prover on a set of inputs
Arguments
---------
witness: str
Path to the witness file
model: str
Path to the compiled model file
pk_path: str
Path to the proving key file
proof_path: str
Path to create the proof file
proof_type: str
Accepts `single`, `for-aggr`
srs_path: str
Path to the SRS file
Returns
-------
bool
"""
...
def setup(model:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],witness_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool) -> bool:
r"""
Runs the setup process
Arguments
---------
model: str
Path to the compiled model file
vk_path: str
Path to create the verification key file
pk_path: str
Path to create the proving key file
srs_path: str
Path to the SRS file
witness_path: str
Path to the witness file
disable_selector_compression: bool
Whether to compress the selectors or not
Returns
-------
bool
"""
...
def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,logrows:int,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool,commitment:PyCommitments) -> bool:
r"""
Runs the setup process for an aggregate setup
Arguments
---------
sample_snarks: list[str]
List of paths to the various proofs
vk_path: str
Path to create the aggregated VK
pk_path: str
Path to create the aggregated PK
logrows: int
Number of logrows to use
split_proofs: bool
Whether the accumulated are segments of a larger proof
srs_path: str
Path to the SRS file
disable_selector_compression: bool
Whether to compress selectors
commitment: str
Accepts `kzg`, `ipa`
Returns
-------
bool
"""
...
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
r"""
Setup test evm witness
Arguments
---------
data_path: str
The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
compiled_circuit_path: str
The path to the compiled model file (generated using the compile-circuit command)
test_data: str
For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
input_sources: str
Where the input data comes from
output_source: str
Where the output data comes from
rpc_url: str
RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
Returns
-------
bool
"""
...
def swap_proof_commitments(proof_path:str | os.PathLike | pathlib.Path,witness_path:str | os.PathLike | pathlib.Path) -> None:
r"""
Swap the commitments in a proof
Arguments
-------
proof_path: str
Path to the proof file
witness_path: str
Path to the witness file
"""
...
def table(model:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> str:
r"""
Displays the table as a string in python
Arguments
---------
model: str
Path to the onnx file
Returns
---------
str
Table of the nodes in the onnx file
"""
...
def verify(proof_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reduced_srs:bool) -> bool:
r"""
Verifies a given proof
Arguments
---------
proof_path: str
Path to create the proof file
settings_path: str
Path to the settings file
vk_path: str
Path to the verification key file
srs_path: str
Path to the SRS file
non_reduced_srs: bool
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
Returns
-------
bool
"""
...
def verify_aggr(proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,logrows:int,commitment:PyCommitments,reduced_srs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> bool:
r"""
Verifies and aggregate proof
Arguments
---------
proof_path: str
The path to the proof file
vk_path: str
The path to the verification key file
logrows: int
logrows used for aggregation circuit
commitment: str
Accepts "kzg" or "ipa"
reduced_srs: bool
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
srs_path: str
The path to the SRS file
Returns
-------
bool
"""
...
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],addr_da:typing.Optional[str],addr_vk:typing.Optional[str]) -> typing.Any:
r"""
verifies an evm compatible proof, you will need solc installed in your environment to run this
Arguments
---------
addr_verifier: str
The verifier contract's address as a hex string
proof_path: str
The path to the proof file (generated using the prove command)
rpc_url: str
RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
addr_da: str
does the verifier use data attestation ?
addr_vk: str
The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
Returns
-------
bool
"""
...

View File

@@ -12,6 +12,7 @@ asyncio_mode = "auto"
[project]
name = "ezkl"
version = "0.0.0"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",

9
src/bin/py_stub_gen.rs Normal file
View File

@@ -0,0 +1,9 @@
use pyo3_stub_gen::Result;
fn main() -> Result<()> {
// `stub_info` is a function defined by `define_stub_info_gatherer!` macro.
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
let stub = ezkl::bindings::python::stub_info()?;
stub.generate()?;
Ok(())
}

View File

@@ -4,6 +4,7 @@ use crate::circuit::modules::poseidon::{
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::circuit::InputType;
use crate::circuit::{CheckMode, Tolerance};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
@@ -26,7 +27,12 @@ use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;
use std::str::FromStr;
use std::{fs::File, path::PathBuf};
@@ -35,6 +41,7 @@ type PyFelt = String;
/// pyclass representing an enum
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
enum PyTestDataSource {
/// The data is loaded from a file
File,
@@ -54,6 +61,7 @@ impl From<PyTestDataSource> for TestDataSource {
/// pyclass containing the struct used for G1, this is mostly a helper class
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass]
struct PyG1 {
#[pyo3(get, set)]
/// Field Element representing x
@@ -100,6 +108,7 @@ impl pyo3::ToPyObject for PyG1 {
/// pyclass containing the struct used for G1
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass]
pub struct PyG1Affine {
#[pyo3(get, set)]
///
@@ -145,6 +154,7 @@ impl pyo3::ToPyObject for PyG1Affine {
///
#[pyclass]
#[derive(Clone)]
#[gen_stub_pyclass]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
@@ -259,6 +269,7 @@ impl Into<PyRunArgs> for RunArgs {
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
/// pyclass representing an enum, denoting the type of commitment
pub enum PyCommitments {
/// KZG commitment
@@ -306,6 +317,65 @@ impl FromStr for PyCommitments {
}
}
#[pyclass]
#[derive(Debug, Clone)]
#[gen_stub_pyclass_enum]
enum PyInputType {
///
Bool,
///
F16,
///
F32,
///
F64,
///
Int,
///
TDim,
}
impl From<InputType> for PyInputType {
fn from(input_type: InputType) -> Self {
match input_type {
InputType::Bool => PyInputType::Bool,
InputType::F16 => PyInputType::F16,
InputType::F32 => PyInputType::F32,
InputType::F64 => PyInputType::F64,
InputType::Int => PyInputType::Int,
InputType::TDim => PyInputType::TDim,
}
}
}
impl From<PyInputType> for InputType {
fn from(py_input_type: PyInputType) -> Self {
match py_input_type {
PyInputType::Bool => InputType::Bool,
PyInputType::F16 => InputType::F16,
PyInputType::F32 => InputType::F32,
PyInputType::F64 => InputType::F64,
PyInputType::Int => InputType::Int,
PyInputType::TDim => InputType::TDim,
}
}
}
impl FromStr for PyInputType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"bool" => Ok(PyInputType::Bool),
"f16" => Ok(PyInputType::F16),
"f32" => Ok(PyInputType::F32),
"f64" => Ok(PyInputType::F64),
"int" => Ok(PyInputType::Int),
"tdim" => Ok(PyInputType::TDim),
_ => Err("Invalid value for InputType".to_string()),
}
}
}
/// Converts a field element hex string to big endian
///
/// Arguments
@@ -322,6 +392,7 @@ impl FromStr for PyCommitments {
#[pyfunction(signature = (
felt,
))]
#[gen_stub_pyfunction]
fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
Ok(format!("{:?}", felt))
@@ -341,6 +412,7 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
#[pyfunction(signature = (
felt,
))]
#[gen_stub_pyfunction]
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_integer_rep(felt);
@@ -365,6 +437,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
felt,
scale
))]
#[gen_stub_pyfunction]
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_integer_rep(felt);
@@ -383,6 +456,9 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
/// scale: float
/// The scaling factor used to quantize the float into a field element
///
/// input_type: PyInputType
/// The type of the input
///
/// Returns
/// -------
/// str
@@ -390,9 +466,12 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
///
#[pyfunction(signature = (
input,
scale
scale,
input_type=PyInputType::F64
))]
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
#[gen_stub_pyfunction]
fn float_to_felt(mut input: f64, scale: crate::Scale, input_type: PyInputType) -> PyResult<PyFelt> {
InputType::roundtrip(&input_type.into(), &mut input);
let int_rep = quantize_float(&input, 0.0, scale)
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
let felt = integer_rep_to_felt(int_rep);
@@ -414,6 +493,7 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
#[pyfunction(signature = (
buffer
))]
#[gen_stub_pyfunction]
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
@@ -486,6 +566,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
#[pyfunction(signature = (
message,
))]
#[gen_stub_pyfunction]
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
@@ -531,6 +612,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
#[gen_stub_pyfunction]
fn kzg_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -589,6 +671,7 @@ fn kzg_commit(
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
#[gen_stub_pyfunction]
fn ipa_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -635,6 +718,7 @@ fn ipa_commit(
proof_path=PathBuf::from(DEFAULT_PROOF),
witness_path=PathBuf::from(DEFAULT_WITNESS),
))]
#[gen_stub_pyfunction]
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
@@ -664,6 +748,7 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_output_path=PathBuf::from(DEFAULT_VK),
))]
#[gen_stub_pyfunction]
fn gen_vk_from_pk_single(
path_to_pk: PathBuf,
circuit_settings_path: PathBuf,
@@ -701,6 +786,7 @@ fn gen_vk_from_pk_single(
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
#[gen_stub_pyfunction]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
@@ -730,6 +816,7 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
model = PathBuf::from(DEFAULT_MODEL),
py_run_args = None
))]
#[gen_stub_pyfunction]
fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into();
let mut reader = File::open(model).map_err(|_| PyIOError::new_err("Failed to open model"))?;
@@ -755,6 +842,7 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
srs_path,
logrows,
))]
#[gen_stub_pyfunction]
fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
let params = ezkl_gen_srs::<KZGCommitmentScheme<Bn256>>(logrows as u32);
save_params::<KZGCommitmentScheme<Bn256>>(&srs_path, &params)?;
@@ -787,6 +875,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
srs_path=None,
commitment=None,
))]
#[gen_stub_pyfunction]
fn get_srs(
py: Python,
settings_path: Option<PathBuf>,
@@ -799,7 +888,7 @@ fn get_srs(
None => None,
};
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::get_srs_cmd(srs_path, settings_path, logrows, commitment)
.await
.map_err(|e| {
@@ -833,6 +922,7 @@ fn get_srs(
output=PathBuf::from(DEFAULT_SETTINGS),
py_run_args = None,
))]
#[gen_stub_pyfunction]
fn gen_settings(
model: PathBuf,
output: PathBuf,
@@ -888,6 +978,7 @@ fn gen_settings(
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
))]
#[gen_stub_pyfunction]
fn calibrate_settings(
py: Python,
data: PathBuf,
@@ -899,7 +990,7 @@ fn calibrate_settings(
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::calibrate(
model,
data,
@@ -951,6 +1042,7 @@ fn calibrate_settings(
vk_path=None,
srs_path=None,
))]
#[gen_stub_pyfunction]
fn gen_witness(
py: Python,
data: PathBuf,
@@ -959,7 +1051,7 @@ fn gen_witness(
vk_path: Option<PathBuf>,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
.await
.map_err(|e| {
@@ -988,6 +1080,7 @@ fn gen_witness(
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
))]
#[gen_stub_pyfunction]
fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
crate::execute::mock(model, witness).map_err(|e| {
let err_str = format!("Failed to run mock: {}", e);
@@ -1018,6 +1111,7 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
split_proofs = false,
))]
#[gen_stub_pyfunction]
fn mock_aggregate(
aggregation_snarks: Vec<PathBuf>,
logrows: u32,
@@ -1065,6 +1159,7 @@ fn mock_aggregate(
witness_path = None,
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn setup(
model: PathBuf,
vk_path: PathBuf,
@@ -1123,6 +1218,7 @@ fn setup(
proof_type=ProofType::default(),
srs_path=None,
))]
#[gen_stub_pyfunction]
fn prove(
witness: PathBuf,
model: PathBuf,
@@ -1178,6 +1274,7 @@ fn prove(
srs_path=None,
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
))]
#[gen_stub_pyfunction]
fn verify(
proof_path: PathBuf,
settings_path: PathBuf,
@@ -1237,6 +1334,7 @@ fn verify(
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn setup_aggregate(
sample_snarks: Vec<PathBuf>,
vk_path: PathBuf,
@@ -1287,6 +1385,7 @@ fn setup_aggregate(
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
))]
#[gen_stub_pyfunction]
fn compile_circuit(
model: PathBuf,
compiled_circuit: PathBuf,
@@ -1346,6 +1445,7 @@ fn compile_circuit(
srs_path=None,
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn aggregate(
aggregation_snarks: Vec<PathBuf>,
proof_path: PathBuf,
@@ -1411,6 +1511,7 @@ fn aggregate(
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
srs_path=None,
))]
#[gen_stub_pyfunction]
fn verify_aggr(
proof_path: PathBuf,
vk_path: PathBuf,
@@ -1458,6 +1559,7 @@ fn verify_aggr(
calldata=PathBuf::from(DEFAULT_CALLDATA),
addr_vk=None,
))]
#[gen_stub_pyfunction]
fn encode_evm_calldata<'a>(
proof: PathBuf,
calldata: PathBuf,
@@ -1510,6 +1612,7 @@ fn encode_evm_calldata<'a>(
srs_path=None,
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn create_evm_verifier(
py: Python,
vk_path: PathBuf,
@@ -1519,7 +1622,7 @@ fn create_evm_verifier(
srs_path: Option<PathBuf>,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_verifier(
vk_path,
srs_path,
@@ -1569,6 +1672,7 @@ fn create_evm_verifier(
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None
))]
#[gen_stub_pyfunction]
fn create_evm_vka(
py: Python,
vk_path: PathBuf,
@@ -1577,7 +1681,7 @@ fn create_evm_vka(
abi_path: PathBuf,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, sol_code_path, abi_path)
.await
.map_err(|e| {
@@ -1616,6 +1720,7 @@ fn create_evm_vka(
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
witness_path=None,
))]
#[gen_stub_pyfunction]
fn create_evm_data_attestation(
py: Python,
input_data: PathBuf,
@@ -1624,7 +1729,7 @@ fn create_evm_data_attestation(
abi_path: PathBuf,
witness_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_data_attestation(
settings_path,
sol_code_path,
@@ -1676,6 +1781,7 @@ fn create_evm_data_attestation(
output_source,
rpc_url=None,
))]
#[gen_stub_pyfunction]
fn setup_test_evm_witness(
py: Python,
data_path: PathBuf,
@@ -1685,7 +1791,7 @@ fn setup_test_evm_witness(
output_source: PyTestDataSource,
rpc_url: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::setup_test_evm_witness(
data_path,
compiled_circuit_path,
@@ -1713,6 +1819,7 @@ fn setup_test_evm_witness(
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
))]
#[gen_stub_pyfunction]
fn deploy_evm(
py: Python,
addr_path: PathBuf,
@@ -1722,7 +1829,7 @@ fn deploy_evm(
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::deploy_evm(
sol_code_path,
rpc_url,
@@ -1751,6 +1858,7 @@ fn deploy_evm(
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None
))]
#[gen_stub_pyfunction]
fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
@@ -1761,7 +1869,7 @@ fn deploy_da_evm(
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::deploy_da_evm(
input_data,
settings_path,
@@ -1809,6 +1917,7 @@ fn deploy_da_evm(
addr_da = None,
addr_vk = None,
))]
#[gen_stub_pyfunction]
fn verify_evm<'a>(
py: Python<'a>,
addr_verifier: &'a str,
@@ -1831,7 +1940,7 @@ fn verify_evm<'a>(
None
};
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk)
.await
.map_err(|e| {
@@ -1881,6 +1990,7 @@ fn verify_evm<'a>(
srs_path=None,
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
))]
#[gen_stub_pyfunction]
fn create_evm_verifier_aggr(
py: Python,
aggregation_settings: Vec<PathBuf>,
@@ -1891,7 +2001,7 @@ fn create_evm_verifier_aggr(
srs_path: Option<PathBuf>,
reusable: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::create_evm_aggregate_verifier(
vk_path,
srs_path,
@@ -1911,15 +2021,19 @@ fn create_evm_verifier_aggr(
})
}
// Define a function to gather stub information.
define_stub_info_gatherer!(stub_info);
// Python Module
#[pymodule]
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init();
m.add_class::<PyRunArgs>()?;
m.add_class::<PyG1Affine>()?;
m.add_class::<PyG1>()?;
m.add_class::<PyTestDataSource>()?;
m.add_class::<PyCommitments>()?;
m.add_class::<PyInputType>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
@@ -1958,3 +2072,48 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
Ok(())
}
impl pyo3_stub_gen::PyStubType for CalibrationTarget {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for ProofType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for TranscriptType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for CheckMode {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}
impl pyo3_stub_gen::PyStubType for ContractType {
fn type_output() -> TypeInfo {
TypeInfo {
name: "str".to_string(),
import: HashSet::new(),
}
}
}

View File

@@ -141,10 +141,11 @@ pub(crate) fn gen_vk(
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
let mut serialized_vk = Vec::new();
vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes)
.map_err(|e| {
EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e))
})?;
vk.write(
&mut serialized_vk,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
Ok(serialized_vk)
}
@@ -165,7 +166,7 @@ pub(crate) fn gen_pk(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
@@ -197,7 +198,7 @@ pub(crate) fn verify(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings.clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
@@ -277,7 +278,7 @@ pub(crate) fn verify_aggr(
let mut reader = BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
@@ -365,7 +366,7 @@ pub(crate) fn prove(
let mut reader = BufReader::new(&pk[..]);
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit.settings().clone(),
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
@@ -487,7 +488,7 @@ pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
let mut reader = BufReader::new(&vk[..]);
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
@@ -504,7 +505,7 @@ pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
let mut reader = BufReader::new(&pk[..]);
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
halo2_proofs::SerdeFormat::RawBytesUnchecked,
circuit_settings,
)
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;

View File

@@ -22,6 +22,7 @@ use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::PrimeField,
};
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
@@ -113,9 +114,15 @@ pub fn feltToFloat(
#[wasm_bindgen]
#[allow(non_snake_case)]
pub fn floatToFelt(
input: f64,
mut input: f64,
scale: crate::Scale,
input_type: &str,
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
crate::circuit::InputType::roundtrip(
&crate::circuit::InputType::from_str(input_type)
.map_err(|e| JsError::new(&format!("{}", e)))?,
&mut input,
);
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = integer_rep_to_felt(int_rep);

View File

@@ -8,10 +8,9 @@ use halo2_proofs::{
use log::debug;
#[cfg(feature = "python-bindings")]
use pyo3::{
conversion::{FromPyObject, PyTryFrom},
conversion::{FromPyObject, IntoPy},
exceptions::PyValueError,
prelude::*,
types::PyString,
};
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
@@ -139,10 +138,9 @@ impl IntoPy<PyObject> for CheckMode {
#[cfg(feature = "python-bindings")]
/// Obtains CheckMode from PyObject (Required for CheckMode to be compatible with Python)
impl<'source> FromPyObject<'source> for CheckMode {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let trystr = String::extract_bound(ob)?;
match trystr.to_lowercase().as_str() {
"safe" => Ok(CheckMode::SAFE),
"unsafe" => Ok(CheckMode::UNSAFE),
_ => Err(PyValueError::new_err("Invalid value for CheckMode")),
@@ -161,8 +159,8 @@ impl IntoPy<PyObject> for Tolerance {
#[cfg(feature = "python-bindings")]
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
impl<'source> FromPyObject<'source> for Tolerance {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if let Ok((val, scale)) = ob.extract::<(f32, f32)>() {
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
Ok(Tolerance {
val,
scale: utils::F32(scale),

View File

@@ -97,4 +97,7 @@ pub enum CircuitError {
/// Invalid scale
#[error("negative scale for an op that requires positive inputs {0}")]
NegativeScale(String),
#[error("invalid input type {0}")]
/// Invalid input type
InvalidInputType(String),
}

View File

@@ -1687,6 +1687,7 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
Ok(output.into())
}
// assumes unique values in fullset
pub(crate) fn get_missing_set_elements<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
@@ -1699,6 +1700,8 @@ pub(crate) fn get_missing_set_elements<
let set_len = fullset.len();
input.flatten();
// while fullset is less than len of input concat
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {

View File

@@ -105,7 +105,10 @@ impl InputType {
}
///
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone>(&self, input: &mut T) {
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone + std::fmt::Debug>(
&self,
input: &mut T,
) {
match self {
InputType::Bool => {
let boolean_input = input.clone().to_i64().unwrap();
@@ -118,7 +121,7 @@ impl InputType {
*input = T::from_f32(f32_input).unwrap();
}
InputType::F32 => {
let f32_input = input.clone().to_f32().unwrap();
let f32_input: f32 = input.clone().to_f32().unwrap();
*input = T::from_f32(f32_input).unwrap();
}
InputType::F64 => {
@@ -133,6 +136,22 @@ impl InputType {
}
}
impl std::str::FromStr for InputType {
type Err = CircuitError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"bool" => Ok(InputType::Bool),
"f16" => Ok(InputType::F16),
"f32" => Ok(InputType::F32),
"f64" => Ok(InputType::F64),
"int" => Ok(InputType::Int),
"tdim" => Ok(InputType::TDim),
e => Err(CircuitError::InvalidInputType(e.to_string())),
}
}
}
///
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Input {

View File

@@ -211,7 +211,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.min_lookup_inputs().to_string().green(),
self.max_range_size().to_string().green(),
self.dynamic_lookup_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green(),
self.max_dynamic_input_len().to_string().green()
);
}
@@ -474,7 +474,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
Ok(())
}
/// Update the max and min forcefully
/// Update the max and min forcefully
pub fn update_max_min_lookup_inputs_force(
&mut self,
min: IntegerRep,
@@ -611,7 +611,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<(ValTensor<F>, usize), CircuitError> {
self.update_max_dynamic_input_len(values.len());
if let Some(region) = &self.region {

View File

@@ -2,12 +2,7 @@ use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{
conversion::{FromPyObject, PyTryFrom},
exceptions::PyValueError,
prelude::*,
types::PyString,
};
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::str::FromStr;
@@ -109,8 +104,8 @@ impl IntoPy<PyObject> for TranscriptType {
#[cfg(feature = "python-bindings")]
/// Obtains TranscriptType from PyObject (Required for TranscriptType to be compatible with Python)
impl<'source> FromPyObject<'source> for TranscriptType {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let trystr = String::extract_bound(ob)?;
let strval = trystr.to_string();
match strval.to_lowercase().as_str() {
"poseidon" => Ok(TranscriptType::Poseidon),
@@ -196,9 +191,7 @@ pub enum ContractType {
impl Default for ContractType {
fn default() -> Self {
ContractType::Verifier {
reusable: false,
}
ContractType::Verifier { reusable: false }
}
}
@@ -210,10 +203,8 @@ impl std::fmt::Display for ContractType {
match self {
ContractType::Verifier { reusable: true } => {
"verifier/reusable".to_string()
},
ContractType::Verifier {
reusable: false,
} => "verifier".to_string(),
}
ContractType::Verifier { reusable: false } => "verifier".to_string(),
ContractType::VerifyingKeyArtifact => "vka".to_string(),
}
)
@@ -241,7 +232,6 @@ impl From<&str> for ContractType {
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
/// wrapper for H160 to make it easy to parse into flag vals
pub struct H160Flag {
@@ -287,9 +277,8 @@ impl IntoPy<PyObject> for CalibrationTarget {
#[cfg(feature = "python-bindings")]
/// Obtains CalibrationTarget from PyObject (Required for CalibrationTarget to be compatible with Python)
impl<'source> FromPyObject<'source> for CalibrationTarget {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let strval = String::extract_bound(ob)?;
match strval.to_lowercase().as_str() {
"resources" => Ok(CalibrationTarget::Resources {
col_overflow: false,
@@ -306,12 +295,8 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
impl IntoPy<PyObject> for ContractType {
fn into_py(self, py: Python) -> PyObject {
match self {
ContractType::Verifier { reusable: true } => {
"verifier/reusable".to_object(py)
}
ContractType::Verifier {
reusable: false,
} => "verifier".to_object(py),
ContractType::Verifier { reusable: true } => "verifier/reusable".to_object(py),
ContractType::Verifier { reusable: false } => "verifier".to_object(py),
ContractType::VerifyingKeyArtifact => "vka".to_object(py),
}
}
@@ -320,13 +305,10 @@ impl IntoPy<PyObject> for ContractType {
#[cfg(feature = "python-bindings")]
/// Obtains ContractType from PyObject (Required for ContractType to be compatible with Python)
impl<'source> FromPyObject<'source> for ContractType {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let strval = String::extract_bound(ob)?;
match strval.to_lowercase().as_str() {
"verifier" => Ok(ContractType::Verifier {
reusable: false,
}),
"verifier" => Ok(ContractType::Verifier { reusable: false }),
"verifier/reusable" => Ok(ContractType::Verifier { reusable: true }),
"vka" => Ok(ContractType::VerifyingKeyArtifact),
_ => Err(PyValueError::new_err("Invalid value for ContractType")),
@@ -341,45 +323,45 @@ pub fn get_styles() -> clap::builder::Styles {
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
.fg_color(Some(clap::builder::styling::Color::Ansi(
clap::builder::styling::AnsiColor::Cyan,
))),
)
.header(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
)
.literal(
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta))),
)
.invalid(
clap::builder::styling::Style::new()
.bold()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
)
.error(
clap::builder::styling::Style::new()
.bold()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
.fg_color(Some(clap::builder::styling::Color::Ansi(
clap::builder::styling::AnsiColor::Cyan,
))),
)
.literal(clap::builder::styling::Style::new().fg_color(Some(
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta),
)))
.invalid(clap::builder::styling::Style::new().bold().fg_color(Some(
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red),
)))
.error(clap::builder::styling::Style::new().bold().fg_color(Some(
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red),
)))
.valid(
clap::builder::styling::Style::new()
.bold()
.underline()
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Green))),
)
.placeholder(
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White))),
.fg_color(Some(clap::builder::styling::Color::Ansi(
clap::builder::styling::AnsiColor::Green,
))),
)
.placeholder(clap::builder::styling::Style::new().fg_color(Some(
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White),
)))
}
/// Print completions for the given generator
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
}
#[allow(missing_docs)]
#[derive(Parser, Debug, Clone)]
#[command(author, about, long_about = None)]
@@ -393,7 +375,6 @@ pub struct Cli {
pub command: Option<Commands>,
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
@@ -443,7 +424,7 @@ pub enum Commands {
},
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
CalibrateSettings {
CalibrateSettings {
/// The path to the .json calibration data file.
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
data: Option<PathBuf>,
@@ -490,7 +471,7 @@ pub enum Commands {
commitment: Option<Commitments>,
},
/// Gets an SRS from a circuit settings file.
/// Gets an SRS from a circuit settings file.
#[command(name = "get-srs")]
GetSrs {
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
@@ -575,7 +556,7 @@ pub enum Commands {
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::default(),
value_enum,
value_enum,
value_hint = clap::ValueHint::Other
)]
transcript: TranscriptType,
@@ -625,7 +606,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
disable_selector_compression: Option<bool>,
},
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
SetupTestEvmData {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
@@ -649,7 +630,7 @@ pub enum Commands {
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
output_source: TestDataSource,
},
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
#[command(arg_required_else_help = true)]
TestUpdateAccountCalls {
/// The path to the verifier contract's address
@@ -662,7 +643,7 @@ pub enum Commands {
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
},
/// Swaps the positions in the transcript that correspond to commitments
/// Swaps the positions in the transcript that correspond to commitments
SwapProofCommitments {
/// The path to the proof file
#[arg(short = 'P', long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
@@ -672,7 +653,7 @@ pub enum Commands {
witness_path: Option<PathBuf>,
},
/// Loads model, data, and creates proof
/// Loads model, data, and creates proof
Prove {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
@@ -694,7 +675,7 @@ pub enum Commands {
require_equals = true,
num_args = 0..=1,
default_value_t = ProofType::Single,
value_enum,
value_enum,
value_hint = clap::ValueHint::Other
)]
proof_type: ProofType,
@@ -702,7 +683,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
check_mode: Option<CheckMode>,
},
/// Encodes a proof into evm calldata
/// Encodes a proof into evm calldata
#[command(name = "encode-evm-calldata")]
EncodeEvmCalldata {
/// The path to the proof file (generated using the prove command)
@@ -715,7 +696,7 @@ pub enum Commands {
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
/// Creates an Evm verifier for a single proof
/// Creates an Evm verifier for a single proof
#[command(name = "create-evm-verifier")]
CreateEvmVerifier {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
@@ -737,7 +718,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
reusable: Option<bool>,
},
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
#[command(name = "create-evm-vka")]
CreateEvmVKArtifact {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
@@ -756,7 +737,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
abi_path: Option<PathBuf>,
},
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
#[command(name = "create-evm-da")]
CreateEvmDataAttestation {
/// The path to load circuit settings .json file from (generated using the gen-settings command)
@@ -780,7 +761,7 @@ pub enum Commands {
witness: Option<PathBuf>,
},
/// Creates an Evm verifier for an aggregate proof
/// Creates an Evm verifier for an aggregate proof
#[command(name = "create-evm-verifier-aggr")]
CreateEvmVerifierAggr {
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
@@ -844,7 +825,7 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
commitment: Option<Commitments>,
},
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
DeployEvm {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
@@ -865,7 +846,7 @@ pub enum Commands {
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
contract: ContractType,
},
/// Deploys an evm verifier that allows for data attestation
/// Deploys an evm verifier that allows for data attestation
#[command(name = "deploy-evm-da")]
DeployEvmDataAttestation {
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
@@ -890,7 +871,7 @@ pub enum Commands {
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
private_key: Option<String>,
},
/// Verifies a proof using a local Evm executor, returning accept or reject
/// Verifies a proof using a local Evm executor, returning accept or reject
#[command(name = "verify-evm")]
VerifyEvm {
/// The path to the proof file (generated using the prove command)
@@ -918,7 +899,6 @@ pub enum Commands {
},
}
impl Commands {
/// Converts the commands to a json string
pub fn as_json(&self) -> String {
@@ -929,4 +909,4 @@ impl Commands {
pub fn from_json(json: &str) -> Self {
serde_json::from_str(json).unwrap()
}
}
}

View File

@@ -488,7 +488,8 @@ pub async fn deploy_da_verifier_via_solidity(
}
}
let contract = match call_to_account {
match call_to_account {
Some(call) => {
deploy_single_da_contract(
client,
@@ -514,8 +515,7 @@ pub async fn deploy_da_verifier_via_solidity(
)
.await
}
};
return contract;
}
}
async fn deploy_multi_da_contract(
@@ -630,7 +630,7 @@ async fn deploy_single_da_contract(
// bytes memory _callData,
PackedSeqToken(call_data.as_ref()),
// uint256 _decimals,
WordToken(B256::from(decimals).into()),
WordToken(B256::from(decimals)),
// uint[] memory _scales,
DynSeqToken(
scales

View File

@@ -712,7 +712,8 @@ impl ToPyObject for DataSource {
DataSource::OnChain(source) => {
let dict = PyDict::new(py);
dict.set_item("rpc_url", &source.rpc).unwrap();
dict.set_item("calls_to_accounts", &source.calls).unwrap();
dict.set_item("calls_to_accounts", &source.calls.to_object(py))
.unwrap();
dict.to_object(py)
}
DataSource::DB(source) => {

View File

@@ -60,7 +60,10 @@ use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
pub use utilities::*;
@@ -343,10 +346,10 @@ impl ToPyObject for GraphWitness {
if let Some(processed_inputs) = &self.processed_inputs {
//poseidon_hash
if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash {
insert_poseidon_hash_pydict(dict_inputs, processed_inputs_poseidon_hash).unwrap();
insert_poseidon_hash_pydict(&dict_inputs, processed_inputs_poseidon_hash).unwrap();
}
if let Some(processed_inputs_polycommit) = &processed_inputs.polycommit {
insert_polycommit_pydict(dict_inputs, processed_inputs_polycommit).unwrap();
insert_polycommit_pydict(&dict_inputs, processed_inputs_polycommit).unwrap();
}
dict.set_item("processed_inputs", dict_inputs).unwrap();
@@ -354,10 +357,10 @@ impl ToPyObject for GraphWitness {
if let Some(processed_params) = &self.processed_params {
if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash {
insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash).unwrap();
insert_poseidon_hash_pydict(&dict_params, processed_params_poseidon_hash).unwrap();
}
if let Some(processed_params_polycommit) = &processed_params.polycommit {
insert_polycommit_pydict(dict_inputs, processed_params_polycommit).unwrap();
insert_polycommit_pydict(&dict_params, processed_params_polycommit).unwrap();
}
dict.set_item("processed_params", dict_params).unwrap();
@@ -365,10 +368,11 @@ impl ToPyObject for GraphWitness {
if let Some(processed_outputs) = &self.processed_outputs {
if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash {
insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash).unwrap();
insert_poseidon_hash_pydict(&dict_outputs, processed_outputs_poseidon_hash)
.unwrap();
}
if let Some(processed_outputs_polycommit) = &processed_outputs.polycommit {
insert_polycommit_pydict(dict_inputs, processed_outputs_polycommit).unwrap();
insert_polycommit_pydict(&dict_outputs, processed_outputs_polycommit).unwrap();
}
dict.set_item("processed_outputs", dict_outputs).unwrap();
@@ -379,7 +383,10 @@ impl ToPyObject for GraphWitness {
}
#[cfg(feature = "python-bindings")]
fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Result<(), PyErr> {
fn insert_poseidon_hash_pydict(
pydict: &Bound<'_, PyDict>,
poseidon_hash: &Vec<Fp>,
) -> Result<(), PyErr> {
let poseidon_hash: Vec<String> = poseidon_hash.iter().map(field_to_string).collect();
pydict.set_item("poseidon_hash", poseidon_hash)?;
@@ -387,7 +394,10 @@ fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Resu
}
#[cfg(feature = "python-bindings")]
fn insert_polycommit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
fn insert_polycommit_pydict(
pydict: &Bound<'_, PyDict>,
commits: &Vec<Vec<G1Affine>>,
) -> Result<(), PyErr> {
use crate::bindings::python::PyG1Affine;
let poseidon_hash: Vec<Vec<PyG1Affine>> = commits
.iter()

View File

@@ -656,7 +656,7 @@ impl Model {
let mut symbol_values = SymbolValues::default();
for (symbol, value) in run_args.variables.iter() {
let symbol = model.symbol_table.sym(symbol);
let symbol = model.symbols.sym(symbol);
symbol_values = symbol_values.with(&symbol, *value as i64);
debug!("set {} to {}", symbol, value);
}
@@ -1199,9 +1199,9 @@ impl Model {
// Then number of columns in the circuits
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
trace!("input indices: {:?}", node.inputs());
trace!("output scales: {:?}", node.out_scales());
trace!(
"input scales: {:?}",
node.inputs()
.iter()
@@ -1220,8 +1220,8 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!("output dims: {:?}", node.out_dims());
debug!(
trace!("output dims: {:?}", node.out_dims());
trace!(
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);

View File

@@ -142,8 +142,6 @@ use tract_onnx::prelude::SymbolValues;
pub fn extract_tensor_value(
input: Arc<tract_onnx::prelude::Tensor>,
) -> Result<Tensor<f32>, GraphError> {
use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
let dt = input.datum_type();
let dims = input.shape().to_vec();
@@ -156,7 +154,7 @@ pub fn extract_tensor_value(
match dt {
DatumType::F16 => {
let vec = input.as_slice::<tract_onnx::prelude::f16>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| (*x).into()).collect();
let cast: Vec<f32> = vec.iter().map(|x| (*x).into()).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::F32 => {
@@ -165,61 +163,61 @@ pub fn extract_tensor_value(
}
DatumType::F64 => {
let vec = input.as_slice::<f64>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I64 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i64>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I32 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i32>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I16 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i16>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::I8 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<i8>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U8 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u8>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U16 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u16>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U32 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u32>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U64 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u64>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::Bool => {
// Generally a shape or hyperparam
let vec = input.as_slice::<bool>()?.to_vec();
let cast: Vec<f32> = vec.par_iter().map(|x| *x as usize as f32).collect();
let cast: Vec<f32> = vec.iter().map(|x| *x as usize as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::TDim => {
@@ -227,7 +225,7 @@ pub fn extract_tensor_value(
let vec = input.as_slice::<tract_onnx::prelude::TDim>()?.to_vec();
let cast: Result<Vec<f32>, GraphError> = vec
.par_iter()
.iter()
.map(|x| match x.to_i64() {
Ok(v) => Ok(v as f32),
Err(_) => match x.to_i64() {
@@ -1007,21 +1005,21 @@ pub fn new_op_from_onnx(
op
}
"Iff" => SupportedOp::Linear(PolyOp::Iff),
"Less" => {
"<" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Less)
} else {
return Err(GraphError::InvalidDims(idx, "less".to_string()));
}
}
"LessEqual" => {
"<=" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::LessEqual)
} else {
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
}
}
"Greater" => {
">" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Greater)
@@ -1029,7 +1027,7 @@ pub fn new_op_from_onnx(
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
}
}
"GreaterEqual" => {
">=" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::GreaterEqual)
@@ -1136,23 +1134,56 @@ pub fn new_op_from_onnx(
a: crate::circuit::utils::F32(exponent),
})
}
} else if let Some(c) = inputs[0].opkind().get_mutable_constant() {
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar base")
}
let base = c.raw_values[0];
SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[1]).into(),
base: base.into(),
})
} else {
if let Some(c) = inputs[0].opkind().get_mutable_constant() {
inputs[0].decrement_use();
deleted_indices.push(0);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar base")
}
unimplemented!("only support constant base or pow for now")
}
}
"Div" => {
let const_idx = inputs
.iter()
.enumerate()
.filter(|(_, n)| n.is_constant())
.map(|(i, _)| i)
.collect::<Vec<_>>();
let base = c.raw_values[0];
if const_idx.len() > 1 {
return Err(GraphError::InvalidDims(idx, "div".to_string()));
}
SupportedOp::Nonlinear(LookupOp::Exp {
scale: scale_to_multiplier(input_scales[1]).into(),
base: base.into(),
let const_idx = const_idx[0];
if const_idx != 1 {
unimplemented!("only support div with constant as second input")
}
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
if c.raw_values.len() == 1 && c.raw_values[0] != 0. {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
// get the non constant index
let denom = c.raw_values[0];
SupportedOp::Hybrid(HybridOp::Div {
denom: denom.into(),
})
} else {
unimplemented!("only support constant base or pow for now")
unimplemented!("only support non zero divisors of size 1")
}
} else {
unimplemented!("only support div with constant as second input")
}
}
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
@@ -1215,7 +1246,7 @@ pub fn new_op_from_onnx(
"And" => SupportedOp::Linear(PolyOp::And),
"Or" => SupportedOp::Linear(PolyOp::Or),
"Xor" => SupportedOp::Linear(PolyOp::Xor),
"Equals" => SupportedOp::Hybrid(HybridOp::Equals),
"==" => SupportedOp::Hybrid(HybridOp::Equals),
"Deconv" => {
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
Some(b) => b,

View File

@@ -9,8 +9,7 @@ use itertools::Itertools;
use log::debug;
#[cfg(feature = "python-bindings")]
use pyo3::{
exceptions::PyValueError, types::PyString, FromPyObject, IntoPy, PyAny, PyObject, PyResult,
PyTryFrom, Python, ToPyObject,
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
};
use serde::{Deserialize, Serialize};
@@ -137,10 +136,8 @@ impl IntoPy<PyObject> for Visibility {
#[cfg(feature = "python-bindings")]
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
impl<'source> FromPyObject<'source> for Visibility {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
let strval = String::extract_bound(ob)?;
let strval = strval.as_str();
if strval.contains("hashed/private") {

View File

@@ -120,7 +120,7 @@ pub fn version() -> &'static str {
}
}
/// Bindings managment
/// Bindings management
#[cfg(any(
feature = "ios-bindings",
all(target_arch = "wasm32", target_os = "unknown"),

View File

@@ -46,6 +46,9 @@ use thiserror::Error as thisError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tosubcommand::ToFlags;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
@@ -116,9 +119,8 @@ impl ToPyObject for ProofType {
#[cfg(feature = "python-bindings")]
/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python)
impl<'source> pyo3::FromPyObject<'source> for ProofType {
fn extract(ob: &'source pyo3::PyAny) -> pyo3::PyResult<Self> {
let trystr = <pyo3::types::PyString as pyo3::PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let strval = String::extract_bound(ob)?;
match strval.to_lowercase().as_str() {
"single" => Ok(ProofType::Single),
"for-aggr" => Ok(ProofType::ForAggr),
@@ -174,9 +176,8 @@ impl pyo3::IntoPy<PyObject> for StrategyType {
#[cfg(feature = "python-bindings")]
/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python)
impl<'source> pyo3::FromPyObject<'source> for StrategyType {
fn extract(ob: &'source pyo3::PyAny) -> pyo3::PyResult<Self> {
let trystr = <pyo3::types::PyString as pyo3::PyTryFrom>::try_from(ob)?;
let strval = trystr.to_string();
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> pyo3::PyResult<Self> {
let strval = String::extract_bound(ob)?;
match strval.to_lowercase().as_str() {
"single" => Ok(StrategyType::Single),
"accum" => Ok(StrategyType::Accum),
@@ -235,7 +236,7 @@ impl ToPyObject for TranscriptType {
#[cfg(feature = "python-bindings")]
///
pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) {
pub fn g1affine_to_pydict(g1affine_dict: &pyo3::Bound<'_, PyDict>, g1affine: &G1Affine) {
let g1affine_x = field_to_string(&g1affine.x);
let g1affine_y = field_to_string(&g1affine.y);
g1affine_dict.set_item("x", g1affine_x).unwrap();
@@ -246,7 +247,7 @@ pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) {
use halo2curves::bn256::G1;
#[cfg(feature = "python-bindings")]
///
pub fn g1_to_pydict(g1_dict: &PyDict, g1: &G1) {
pub fn g1_to_pydict(g1_dict: &pyo3::Bound<'_, PyDict>, g1: &G1) {
let g1_x = field_to_string(&g1.x);
let g1_y = field_to_string(&g1.y);
let g1_z = field_to_string(&g1.z);
@@ -337,7 +338,7 @@ where
dict.set_item("instances", field_elems).unwrap();
let hex_proof = hex::encode(&self.proof);
dict.set_item("proof", format!("0x{}", hex_proof)).unwrap();
dict.set_item("transcript_type", self.transcript_type)
dict.set_item("transcript_type", self.transcript_type.to_object(py))
.unwrap();
dict.to_object(py)
}

View File

@@ -638,42 +638,44 @@ impl<T: Clone + TensorType> Tensor<T> {
where
T: Send + Sync,
{
if indices.is_empty() {
// Fast path: empty indices or full tensor slice
if indices.is_empty()
|| indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims
{
return Ok(self.clone());
}
// Validate dimensions
if self.dims.len() < indices.len() {
return Err(TensorError::DimError(format!(
"The dimensionality of the slice {:?} is greater than the tensor's {:?}",
indices, self.dims
)));
} else if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims {
// else if slice is the same as dims, return self
return Ok(self.clone());
}
// if indices weren't specified we fill them in as required
let mut full_indices = indices.to_vec();
// Pre-allocate the full indices vector with capacity
let mut full_indices = Vec::with_capacity(self.dims.len());
full_indices.extend_from_slice(indices);
for i in 0..(self.dims.len() - indices.len()) {
full_indices.push(0..self.dims()[indices.len() + i])
}
// Fill remaining dimensions
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
let cartesian_coord: Vec<Vec<usize>> = full_indices
// Pre-calculate total size and allocate result vector
let total_size: usize = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
let res: Vec<T> = cartesian_coord
.par_iter()
.map(|e| {
let index = self.get_index(e);
self[index].clone()
})
.collect();
.map(|range| range.end - range.start)
.product();
let mut res = Vec::with_capacity(total_size);
// Calculate new dimensions once
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
// Use iterator directly without collecting into intermediate Vec
for coord in full_indices.iter().cloned().multi_cartesian_product() {
let index = self.get_index(&coord);
res.push(self[index].clone());
}
Tensor::new(Some(&res), &dims)
}
@@ -1109,6 +1111,13 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
pub fn expand(&self, shape: &[usize]) -> Result<Self, TensorError> {
// if both have length 1 then we can just return the tensor
if self.dims().iter().product::<usize>() == 1 && shape.iter().product::<usize>() == 1 {
let mut output = self.clone();
output.reshape(shape)?;
return Ok(output);
}
if self.dims().len() > shape.len() {
return Err(TensorError::DimError(format!(
"Cannot expand {:?} to the smaller shape {:?}",

View File

@@ -1050,6 +1050,7 @@ pub fn scatter_nd<T: TensorType + Send + Sync>(
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let index_val = index.get_slice(&slice)?;
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let src_val = src.get_slice(&slice)?;
output.set_slice(&index_slice, &src_val)?;
Ok::<_, TensorError>(())

View File

@@ -1,12 +1,12 @@
use crate::{circuit::region::ConstantsMap, fieldutils::felt_to_integer_rep};
use maybe_rayon::slice::Iter;
use maybe_rayon::slice::{Iter, ParallelSlice};
use super::{
ops::{intercalate_values, pad, resize},
*,
};
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
use maybe_rayon::iter::{FilterMap, IntoParallelIterator, ParallelIterator};
use maybe_rayon::iter::{FilterMap, ParallelIterator};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
@@ -455,7 +455,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
}
/// Returns the number of constants in the [ValTensor].
/// Returns an iterator over the [ValTensor]'s constants.
pub fn create_constants_map_iterator(
&self,
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
@@ -473,20 +473,48 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
}
/// Returns the number of constants in the [ValTensor].
/// Returns a map of the constants in the [ValTensor].
pub fn create_constants_map(&self) -> ConstantsMap<F> {
match self {
ValTensor::Value { inner, .. } => inner
.par_iter()
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
let threshold = 1_000_000; // Tuned using the benchmarks
if self.len() < threshold {
match self {
ValTensor::Value { inner, .. } => inner
.par_iter()
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
}
} else {
// Use parallel for larger arrays
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match self {
ValTensor::Value { inner, .. } => inner
.par_chunks(chunk_size)
.flat_map(|chunk| {
chunk
.par_iter() // Make sure we use par_iter() here
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
})
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
}
}
}
@@ -878,70 +906,161 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// remove constant zero values constants
pub fn remove_const_zero_values(&mut self) {
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_par_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
let size_threshold = 1_000_000; // Tuned using the benchmarks
if self.len() < size_threshold {
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
Some(e)
})
.collect();
*dims = v.dims().to_vec();
Some(e)
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
} else {
// Use parallel for larger arrays
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.par_chunks_mut(chunk_size)
.flat_map(|chunk| {
chunk
.par_iter_mut() // Make sure we use par_iter() here
.filter_map(|e| {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return None;
}
}
Some(e.clone())
})
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
ValTensor::Instance { .. } => {}
}
}
/// gets constants
/// filter constant zero values constants
pub fn get_const_zero_indices(&self) -> Vec<usize> {
match self {
ValTensor::Value { inner: v, .. } => v
.par_iter()
.enumerate()
.filter_map(|(i, e)| {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return Some(i);
let size_threshold = 1_000_000; // Tuned using the benchmarks
if self.len() < size_threshold {
// Use single-threaded for smaller arrays
match &self {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.filter_map(|(i, e)| {
match e {
// Combine both match arms to reduce branching
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return Some(i);
}
}
None
})
.collect(),
ValTensor::Instance { .. } => vec![],
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
} else {
// Use parallel for larger arrays
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match &self {
ValTensor::Value { inner: v, .. } => v
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter() // Make sure we use par_iter() here
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>(),
ValTensor::Instance { .. } => vec![],
}
}
}
/// gets constants
/// gets constant indices
pub fn get_const_indices(&self) -> Vec<usize> {
match self {
ValTensor::Value { inner: v, .. } => v
.par_iter()
.enumerate()
.filter_map(|(i, e)| {
if let ValType::Constant(_) = e {
Some(i)
} else if let ValType::AssignedConstant(_, _) = e {
Some(i)
} else {
None
}
})
.collect(),
ValTensor::Instance { .. } => vec![],
let size_threshold = 1_000_000; // Tuned using the benchmarks
if self.len() < size_threshold {
// Use single-threaded for smaller arrays
match &self {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.filter_map(|(i, e)| {
match e {
// Combine both match arms to reduce branching
ValType::Constant(_) | ValType::AssignedConstant(_, _) => Some(i),
_ => None,
}
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
} else {
// Use parallel for larger arrays
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match &self {
ValTensor::Value { inner: v, .. } => v
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter() // Make sure we use par_iter() here
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(_) | ValType::AssignedConstant(_, _) => {
Some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>(),
ValTensor::Instance { .. } => vec![],
}
}
}

View File

@@ -492,7 +492,7 @@ mod native_tests {
#[cfg(feature="icicle")]
seq!(N in 0..=2 {
#(#[test_case(TESTS_AGGR[N])])*
fn aggr_prove_and_verify_(test: &str) {
fn kzg_aggr_prove_and_verify_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(test_dir.path().to_str().unwrap(), test);
@@ -636,7 +636,7 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
// currently variable output rank is not supported in ONNX
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" {
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" && test != "scatter_nd" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);

View File

@@ -68,6 +68,8 @@ mod py_tests {
"install",
"torch-geometric==2.5.2",
"torch==2.2.2",
"datasets==3.2.0",
"torchtext==0.17.2",
"torchvision==0.17.2",
"pandas==2.2.1",
"numpy==1.26.4",
@@ -189,6 +191,27 @@ mod py_tests {
anvil_child.kill().unwrap();
}
});
#[test]
fn neural_bag_of_words_notebook() {
crate::py_tests::init_binary();
let test_dir: TempDir = TempDir::new("neural_bow").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "neural_bow.ipynb");
run_notebook(path, "neural_bow.ipynb");
test_dir.close().unwrap();
}
#[test]
fn felt_conversion_test_notebook() {
crate::py_tests::init_binary();
let test_dir: TempDir = TempDir::new("felt_conversion_test").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "felt_conversion_test.ipynb");
run_notebook(path, "felt_conversion_test.ipynb");
test_dir.close().unwrap();
}
#[test]
fn voice_notebook_() {
crate::py_tests::init_binary();