mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1ce8c88d0 | ||
|
|
876a9584a1 | ||
|
|
7d7f049cc4 | ||
|
|
96f3fd94b2 | ||
|
|
6263510c56 | ||
|
|
f5b8ae3213 | ||
|
|
b2e4e414f0 | ||
|
|
0b0199e2b7 | ||
|
|
5e169bdd17 | ||
|
|
64cbcb3f7e |
@@ -2,3 +2,16 @@
|
||||
runner = 'wasm-bindgen-test-runner'
|
||||
rustflags = ["-C", "target-feature=+atomics,+bulk-memory,+mutable-globals","-C",
|
||||
"link-arg=--max-memory=4294967296"]
|
||||
|
||||
|
||||
[target.x86_64-apple-darwin]
|
||||
rustflags = [
|
||||
"-C", "link-arg=-undefined",
|
||||
"-C", "link-arg=dynamic_lookup",
|
||||
]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = [
|
||||
"-C", "link-arg=-undefined",
|
||||
"-C", "link-arg=dynamic_lookup",
|
||||
]
|
||||
4
.github/workflows/pypi-gpu.yml
vendored
4
.github/workflows/pypi-gpu.yml
vendored
@@ -98,14 +98,14 @@ jobs:
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
4
.github/workflows/pypi.yml
vendored
4
.github/workflows/pypi.yml
vendored
@@ -348,14 +348,14 @@ jobs:
|
||||
# publishes to PyPI
|
||||
- name: Publish package distributions to PyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
packages-dir: ./
|
||||
|
||||
# publishes to TestPyPI
|
||||
- name: Publish package distribution to TestPyPI
|
||||
continue-on-error: true
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
uses: pypa/gh-action-pypi-publish@unstable/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
packages-dir: ./
|
||||
|
||||
57
.github/workflows/rust.yml
vendored
57
.github/workflows/rust.yml
vendored
@@ -207,23 +207,6 @@ jobs:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
|
||||
|
||||
tutorial:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Circuit Render
|
||||
run: cargo nextest run --release --verbose tests::tutorial_
|
||||
|
||||
mock-proving-tests:
|
||||
runs-on: non-gpu
|
||||
needs: [build, library-tests, docs, python-tests, python-integration-tests]
|
||||
@@ -494,23 +477,23 @@ jobs:
|
||||
- name: Mock aggr tests (KZG)
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
|
||||
prove-and-verify-aggr-tests-gpu:
|
||||
runs-on: GPU
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG )tests
|
||||
run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
# prove-and-verify-aggr-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG tests
|
||||
# run: cargo nextest run --verbose tests_aggr::kzg_aggr_prove_and_verify_ --features icicle --test-threads 1 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-tests:
|
||||
runs-on: large-self-hosted
|
||||
@@ -614,8 +597,6 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Div rebase
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_div_rebase_
|
||||
- name: Public inputs
|
||||
run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_inputs_
|
||||
- name: fixed params
|
||||
@@ -669,6 +650,10 @@ jobs:
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Neural bow
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
|
||||
- name: Postgres tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
|
||||
- name: Tictactoe tutorials
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -27,7 +27,6 @@ __pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.py[cod]
|
||||
bin/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
@@ -49,4 +48,4 @@ timingData.json
|
||||
!tests/assets/pk.key
|
||||
!tests/assets/vk.key
|
||||
docs/python/build
|
||||
!tests/assets/vk_aggr.key
|
||||
!tests/assets/vk_aggr.key
|
||||
|
||||
717
Cargo.lock
generated
717
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
34
Cargo.toml
34
Cargo.toml
@@ -16,11 +16,11 @@ crate-type = ["cdylib", "rlib", "staticlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = [
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", features = [
|
||||
"circuit-params",
|
||||
] }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
@@ -35,7 +35,7 @@ halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves", optional = true }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", optional = true }
|
||||
maybe-rayon = { version = "0.1.1", default-features = false }
|
||||
bincode = { version = "1.3.3", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
@@ -79,21 +79,22 @@ tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
], optional = true }
|
||||
pyo3 = { version = "0.21.2", features = [
|
||||
pyo3 = { version = "0.23.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default-features = false, optional = true }
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch = "migration-pyo3-0.21", features = [
|
||||
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default-features = false, optional = true }
|
||||
pyo3-log = { version = "0.10.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true }
|
||||
pyo3-log = { version = "0.12.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
pyo3-stub-gen = { version = "0.6.0", optional = true }
|
||||
|
||||
# universal bindings
|
||||
uniffi = { version = "=0.28.0", optional = true }
|
||||
@@ -210,6 +211,10 @@ required-features = ["ezkl"]
|
||||
name = "ios_gen_bindings"
|
||||
required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"]
|
||||
|
||||
[[bin]]
|
||||
name = "py_stub_gen"
|
||||
required-features = ["python-bindings"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = [
|
||||
@@ -220,7 +225,7 @@ default = [
|
||||
"parallel-poly-read",
|
||||
]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-async-runtimes", "pyo3-stub-gen"]
|
||||
ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"]
|
||||
ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"]
|
||||
ezkl = [
|
||||
@@ -269,12 +274,9 @@ empty-cmd = []
|
||||
no-banner = []
|
||||
no-update = []
|
||||
|
||||
# icicle patch to 0.1.0 if feature icicle is enabled
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b", package = "halo2_proofs" }
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
@@ -284,3 +286,11 @@ rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
wasm-opt = [
|
||||
"-O4",
|
||||
"--flexible-inline-max-function-size",
|
||||
"4294967295",
|
||||
]
|
||||
130
examples/notebooks/felt_conversion_test.ipynb
Normal file
130
examples/notebooks/felt_conversion_test.ipynb
Normal file
File diff suppressed because one or more lines are too long
766
examples/notebooks/neural_bow.ipynb
Normal file
766
examples/notebooks/neural_bow.ipynb
Normal file
@@ -0,0 +1,766 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"This is a zk version of the tutorial found [here](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/main/1%20-%20Neural%20Bag%20of%20Words.ipynb). The original tutorial is part of the PyTorch Sentiment Analysis series by Ben Trevett.\n",
|
||||
"\n",
|
||||
"1 - NBoW\n",
|
||||
"\n",
|
||||
"In this series we'll be building a machine learning model to perform sentiment analysis -- a subset of text classification where the task is to detect if a given sentence is positive or negative -- using PyTorch and torchtext. The dataset used will be movie reviews from the IMDb dataset, which we'll obtain using the datasets library.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"Preparing Data\n",
|
||||
"\n",
|
||||
"Before we can implement our NBoW model, we first have to perform quite a few steps to get our data ready to use. NLP usually requires quite a lot of data wrangling beforehand, though libraries such as datasets and torchtext handle most of this for us.\n",
|
||||
"\n",
|
||||
"The steps to take are:\n",
|
||||
"\n",
|
||||
" 1. importing modules\n",
|
||||
" 2. loading data\n",
|
||||
" 3. tokenizing data\n",
|
||||
" 4. creating data splits\n",
|
||||
" 5. creating a vocabulary\n",
|
||||
" 6. numericalizing data\n",
|
||||
" 7. creating the data loaders\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install torchtex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import collections\n",
|
||||
"\n",
|
||||
"import datasets\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torchtext\n",
|
||||
"import tqdm\n",
|
||||
"\n",
|
||||
"# It is usually good practice to run your experiments multiple times with different random seeds -- both to measure the variance of your model and also to avoid having results only calculated with either \"good\" or \"bad\" seeds, i.e. being very lucky or unlucky with the randomness in the training process.\n",
|
||||
"\n",
|
||||
"seed = 1234\n",
|
||||
"\n",
|
||||
"np.random.seed(seed)\n",
|
||||
"torch.manual_seed(seed)\n",
|
||||
"torch.cuda.manual_seed(seed)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data, test_data = datasets.load_dataset(\"imdb\", split=[\"train\", \"test\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can check the features attribute of a split to get more information about the features. We can see that text is a Value of dtype=string -- in other words, it's a string -- and that label is a ClassLabel. A ClassLabel means the feature is an integer representation of which class the example belongs to. num_classes=2 means that our labels are one of two values, 0 or 1, and names=['neg', 'pos'] gives us the human-readable versions of those values. Thus, a label of 0 means the example is a negative review and a label of 1 means the example is a positive review."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data.features\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One of the first things we need to do to our data is tokenize it. Machine learning models aren't designed to handle strings, they're design to handle numbers. So what we need to do is break down our string into individual tokens, and then convert these tokens to numbers. We'll get to the conversion later, but first we'll look at tokenization.\n",
|
||||
"\n",
|
||||
"Tokenization involves using a tokenizer to process the strings in our dataset. A tokenizer is a function that goes from a string to a list of strings. There are many types of tokenizers available, but we're going to use a relatively simple one provided by torchtext called the basic_english tokenizer. We load our tokenizer as such:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = torchtext.data.utils.get_tokenizer(\"basic_english\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tokenize_example(example, tokenizer, max_length):\n",
|
||||
" tokens = tokenizer(example[\"text\"])[:max_length]\n",
|
||||
" return {\"tokens\": tokens}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"max_length = 256\n",
|
||||
"\n",
|
||||
"train_data = train_data.map(\n",
|
||||
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
|
||||
")\n",
|
||||
"test_data = test_data.map(\n",
|
||||
" tokenize_example, fn_kwargs={\"tokenizer\": tokenizer, \"max_length\": max_length}\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# create validation data \n",
|
||||
"# Why have both a validation set and a test set? Your test set respresents the real world data that you'd see if you actually deployed this model. You won't be able to see what data your model will be fed once deployed, and your test set is supposed to reflect that. Every time we tune our model hyperparameters or training set-up to make it do a bit better on the test set, we are leak information from the test set into the training process. If we do this too often then we begin to overfit on the test set. Hence, we need some data which can act as a \"proxy\" test set which we can look at more frequently in order to evaluate how well our model actually does on unseen data -- this is the validation set.\n",
|
||||
"\n",
|
||||
"test_size = 0.25\n",
|
||||
"\n",
|
||||
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
|
||||
"train_data = train_valid_data[\"train\"]\n",
|
||||
"valid_data = train_valid_data[\"test\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we have to build a vocabulary. This is look-up table where every unique token in your dataset has a corresponding index (an integer).\n",
|
||||
"\n",
|
||||
"We do this as machine learning models cannot operate on strings, only numerical vaslues. Each index is used to construct a one-hot vector for each token. A one-hot vector is a vector where all the elements are 0, except one, which is 1, and the dimensionality is the total number of unique tokens in your vocabulary, commonly denoted by V."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"min_freq = 5\n",
|
||||
"special_tokens = [\"<unk>\", \"<pad>\"]\n",
|
||||
"\n",
|
||||
"vocab = torchtext.vocab.build_vocab_from_iterator(\n",
|
||||
" train_data[\"tokens\"],\n",
|
||||
" min_freq=min_freq,\n",
|
||||
" specials=special_tokens,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# We store the indices of the unknown and padding tokens (zero and one, respectively) in variables, as we'll use these further on in this notebook.\n",
|
||||
"\n",
|
||||
"unk_index = vocab[\"<unk>\"]\n",
|
||||
"pad_index = vocab[\"<pad>\"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"vocab.set_default_index(unk_index)\n",
|
||||
"\n",
|
||||
"# To look-up a list of tokens, we can use the vocabulary's lookup_indices method.\n",
|
||||
"vocab.lookup_indices([\"hello\", \"world\", \"some_token\", \"<pad>\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we have our vocabulary, we can numericalize our data. This involves converting the tokens within our dataset into indices. Similar to how we tokenized our data using the Dataset.map method, we'll define a function that takes an example and our vocabulary, gets the index for each token in each example and then creates an ids field which containes the numericalized tokens."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def numericalize_example(example, vocab):\n",
|
||||
" ids = vocab.lookup_indices(example[\"tokens\"])\n",
|
||||
" return {\"ids\": ids}\n",
|
||||
"\n",
|
||||
"train_data = train_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
|
||||
"valid_data = valid_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
|
||||
"test_data = test_data.map(numericalize_example, fn_kwargs={\"vocab\": vocab})\n",
|
||||
"\n",
|
||||
"train_data = train_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
|
||||
"valid_data = valid_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n",
|
||||
"test_data = test_data.with_format(type=\"torch\", columns=[\"ids\", \"label\"])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The final step of preparing the data is creating the data loaders. We can iterate over a data loader to retrieve batches of examples. This is also where we will perform any padding that is necessary.\n",
|
||||
"\n",
|
||||
"We first need to define a function to collate a batch, consisting of a list of examples, into what we want our data loader to output.\n",
|
||||
"\n",
|
||||
"Here, our desired output from the data loader is a dictionary with keys of \"ids\" and \"label\".\n",
|
||||
"\n",
|
||||
"The value of batch[\"ids\"] should be a tensor of shape [batch size, length], where length is the length of the longest sentence (in terms of tokens) within the batch, and all sentences shorter than this should be padded to that length.\n",
|
||||
"\n",
|
||||
"The value of batch[\"label\"] should be a tensor of shape [batch size] consisting of the label for each sentence in the batch.\n",
|
||||
"\n",
|
||||
"We define a function, get_collate_fn, which is passed the pad token index and returns the actual collate function. Within the actual collate function, collate_fn, we get a list of \"ids\" tensors for each example in the batch, and then use the pad_sequence function, which converts the list of tensors into the desired [batch size, length] shaped tensor and performs padding using the specified pad_index. By default, pad_sequence will return a [length, batch size] shaped tensor, but by setting batch_first=True, these two dimensions are switched. We get a list of \"label\" tensors and convert the list of tensors into a single [batch size] shaped tensor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_collate_fn(pad_index):\n",
|
||||
" def collate_fn(batch):\n",
|
||||
" batch_ids = [i[\"ids\"] for i in batch]\n",
|
||||
" batch_ids = nn.utils.rnn.pad_sequence(\n",
|
||||
" batch_ids, padding_value=pad_index, batch_first=True\n",
|
||||
" )\n",
|
||||
" batch_label = [i[\"label\"] for i in batch]\n",
|
||||
" batch_label = torch.stack(batch_label)\n",
|
||||
" batch = {\"ids\": batch_ids, \"label\": batch_label}\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
" return collate_fn\n",
|
||||
"\n",
|
||||
"def get_data_loader(dataset, batch_size, pad_index, shuffle=False):\n",
|
||||
" collate_fn = get_collate_fn(pad_index)\n",
|
||||
" data_loader = torch.utils.data.DataLoader(\n",
|
||||
" dataset=dataset,\n",
|
||||
" batch_size=batch_size,\n",
|
||||
" collate_fn=collate_fn,\n",
|
||||
" shuffle=shuffle,\n",
|
||||
" )\n",
|
||||
" return data_loader\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"batch_size = 512\n",
|
||||
"\n",
|
||||
"train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)\n",
|
||||
"valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)\n",
|
||||
"test_data_loader = get_data_loader(test_data, batch_size, pad_index)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"class NBoW(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
|
||||
" super().__init__()\n",
|
||||
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
|
||||
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
|
||||
"\n",
|
||||
" def forward(self, ids):\n",
|
||||
" # ids = [batch size, seq len]\n",
|
||||
" embedded = self.embedding(ids)\n",
|
||||
" # embedded = [batch size, seq len, embedding dim]\n",
|
||||
" pooled = embedded.mean(dim=1)\n",
|
||||
" # pooled = [batch size, embedding dim]\n",
|
||||
" prediction = self.fc(pooled)\n",
|
||||
" # prediction = [batch size, output dim]\n",
|
||||
" return prediction\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"vocab_size = len(vocab)\n",
|
||||
"embedding_dim = 300\n",
|
||||
"output_dim = len(train_data.unique(\"label\"))\n",
|
||||
"\n",
|
||||
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)\n",
|
||||
"\n",
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f\"The model has {count_parameters(model):,} trainable parameters\")\n",
|
||||
"\n",
|
||||
"vectors = torchtext.vocab.GloVe()\n",
|
||||
"\n",
|
||||
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())\n",
|
||||
"\n",
|
||||
"optimizer = optim.Adam(model.parameters())\n",
|
||||
"\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"model = model.to(device)\n",
|
||||
"criterion = criterion.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(data_loader, model, criterion, optimizer, device):\n",
|
||||
" model.train()\n",
|
||||
" epoch_losses = []\n",
|
||||
" epoch_accs = []\n",
|
||||
" for batch in tqdm.tqdm(data_loader, desc=\"training...\"):\n",
|
||||
" ids = batch[\"ids\"].to(device)\n",
|
||||
" label = batch[\"label\"].to(device)\n",
|
||||
" prediction = model(ids)\n",
|
||||
" loss = criterion(prediction, label)\n",
|
||||
" accuracy = get_accuracy(prediction, label)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" epoch_losses.append(loss.item())\n",
|
||||
" epoch_accs.append(accuracy.item())\n",
|
||||
" return np.mean(epoch_losses), np.mean(epoch_accs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluate(data_loader, model, criterion, device):\n",
|
||||
" model.eval()\n",
|
||||
" epoch_losses = []\n",
|
||||
" epoch_accs = []\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for batch in tqdm.tqdm(data_loader, desc=\"evaluating...\"):\n",
|
||||
" ids = batch[\"ids\"].to(device)\n",
|
||||
" label = batch[\"label\"].to(device)\n",
|
||||
" prediction = model(ids)\n",
|
||||
" loss = criterion(prediction, label)\n",
|
||||
" accuracy = get_accuracy(prediction, label)\n",
|
||||
" epoch_losses.append(loss.item())\n",
|
||||
" epoch_accs.append(accuracy.item())\n",
|
||||
" return np.mean(epoch_losses), np.mean(epoch_accs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_accuracy(prediction, label):\n",
|
||||
" batch_size, _ = prediction.shape\n",
|
||||
" predicted_classes = prediction.argmax(dim=-1)\n",
|
||||
" correct_predictions = predicted_classes.eq(label).sum()\n",
|
||||
" accuracy = correct_predictions / batch_size\n",
|
||||
" return accuracy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n_epochs = 10\n",
|
||||
"best_valid_loss = float(\"inf\")\n",
|
||||
"\n",
|
||||
"metrics = collections.defaultdict(list)\n",
|
||||
"\n",
|
||||
"for epoch in range(n_epochs):\n",
|
||||
" train_loss, train_acc = train(\n",
|
||||
" train_data_loader, model, criterion, optimizer, device\n",
|
||||
" )\n",
|
||||
" valid_loss, valid_acc = evaluate(valid_data_loader, model, criterion, device)\n",
|
||||
" metrics[\"train_losses\"].append(train_loss)\n",
|
||||
" metrics[\"train_accs\"].append(train_acc)\n",
|
||||
" metrics[\"valid_losses\"].append(valid_loss)\n",
|
||||
" metrics[\"valid_accs\"].append(valid_acc)\n",
|
||||
" if valid_loss < best_valid_loss:\n",
|
||||
" best_valid_loss = valid_loss\n",
|
||||
" torch.save(model.state_dict(), \"nbow.pt\")\n",
|
||||
" print(f\"epoch: {epoch}\")\n",
|
||||
" print(f\"train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}\")\n",
|
||||
" print(f\"valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig = plt.figure(figsize=(10, 6))\n",
|
||||
"ax = fig.add_subplot(1, 1, 1)\n",
|
||||
"ax.plot(metrics[\"train_losses\"], label=\"train loss\")\n",
|
||||
"ax.plot(metrics[\"valid_losses\"], label=\"valid loss\")\n",
|
||||
"ax.set_xlabel(\"epoch\")\n",
|
||||
"ax.set_ylabel(\"loss\")\n",
|
||||
"ax.set_xticks(range(n_epochs))\n",
|
||||
"ax.legend()\n",
|
||||
"ax.grid()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig = plt.figure(figsize=(10, 6))\n",
|
||||
"ax = fig.add_subplot(1, 1, 1)\n",
|
||||
"ax.plot(metrics[\"train_accs\"], label=\"train accuracy\")\n",
|
||||
"ax.plot(metrics[\"valid_accs\"], label=\"valid accuracy\")\n",
|
||||
"ax.set_xlabel(\"epoch\")\n",
|
||||
"ax.set_ylabel(\"loss\")\n",
|
||||
"ax.set_xticks(range(n_epochs))\n",
|
||||
"ax.legend()\n",
|
||||
"ax.grid()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.load_state_dict(torch.load(\"nbow.pt\"))\n",
|
||||
"\n",
|
||||
"test_loss, test_acc = evaluate(test_data_loader, model, criterion, device)\n",
|
||||
"\n",
|
||||
"print(f\"test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
|
||||
" tokens = tokenizer(text)\n",
|
||||
" ids = vocab.lookup_indices(tokens)\n",
|
||||
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
|
||||
" prediction = model(tensor).squeeze(dim=0)\n",
|
||||
" probability = torch.softmax(prediction, dim=-1)\n",
|
||||
" predicted_class = prediction.argmax(dim=-1).item()\n",
|
||||
" predicted_probability = probability[predicted_class].item()\n",
|
||||
" return predicted_class, predicted_probability"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This film is terrible!\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(text, model, tokenizer, vocab, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This film is great!\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(text, model, tokenizer, vocab, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This film is not terrible, it's great!\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(text, model, tokenizer, vocab, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"This film is not great, it's terrible!\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(text, model, tokenizer, vocab, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def text_to_tensor(text, tokenizer, vocab, device):\n",
|
||||
" tokens = tokenizer(text)\n",
|
||||
" ids = vocab.lookup_indices(tokens)\n",
|
||||
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
|
||||
" return tensor\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we do onnx stuff to get the data ready for the zk-circuit."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"text = \"This film is terrible!\"\n",
|
||||
"x = text_to_tensor(text, tokenizer, vocab, device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"model.eval()\n",
|
||||
"model.to('cpu')\n",
|
||||
"\n",
|
||||
"model_path = \"network.onnx\"\n",
|
||||
"data_path = \"input.json\"\n",
|
||||
"\n",
|
||||
" # Export the model\n",
|
||||
"torch.onnx.export(model, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" model_path, # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=10, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data_json = dict(input_data = [data_array])\n",
|
||||
"\n",
|
||||
"print(data_json)\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(data_json, open(data_path, 'w'))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import ezkl\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.logrows = 23\n",
|
||||
"run_args.scale_rebase_multiplier = 10\n",
|
||||
"# inputs should be auditable by all\n",
|
||||
"run_args.input_visibility = \"public\"\n",
|
||||
"# same with outputs\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"# for simplicity, we'll just use the fixed model visibility: i.e it is public and can't be changed by the prover\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# TODO: Dictionary outputs\n",
|
||||
"res = ezkl.gen_settings(py_run_args=run_args)\n",
|
||||
"assert res == True\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.compile_circuit()\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# srs path\n",
|
||||
"res = await ezkl.get_srs()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# now generate the witness file\n",
|
||||
"res = await ezkl.gen_witness()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = ezkl.mock()\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
||||
"# WE GOT KEYS\n",
|
||||
"# WE GOT CIRCUIT PARAMETERS\n",
|
||||
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
||||
"\n",
|
||||
"res = ezkl.setup()\n",
|
||||
"\n",
|
||||
"assert res == True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GENERATE A PROOF\n",
|
||||
"res = ezkl.prove(proof_path=\"proof.json\")\n",
|
||||
"\n",
|
||||
"print(res)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# VERIFY IT\n",
|
||||
"res = ezkl.verify()\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also verify it on chain by creating an onchain verifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"solc-select\"])\n",
|
||||
" !solc-select install 0.8.20\n",
|
||||
" !solc-select use 0.8.20\n",
|
||||
" !solc --version\n",
|
||||
" import os\n",
|
||||
"\n",
|
||||
"# rely on local installation if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" import os\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"res = await ezkl.create_evm_verifier()\n",
|
||||
"assert res == True\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You should see a `Verifier.sol`. Right-click and save it locally.\n",
|
||||
"\n",
|
||||
"Now go to [https://remix.ethereum.org](https://remix.ethereum.org).\n",
|
||||
"\n",
|
||||
"Create a new file within remix and copy the verifier code over.\n",
|
||||
"\n",
|
||||
"Finally, compile the code and deploy. For the demo you can deploy to the test environment within remix.\n",
|
||||
"\n",
|
||||
"If everything works, you would have deployed your verifer onchain! Copy the values in the cell above to the respective fields to test if the verifier is working."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,75 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
from torch import nn
|
||||
import json
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
class Model(nn.Module):
|
||||
"""
|
||||
Just one Linear layer
|
||||
"""
|
||||
def __init__(self, configs):
|
||||
super(Model, self).__init__()
|
||||
self.seq_len = configs.seq_len
|
||||
self.pred_len = configs.pred_len
|
||||
|
||||
# Use this line if you want to visualize the weights
|
||||
# self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
||||
self.channels = configs.enc_in
|
||||
self.individual = configs.individual
|
||||
if self.individual:
|
||||
self.Linear = nn.ModuleList()
|
||||
for i in range(self.channels):
|
||||
self.Linear.append(nn.Linear(self.seq_len,self.pred_len))
|
||||
else:
|
||||
self.Linear = nn.Linear(self.seq_len, self.pred_len)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [Batch, Input length, Channel]
|
||||
if self.individual:
|
||||
output = torch.zeros([x.size(0),self.pred_len,x.size(2)],dtype=x.dtype).to(x.device)
|
||||
for i in range(self.channels):
|
||||
output[:,:,i] = self.Linear[i](x[:,:,i])
|
||||
x = output
|
||||
else:
|
||||
x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
|
||||
return x # [Batch, Output length, Channel]
|
||||
|
||||
class Configs:
|
||||
def __init__(self, seq_len, pred_len, enc_in=321, individual=True):
|
||||
self.seq_len = seq_len
|
||||
self.pred_len = pred_len
|
||||
self.enc_in = enc_in
|
||||
self.individual = individual
|
||||
|
||||
model = 'Linear'
|
||||
seq_len = 10
|
||||
pred_len = 4
|
||||
enc_in = 3
|
||||
|
||||
configs = Configs(seq_len, pred_len, enc_in, True)
|
||||
circuit = Model(configs)
|
||||
|
||||
x = torch.randn(1, seq_len, pred_len)
|
||||
import numpy as np
|
||||
import tf2onnx
|
||||
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=15, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
# the model's input names
|
||||
input_names=['input'],
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import *
|
||||
from tensorflow.keras.models import Model
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
# gather_nd in tf then export to onnx
|
||||
x = in1 = Input((4, 1), dtype=tf.int32)
|
||||
w = in2 = Input((4, ), dtype=tf.int32)
|
||||
|
||||
class MyLayer(Layer):
|
||||
def call(self, x, w):
|
||||
shape = tf.constant([8])
|
||||
return tf.scatter_nd(x, w, shape)
|
||||
|
||||
x = MyLayer()(x, w)
|
||||
|
||||
|
||||
|
||||
tm = Model((in1, in2), x)
|
||||
tm.summary()
|
||||
tm.compile(optimizer='adam', loss='mse')
|
||||
|
||||
shape = [1, 4, 1]
|
||||
index_shape = [1, 4]
|
||||
# After training, export to onnx (network.onnx) and create a data file (input.json)
|
||||
x = np.random.randint(0, 4, shape)
|
||||
# w = random int tensor
|
||||
w = np.random.randint(0, 4, index_shape)
|
||||
|
||||
spec = tf.TensorSpec(shape, tf.int32, name='input_0')
|
||||
index_spec = tf.TensorSpec(index_shape, tf.int32, name='input_1')
|
||||
|
||||
model_path = "network.onnx"
|
||||
|
||||
tf2onnx.convert.from_keras(tm, input_signature=[spec, index_spec], inputs_as_nchw=['input_0', 'input_1'], opset=12, output_path=model_path)
|
||||
|
||||
|
||||
d = x.reshape([-1]).tolist()
|
||||
d1 = w.reshape([-1]).tolist()
|
||||
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
input_data=[d, d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
|
||||
@@ -1 +1,16 @@
|
||||
{"input_data": [[0.1874287724494934, 1.0498261451721191, 0.22384068369865417, 1.048445224761963, -0.5670360326766968, -0.38653188943862915, 0.12878702580928802, -2.3675858974456787, 0.5800458192825317, -0.43653929233551025, -0.2511898875236511, 0.3324051797389984, 0.27960312366485596, 0.4763695001602173, 0.3796705901622772, 1.1334782838821411, -0.87981778383255, -1.2451434135437012, 0.7672272324562073, -0.24404007196426392, -0.6875824928283691, 0.3619358539581299, -0.10131897777318954, 0.7169521450996399, 1.6585893630981445, -0.5451845526695251, 0.429487019777298, 0.7426952123641968, -0.2543637454509735, 0.06546942889690399, 0.7939824461936951, 0.1579471379518509, -0.043604474514722824, -0.8621711730957031, -0.5344759821891785, -0.05880478024482727, -0.17351101338863373, 0.5095029473304749, -0.7864817976951599, -0.449171245098114]]}
|
||||
{
|
||||
"input_data": [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3
|
||||
],
|
||||
[
|
||||
1,
|
||||
0,
|
||||
2,
|
||||
1
|
||||
]
|
||||
]
|
||||
}
|
||||
Binary file not shown.
851
ezkl.pyi
Normal file
851
ezkl.pyi
Normal file
@@ -0,0 +1,851 @@
|
||||
# This file is automatically generated by pyo3_stub_gen
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import typing
|
||||
from enum import Enum, auto
|
||||
|
||||
class PyG1:
|
||||
r"""
|
||||
pyclass containing the struct used for G1, this is mostly a helper class
|
||||
"""
|
||||
...
|
||||
|
||||
class PyG1Affine:
|
||||
r"""
|
||||
pyclass containing the struct used for G1
|
||||
"""
|
||||
...
|
||||
|
||||
class PyRunArgs:
|
||||
r"""
|
||||
Python class containing the struct used for run_args
|
||||
|
||||
Returns
|
||||
-------
|
||||
PyRunArgs
|
||||
"""
|
||||
...
|
||||
|
||||
class PyCommitments(Enum):
|
||||
r"""
|
||||
pyclass representing an enum, denoting the type of commitment
|
||||
"""
|
||||
KZG = auto()
|
||||
IPA = auto()
|
||||
|
||||
class PyInputType(Enum):
|
||||
Bool = auto()
|
||||
F16 = auto()
|
||||
F32 = auto()
|
||||
F64 = auto()
|
||||
Int = auto()
|
||||
TDim = auto()
|
||||
|
||||
class PyTestDataSource(Enum):
|
||||
r"""
|
||||
pyclass representing an enum
|
||||
"""
|
||||
File = auto()
|
||||
OnChain = auto()
|
||||
|
||||
def aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,transcript:str,logrows:int,check_mode:str,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:PyCommitments) -> bool:
|
||||
r"""
|
||||
Creates an aggregated proof
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_snarks: list[str]
|
||||
List of paths to the various proofs
|
||||
|
||||
proof_path: str
|
||||
Path to output the aggregated proof
|
||||
|
||||
vk_path: str
|
||||
Path to the VK file
|
||||
|
||||
transcript:
|
||||
Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
|
||||
|
||||
logrows:
|
||||
Logrows used for aggregation circuit
|
||||
|
||||
check_mode: str
|
||||
Run sanity checks during calculations. Accepts `safe` or `unsafe`
|
||||
|
||||
split-proofs: bool
|
||||
Whether the accumulated proofs are segments of a larger circuit
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS used
|
||||
|
||||
commitment: str
|
||||
Accepts "kzg" or "ipa"
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def buffer_to_felts(buffer:typing.Sequence[int]) -> list[str]:
|
||||
r"""
|
||||
Converts a buffer to vector of field elements
|
||||
|
||||
Arguments
|
||||
-------
|
||||
buffer: list[int]
|
||||
List of integers representing a buffer
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
List of field elements represented as strings
|
||||
"""
|
||||
...
|
||||
|
||||
def calibrate_settings(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,settings:str | os.PathLike | pathlib.Path,target:str,lookup_safety_margin:float,scales:typing.Optional[typing.Sequence[int]],scale_rebase_multiplier:typing.Sequence[int],max_logrows:typing.Optional[int]) -> typing.Any:
|
||||
r"""
|
||||
Calibrates the circuit settings
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data: str
|
||||
Path to the calibration data
|
||||
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
settings: str
|
||||
Path to the settings file
|
||||
|
||||
lookup_safety_margin: int
|
||||
the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
|
||||
scales: list[int]
|
||||
Optional scales to specifically try for calibration
|
||||
|
||||
scale_rebase_multiplier: list[int]
|
||||
Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
|
||||
|
||||
max_logrows: int
|
||||
Optional max logrows to use for calibration
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def compile_circuit(model:str | os.PathLike | pathlib.Path,compiled_circuit:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Compiles the circuit for use in other steps
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx model file
|
||||
|
||||
compiled_circuit: str
|
||||
Path to output the compiled circuit
|
||||
|
||||
settings_path: str
|
||||
Path to the settings files
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_data_attestation(input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,witness_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
input_data: str
|
||||
The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
|
||||
|
||||
settings_path: str
|
||||
The path to the settings file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the create the solidity verifier
|
||||
|
||||
abi_path: str
|
||||
The path to create the ABI for the solidity verifier
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_verifier(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
|
||||
r"""
|
||||
Creates an EVM compatible verifier, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
settings_path: str
|
||||
The path to the settings file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the create the solidity verifier
|
||||
|
||||
abi_path: str
|
||||
The path to create the ABI for the solidity verifier
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
reusable: bool
|
||||
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_verifier_aggr(aggregation_settings:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,logrows:int,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reusable:bool) -> typing.Any:
|
||||
r"""
|
||||
Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_settings: str
|
||||
path to the settings file
|
||||
|
||||
vk_path: str
|
||||
The path to load the desired verification key file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the Solidity code
|
||||
|
||||
abi_path: str
|
||||
The path to output the Solidity verifier ABI
|
||||
|
||||
logrows: int
|
||||
Number of logrows used during aggregated setup
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
reusable: bool
|
||||
Whether the verifier should be rendered as a reusable contract. If so, then you will need to deploy the VK artifact separately which you can generate using the create_evm_vka command
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def create_evm_vka(vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,abi_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Creates an Evm VK artifact. This command generated a VK with circuit specific meta data encoding in memory for use by the reusable H2 verifier.
|
||||
This is useful for deploying verifier that were otherwise too big to fit on chain and required aggregation.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
settings_path: str
|
||||
The path to the settings file
|
||||
|
||||
sol_code_path: str
|
||||
The path to the create the solidity verifying key.
|
||||
|
||||
abi_path: str
|
||||
The path to create the ABI for the solidity verifier
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def deploy_da_evm(addr_path:str | os.PathLike | pathlib.Path,input_data:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
deploys the solidity da verifier
|
||||
"""
|
||||
...
|
||||
|
||||
def deploy_evm(addr_path:str | os.PathLike | pathlib.Path,sol_code_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],contract_type:str,optimizer_runs:int,private_key:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
deploys the solidity verifier
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_evm_calldata(proof:str | os.PathLike | pathlib.Path,calldata:str | os.PathLike | pathlib.Path,addr_vk:typing.Optional[str]) -> list[int]:
|
||||
r"""
|
||||
Creates encoded evm calldata from a proof file
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof: str
|
||||
Path to the proof file
|
||||
|
||||
calldata: str
|
||||
Path to the calldata file to save
|
||||
|
||||
addr_vk: str
|
||||
The address of the verification key contract (if the verifier key is to be rendered as a separate contract)
|
||||
|
||||
Returns
|
||||
-------
|
||||
vec[u8]
|
||||
The encoded calldata
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_big_endian(felt:str) -> str:
|
||||
r"""
|
||||
Converts a field element hex string to big endian
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
field element represented as a string
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_float(felt:str,scale:int) -> float:
|
||||
r"""
|
||||
Converts a field element hex string to a floating point number
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
scale: float
|
||||
The scaling factor used to convert the field element into a floating point representation
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
"""
|
||||
...
|
||||
|
||||
def felt_to_int(felt:str) -> int:
|
||||
r"""
|
||||
Converts a field element hex string to an integer
|
||||
|
||||
Arguments
|
||||
-------
|
||||
felt: str
|
||||
The field element represented as a string
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
"""
|
||||
...
|
||||
|
||||
def float_to_felt(input:float,scale:int,input_type:PyInputType) -> str:
|
||||
r"""
|
||||
Converts a floating point element to a field element hex string
|
||||
|
||||
Arguments
|
||||
-------
|
||||
input: float
|
||||
The field element represented as a string
|
||||
|
||||
scale: float
|
||||
The scaling factor used to quantize the float into a field element
|
||||
|
||||
input_type: PyInputType
|
||||
The type of the input
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The field element represented as a string
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_settings(model:str | os.PathLike | pathlib.Path,output:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> bool:
|
||||
r"""
|
||||
Generates the circuit settings
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
output: str
|
||||
Path to create the settings file
|
||||
|
||||
py_run_args: PyRunArgs
|
||||
PyRunArgs object to initialize the settings
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_srs(srs_path:str | os.PathLike | pathlib.Path,logrows:int) -> None:
|
||||
r"""
|
||||
Generates the Structured Reference String (SRS), use this only for testing purposes
|
||||
|
||||
Arguments
|
||||
---------
|
||||
srs_path: str
|
||||
Path to the create the SRS file
|
||||
|
||||
logrows: int
|
||||
The number of logrows for the SRS file
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_vk_from_pk_aggr(path_to_pk:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Generates a vk from a pk for an aggregate circuit and saves it to a file
|
||||
|
||||
Arguments
|
||||
-------
|
||||
path_to_pk: str
|
||||
Path to the proving key
|
||||
|
||||
vk_output_path: str
|
||||
Path to create the vk file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_vk_from_pk_single(path_to_pk:str | os.PathLike | pathlib.Path,circuit_settings_path:str | os.PathLike | pathlib.Path,vk_output_path:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Generates a vk from a pk for a model circuit and saves it to a file
|
||||
|
||||
Arguments
|
||||
-------
|
||||
path_to_pk: str
|
||||
Path to the proving key
|
||||
|
||||
circuit_settings_path: str
|
||||
Path to the witness file
|
||||
|
||||
vk_output_path: str
|
||||
Path to create the vk file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def gen_witness(data:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,output:typing.Optional[str | os.PathLike | pathlib.Path],vk_path:typing.Optional[str | os.PathLike | pathlib.Path],srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Runs the forward pass operation to generate a witness
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data: str
|
||||
Path to the data file
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
output: str
|
||||
Path to create the witness file
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Python object containing the witness values
|
||||
"""
|
||||
...
|
||||
|
||||
def get_srs(settings_path:typing.Optional[str | os.PathLike | pathlib.Path],logrows:typing.Optional[int],srs_path:typing.Optional[str | os.PathLike | pathlib.Path],commitment:typing.Optional[PyCommitments]) -> typing.Any:
|
||||
r"""
|
||||
Gets a public srs
|
||||
|
||||
Arguments
|
||||
---------
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
logrows: int
|
||||
The number of logrows for the SRS file
|
||||
|
||||
srs_path: str
|
||||
Path to the create the SRS file
|
||||
|
||||
commitment: str
|
||||
Specify the commitment used ("kzg", "ipa")
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def ipa_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
|
||||
r"""
|
||||
Generate an ipa commitment.
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represnted as strings
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
srs_path: str
|
||||
Path to the Structure Reference String (SRS) file
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PyG1Affine]
|
||||
"""
|
||||
...
|
||||
|
||||
def kzg_commit(message:typing.Sequence[str],vk_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> list[PyG1Affine]:
|
||||
r"""
|
||||
Generate a kzg commitment.
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represnted as strings
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
srs_path: str
|
||||
Path to the Structure Reference String (SRS) file
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PyG1Affine]
|
||||
"""
|
||||
...
|
||||
|
||||
def mock(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path) -> bool:
|
||||
r"""
|
||||
Mocks the prover
|
||||
|
||||
Arguments
|
||||
---------
|
||||
witness: str
|
||||
Path to the witness file
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def mock_aggregate(aggregation_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],logrows:int,split_proofs:bool) -> bool:
|
||||
r"""
|
||||
Mocks the aggregate prover
|
||||
|
||||
Arguments
|
||||
---------
|
||||
aggregation_snarks: list[str]
|
||||
List of paths to the relevant proof files
|
||||
|
||||
logrows: int
|
||||
Number of logrows to use for the aggregation circuit
|
||||
|
||||
split_proofs: bool
|
||||
Indicates whether the accumulated are segments of a larger proof
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def poseidon_hash(message:typing.Sequence[str]) -> list[str]:
|
||||
r"""
|
||||
Generate a poseidon hash.
|
||||
|
||||
Arguments
|
||||
-------
|
||||
message: list[str]
|
||||
List of field elements represented as strings
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
List of field elements represented as strings
|
||||
"""
|
||||
...
|
||||
|
||||
def prove(witness:str | os.PathLike | pathlib.Path,model:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,proof_path:typing.Optional[str | os.PathLike | pathlib.Path],proof_type:str,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> typing.Any:
|
||||
r"""
|
||||
Runs the prover on a set of inputs
|
||||
|
||||
Arguments
|
||||
---------
|
||||
witness: str
|
||||
Path to the witness file
|
||||
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
pk_path: str
|
||||
Path to the proving key file
|
||||
|
||||
proof_path: str
|
||||
Path to create the proof file
|
||||
|
||||
proof_type: str
|
||||
Accepts `single`, `for-aggr`
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def setup(model:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],witness_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool) -> bool:
|
||||
r"""
|
||||
Runs the setup process
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the compiled model file
|
||||
|
||||
vk_path: str
|
||||
Path to create the verification key file
|
||||
|
||||
pk_path: str
|
||||
Path to create the proving key file
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
witness_path: str
|
||||
Path to the witness file
|
||||
|
||||
disable_selector_compression: bool
|
||||
Whether to compress the selectors or not
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def setup_aggregate(sample_snarks:typing.Sequence[str | os.PathLike | pathlib.Path],vk_path:str | os.PathLike | pathlib.Path,pk_path:str | os.PathLike | pathlib.Path,logrows:int,split_proofs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],disable_selector_compression:bool,commitment:PyCommitments) -> bool:
|
||||
r"""
|
||||
Runs the setup process for an aggregate setup
|
||||
|
||||
Arguments
|
||||
---------
|
||||
sample_snarks: list[str]
|
||||
List of paths to the various proofs
|
||||
|
||||
vk_path: str
|
||||
Path to create the aggregated VK
|
||||
|
||||
pk_path: str
|
||||
Path to create the aggregated PK
|
||||
|
||||
logrows: int
|
||||
Number of logrows to use
|
||||
|
||||
split_proofs: bool
|
||||
Whether the accumulated are segments of a larger proof
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
disable_selector_compression: bool
|
||||
Whether to compress selectors
|
||||
|
||||
commitment: str
|
||||
Accepts `kzg`, `ipa`
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def setup_test_evm_witness(data_path:str | os.PathLike | pathlib.Path,compiled_circuit_path:str | os.PathLike | pathlib.Path,test_data:str | os.PathLike | pathlib.Path,input_source:PyTestDataSource,output_source:PyTestDataSource,rpc_url:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
Setup test evm witness
|
||||
|
||||
Arguments
|
||||
---------
|
||||
data_path: str
|
||||
The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
|
||||
compiled_circuit_path: str
|
||||
The path to the compiled model file (generated using the compile-circuit command)
|
||||
|
||||
test_data: str
|
||||
For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
|
||||
input_sources: str
|
||||
Where the input data comes from
|
||||
|
||||
output_source: str
|
||||
Where the output data comes from
|
||||
|
||||
rpc_url: str
|
||||
RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def swap_proof_commitments(proof_path:str | os.PathLike | pathlib.Path,witness_path:str | os.PathLike | pathlib.Path) -> None:
|
||||
r"""
|
||||
Swap the commitments in a proof
|
||||
|
||||
Arguments
|
||||
-------
|
||||
proof_path: str
|
||||
Path to the proof file
|
||||
|
||||
witness_path: str
|
||||
Path to the witness file
|
||||
"""
|
||||
...
|
||||
|
||||
def table(model:str | os.PathLike | pathlib.Path,py_run_args:typing.Optional[PyRunArgs]) -> str:
|
||||
r"""
|
||||
Displays the table as a string in python
|
||||
|
||||
Arguments
|
||||
---------
|
||||
model: str
|
||||
Path to the onnx file
|
||||
|
||||
Returns
|
||||
---------
|
||||
str
|
||||
Table of the nodes in the onnx file
|
||||
"""
|
||||
...
|
||||
|
||||
def verify(proof_path:str | os.PathLike | pathlib.Path,settings_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,srs_path:typing.Optional[str | os.PathLike | pathlib.Path],reduced_srs:bool) -> bool:
|
||||
r"""
|
||||
Verifies a given proof
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof_path: str
|
||||
Path to create the proof file
|
||||
|
||||
settings_path: str
|
||||
Path to the settings file
|
||||
|
||||
vk_path: str
|
||||
Path to the verification key file
|
||||
|
||||
srs_path: str
|
||||
Path to the SRS file
|
||||
|
||||
non_reduced_srs: bool
|
||||
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def verify_aggr(proof_path:str | os.PathLike | pathlib.Path,vk_path:str | os.PathLike | pathlib.Path,logrows:int,commitment:PyCommitments,reduced_srs:bool,srs_path:typing.Optional[str | os.PathLike | pathlib.Path]) -> bool:
|
||||
r"""
|
||||
Verifies and aggregate proof
|
||||
|
||||
Arguments
|
||||
---------
|
||||
proof_path: str
|
||||
The path to the proof file
|
||||
|
||||
vk_path: str
|
||||
The path to the verification key file
|
||||
|
||||
logrows: int
|
||||
logrows used for aggregation circuit
|
||||
|
||||
commitment: str
|
||||
Accepts "kzg" or "ipa"
|
||||
|
||||
reduced_srs: bool
|
||||
Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
|
||||
|
||||
srs_path: str
|
||||
The path to the SRS file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
def verify_evm(addr_verifier:str,proof_path:str | os.PathLike | pathlib.Path,rpc_url:typing.Optional[str],addr_da:typing.Optional[str],addr_vk:typing.Optional[str]) -> typing.Any:
|
||||
r"""
|
||||
verifies an evm compatible proof, you will need solc installed in your environment to run this
|
||||
|
||||
Arguments
|
||||
---------
|
||||
addr_verifier: str
|
||||
The verifier contract's address as a hex string
|
||||
|
||||
proof_path: str
|
||||
The path to the proof file (generated using the prove command)
|
||||
|
||||
rpc_url: str
|
||||
RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
|
||||
|
||||
addr_da: str
|
||||
does the verifier use data attestation ?
|
||||
|
||||
addr_vk: str
|
||||
The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
...
|
||||
|
||||
9
src/bin/py_stub_gen.rs
Normal file
9
src/bin/py_stub_gen.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use pyo3_stub_gen::Result;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// `stub_info` is a function defined by `define_stub_info_gatherer!` macro.
|
||||
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
|
||||
let stub = ezkl::bindings::python::stub_info()?;
|
||||
stub.generate()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -4,6 +4,7 @@ use crate::circuit::modules::poseidon::{
|
||||
PoseidonChip,
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
@@ -26,7 +27,12 @@ use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use pyo3_stub_gen::{
|
||||
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
|
||||
derive::gen_stub_pyfunction, TypeInfo,
|
||||
};
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::collections::HashSet;
|
||||
use std::str::FromStr;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
|
||||
@@ -35,6 +41,7 @@ type PyFelt = String;
|
||||
/// pyclass representing an enum
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass_enum]
|
||||
enum PyTestDataSource {
|
||||
/// The data is loaded from a file
|
||||
File,
|
||||
@@ -54,6 +61,7 @@ impl From<PyTestDataSource> for TestDataSource {
|
||||
/// pyclass containing the struct used for G1, this is mostly a helper class
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
struct PyG1 {
|
||||
#[pyo3(get, set)]
|
||||
/// Field Element representing x
|
||||
@@ -100,6 +108,7 @@ impl pyo3::ToPyObject for PyG1 {
|
||||
/// pyclass containing the struct used for G1
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
pub struct PyG1Affine {
|
||||
#[pyo3(get, set)]
|
||||
///
|
||||
@@ -145,6 +154,7 @@ impl pyo3::ToPyObject for PyG1Affine {
|
||||
///
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
#[gen_stub_pyclass]
|
||||
struct PyRunArgs {
|
||||
#[pyo3(get, set)]
|
||||
/// float: The tolerance for error on model outputs
|
||||
@@ -259,6 +269,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass_enum]
|
||||
/// pyclass representing an enum, denoting the type of commitment
|
||||
pub enum PyCommitments {
|
||||
/// KZG commitment
|
||||
@@ -306,6 +317,65 @@ impl FromStr for PyCommitments {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
#[gen_stub_pyclass_enum]
|
||||
enum PyInputType {
|
||||
///
|
||||
Bool,
|
||||
///
|
||||
F16,
|
||||
///
|
||||
F32,
|
||||
///
|
||||
F64,
|
||||
///
|
||||
Int,
|
||||
///
|
||||
TDim,
|
||||
}
|
||||
|
||||
impl From<InputType> for PyInputType {
|
||||
fn from(input_type: InputType) -> Self {
|
||||
match input_type {
|
||||
InputType::Bool => PyInputType::Bool,
|
||||
InputType::F16 => PyInputType::F16,
|
||||
InputType::F32 => PyInputType::F32,
|
||||
InputType::F64 => PyInputType::F64,
|
||||
InputType::Int => PyInputType::Int,
|
||||
InputType::TDim => PyInputType::TDim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyInputType> for InputType {
|
||||
fn from(py_input_type: PyInputType) -> Self {
|
||||
match py_input_type {
|
||||
PyInputType::Bool => InputType::Bool,
|
||||
PyInputType::F16 => InputType::F16,
|
||||
PyInputType::F32 => InputType::F32,
|
||||
PyInputType::F64 => InputType::F64,
|
||||
PyInputType::Int => InputType::Int,
|
||||
PyInputType::TDim => InputType::TDim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PyInputType {
|
||||
type Err = String;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"bool" => Ok(PyInputType::Bool),
|
||||
"f16" => Ok(PyInputType::F16),
|
||||
"f32" => Ok(PyInputType::F32),
|
||||
"f64" => Ok(PyInputType::F64),
|
||||
"int" => Ok(PyInputType::Int),
|
||||
"tdim" => Ok(PyInputType::TDim),
|
||||
_ => Err("Invalid value for InputType".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a field element hex string to big endian
|
||||
///
|
||||
/// Arguments
|
||||
@@ -322,6 +392,7 @@ impl FromStr for PyCommitments {
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
Ok(format!("{:?}", felt))
|
||||
@@ -341,6 +412,7 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
@@ -365,6 +437,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
|
||||
felt,
|
||||
scale
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
@@ -383,6 +456,9 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
/// scale: float
|
||||
/// The scaling factor used to quantize the float into a field element
|
||||
///
|
||||
/// input_type: PyInputType
|
||||
/// The type of the input
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// str
|
||||
@@ -390,9 +466,12 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
input,
|
||||
scale
|
||||
scale,
|
||||
input_type=PyInputType::F64
|
||||
))]
|
||||
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
#[gen_stub_pyfunction]
|
||||
fn float_to_felt(mut input: f64, scale: crate::Scale, input_type: PyInputType) -> PyResult<PyFelt> {
|
||||
InputType::roundtrip(&input_type.into(), &mut input);
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
@@ -414,6 +493,7 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
#[pyfunction(signature = (
|
||||
buffer
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
|
||||
let mut n: u128 = 0;
|
||||
@@ -486,6 +566,7 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
let message: Vec<Fr> = message
|
||||
.iter()
|
||||
@@ -531,6 +612,7 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn kzg_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
@@ -589,6 +671,7 @@ fn kzg_commit(
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn ipa_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
@@ -635,6 +718,7 @@ fn ipa_commit(
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF),
|
||||
witness_path=PathBuf::from(DEFAULT_WITNESS),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
|
||||
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
|
||||
@@ -664,6 +748,7 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
|
||||
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_vk_from_pk_single(
|
||||
path_to_pk: PathBuf,
|
||||
circuit_settings_path: PathBuf,
|
||||
@@ -701,6 +786,7 @@ fn gen_vk_from_pk_single(
|
||||
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
|
||||
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
|
||||
@@ -730,6 +816,7 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
|
||||
model = PathBuf::from(DEFAULT_MODEL),
|
||||
py_run_args = None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
|
||||
let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into();
|
||||
let mut reader = File::open(model).map_err(|_| PyIOError::new_err("Failed to open model"))?;
|
||||
@@ -755,6 +842,7 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
|
||||
srs_path,
|
||||
logrows,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
let params = ezkl_gen_srs::<KZGCommitmentScheme<Bn256>>(logrows as u32);
|
||||
save_params::<KZGCommitmentScheme<Bn256>>(&srs_path, ¶ms)?;
|
||||
@@ -787,6 +875,7 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
srs_path=None,
|
||||
commitment=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn get_srs(
|
||||
py: Python,
|
||||
settings_path: Option<PathBuf>,
|
||||
@@ -799,7 +888,7 @@ fn get_srs(
|
||||
None => None,
|
||||
};
|
||||
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::get_srs_cmd(srs_path, settings_path, logrows, commitment)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -833,6 +922,7 @@ fn get_srs(
|
||||
output=PathBuf::from(DEFAULT_SETTINGS),
|
||||
py_run_args = None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_settings(
|
||||
model: PathBuf,
|
||||
output: PathBuf,
|
||||
@@ -888,6 +978,7 @@ fn gen_settings(
|
||||
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
|
||||
max_logrows = None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn calibrate_settings(
|
||||
py: Python,
|
||||
data: PathBuf,
|
||||
@@ -899,7 +990,7 @@ fn calibrate_settings(
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::calibrate(
|
||||
model,
|
||||
data,
|
||||
@@ -951,6 +1042,7 @@ fn calibrate_settings(
|
||||
vk_path=None,
|
||||
srs_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn gen_witness(
|
||||
py: Python,
|
||||
data: PathBuf,
|
||||
@@ -959,7 +1051,7 @@ fn gen_witness(
|
||||
vk_path: Option<PathBuf>,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
let output = crate::execute::gen_witness(model, data, output, vk_path, srs_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -988,6 +1080,7 @@ fn gen_witness(
|
||||
witness=PathBuf::from(DEFAULT_WITNESS),
|
||||
model=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
|
||||
crate::execute::mock(model, witness).map_err(|e| {
|
||||
let err_str = format!("Failed to run mock: {}", e);
|
||||
@@ -1018,6 +1111,7 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
split_proofs = false,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn mock_aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
logrows: u32,
|
||||
@@ -1065,6 +1159,7 @@ fn mock_aggregate(
|
||||
witness_path = None,
|
||||
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup(
|
||||
model: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
@@ -1123,6 +1218,7 @@ fn setup(
|
||||
proof_type=ProofType::default(),
|
||||
srs_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn prove(
|
||||
witness: PathBuf,
|
||||
model: PathBuf,
|
||||
@@ -1178,6 +1274,7 @@ fn prove(
|
||||
srs_path=None,
|
||||
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::<bool>().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn verify(
|
||||
proof_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
@@ -1237,6 +1334,7 @@ fn verify(
|
||||
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
vk_path: PathBuf,
|
||||
@@ -1287,6 +1385,7 @@ fn setup_aggregate(
|
||||
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn compile_circuit(
|
||||
model: PathBuf,
|
||||
compiled_circuit: PathBuf,
|
||||
@@ -1346,6 +1445,7 @@ fn compile_circuit(
|
||||
srs_path=None,
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
proof_path: PathBuf,
|
||||
@@ -1411,6 +1511,7 @@ fn aggregate(
|
||||
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
|
||||
srs_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn verify_aggr(
|
||||
proof_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
@@ -1458,6 +1559,7 @@ fn verify_aggr(
|
||||
calldata=PathBuf::from(DEFAULT_CALLDATA),
|
||||
addr_vk=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn encode_evm_calldata<'a>(
|
||||
proof: PathBuf,
|
||||
calldata: PathBuf,
|
||||
@@ -1510,6 +1612,7 @@ fn encode_evm_calldata<'a>(
|
||||
srs_path=None,
|
||||
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_verifier(
|
||||
py: Python,
|
||||
vk_path: PathBuf,
|
||||
@@ -1519,7 +1622,7 @@ fn create_evm_verifier(
|
||||
srs_path: Option<PathBuf>,
|
||||
reusable: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
@@ -1569,6 +1672,7 @@ fn create_evm_verifier(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_vka(
|
||||
py: Python,
|
||||
vk_path: PathBuf,
|
||||
@@ -1577,7 +1681,7 @@ fn create_evm_vka(
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_vka(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1616,6 +1720,7 @@ fn create_evm_vka(
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_DA_ABI),
|
||||
witness_path=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_data_attestation(
|
||||
py: Python,
|
||||
input_data: PathBuf,
|
||||
@@ -1624,7 +1729,7 @@ fn create_evm_data_attestation(
|
||||
abi_path: PathBuf,
|
||||
witness_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_data_attestation(
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
@@ -1676,6 +1781,7 @@ fn create_evm_data_attestation(
|
||||
output_source,
|
||||
rpc_url=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn setup_test_evm_witness(
|
||||
py: Python,
|
||||
data_path: PathBuf,
|
||||
@@ -1685,7 +1791,7 @@ fn setup_test_evm_witness(
|
||||
output_source: PyTestDataSource,
|
||||
rpc_url: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::setup_test_evm_witness(
|
||||
data_path,
|
||||
compiled_circuit_path,
|
||||
@@ -1713,6 +1819,7 @@ fn setup_test_evm_witness(
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn deploy_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
@@ -1722,7 +1829,7 @@ fn deploy_evm(
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_evm(
|
||||
sol_code_path,
|
||||
rpc_url,
|
||||
@@ -1751,6 +1858,7 @@ fn deploy_evm(
|
||||
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
|
||||
private_key=None
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn deploy_da_evm(
|
||||
py: Python,
|
||||
addr_path: PathBuf,
|
||||
@@ -1761,7 +1869,7 @@ fn deploy_da_evm(
|
||||
optimizer_runs: usize,
|
||||
private_key: Option<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::deploy_da_evm(
|
||||
input_data,
|
||||
settings_path,
|
||||
@@ -1809,6 +1917,7 @@ fn deploy_da_evm(
|
||||
addr_da = None,
|
||||
addr_vk = None,
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn verify_evm<'a>(
|
||||
py: Python<'a>,
|
||||
addr_verifier: &'a str,
|
||||
@@ -1831,7 +1940,7 @@ fn verify_evm<'a>(
|
||||
None
|
||||
};
|
||||
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::verify_evm(proof_path, addr_verifier, rpc_url, addr_da, addr_vk)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -1881,6 +1990,7 @@ fn verify_evm<'a>(
|
||||
srs_path=None,
|
||||
reusable = DEFAULT_RENDER_REUSABLE.parse().unwrap(),
|
||||
))]
|
||||
#[gen_stub_pyfunction]
|
||||
fn create_evm_verifier_aggr(
|
||||
py: Python,
|
||||
aggregation_settings: Vec<PathBuf>,
|
||||
@@ -1891,7 +2001,7 @@ fn create_evm_verifier_aggr(
|
||||
srs_path: Option<PathBuf>,
|
||||
reusable: bool,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_aggregate_verifier(
|
||||
vk_path,
|
||||
srs_path,
|
||||
@@ -1911,15 +2021,19 @@ fn create_evm_verifier_aggr(
|
||||
})
|
||||
}
|
||||
|
||||
// Define a function to gather stub information.
|
||||
define_stub_info_gatherer!(stub_info);
|
||||
|
||||
// Python Module
|
||||
#[pymodule]
|
||||
fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
pyo3_log::init();
|
||||
m.add_class::<PyRunArgs>()?;
|
||||
m.add_class::<PyG1Affine>()?;
|
||||
m.add_class::<PyG1>()?;
|
||||
m.add_class::<PyTestDataSource>()?;
|
||||
m.add_class::<PyCommitments>()?;
|
||||
m.add_class::<PyInputType>()?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
|
||||
@@ -1958,3 +2072,48 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(encode_evm_calldata, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for CalibrationTarget {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for ProofType {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for TranscriptType {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for CheckMode {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl pyo3_stub_gen::PyStubType for ContractType {
|
||||
fn type_output() -> TypeInfo {
|
||||
TypeInfo {
|
||||
name: "str".to_string(),
|
||||
import: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,10 +141,11 @@ pub(crate) fn gen_vk(
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?;
|
||||
|
||||
let mut serialized_vk = Vec::new();
|
||||
vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes)
|
||||
.map_err(|e| {
|
||||
EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e))
|
||||
})?;
|
||||
vk.write(
|
||||
&mut serialized_vk,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)))?;
|
||||
|
||||
Ok(serialized_vk)
|
||||
}
|
||||
@@ -165,7 +166,7 @@ pub(crate) fn gen_pk(
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit.settings().clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
|
||||
@@ -197,7 +198,7 @@ pub(crate) fn verify(
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings.clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
|
||||
@@ -277,7 +278,7 @@ pub(crate) fn verify_aggr(
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?;
|
||||
@@ -365,7 +366,7 @@ pub(crate) fn prove(
|
||||
let mut reader = BufReader::new(&pk[..]);
|
||||
let pk = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit.settings().clone(),
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
|
||||
@@ -487,7 +488,7 @@ pub(crate) fn vk_validation(vk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
|
||||
let mut reader = BufReader::new(&vk[..]);
|
||||
let _ = VerifyingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?;
|
||||
@@ -504,7 +505,7 @@ pub(crate) fn pk_validation(pk: Vec<u8>, settings: Vec<u8>) -> Result<bool, EZKL
|
||||
let mut reader = BufReader::new(&pk[..]);
|
||||
let _ = ProvingKey::<G1Affine>::read::<_, GraphCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
halo2_proofs::SerdeFormat::RawBytesUnchecked,
|
||||
circuit_settings,
|
||||
)
|
||||
.map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?;
|
||||
|
||||
@@ -22,6 +22,7 @@ use halo2curves::{
|
||||
bn256::{Bn256, Fr, G1Affine},
|
||||
ff::PrimeField,
|
||||
};
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
|
||||
@@ -113,9 +114,15 @@ pub fn feltToFloat(
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn floatToFelt(
|
||||
input: f64,
|
||||
mut input: f64,
|
||||
scale: crate::Scale,
|
||||
input_type: &str,
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
crate::circuit::InputType::roundtrip(
|
||||
&crate::circuit::InputType::from_str(input_type)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?,
|
||||
&mut input,
|
||||
);
|
||||
let int_rep =
|
||||
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
|
||||
@@ -8,10 +8,9 @@ use halo2_proofs::{
|
||||
use log::debug;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
conversion::{FromPyObject, PyTryFrom},
|
||||
conversion::{FromPyObject, IntoPy},
|
||||
exceptions::PyValueError,
|
||||
prelude::*,
|
||||
types::PyString,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
@@ -139,10 +138,9 @@ impl IntoPy<PyObject> for CheckMode {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains CheckMode from PyObject (Required for CheckMode to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for CheckMode {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
match strval.to_lowercase().as_str() {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let trystr = String::extract_bound(ob)?;
|
||||
match trystr.to_lowercase().as_str() {
|
||||
"safe" => Ok(CheckMode::SAFE),
|
||||
"unsafe" => Ok(CheckMode::UNSAFE),
|
||||
_ => Err(PyValueError::new_err("Invalid value for CheckMode")),
|
||||
@@ -161,8 +159,8 @@ impl IntoPy<PyObject> for Tolerance {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Tolerance from PyObject (Required for Tolerance to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Tolerance {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
if let Ok((val, scale)) = ob.extract::<(f32, f32)>() {
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
if let Ok((val, scale)) = <(f32, f32)>::extract_bound(ob) {
|
||||
Ok(Tolerance {
|
||||
val,
|
||||
scale: utils::F32(scale),
|
||||
|
||||
@@ -97,4 +97,7 @@ pub enum CircuitError {
|
||||
/// Invalid scale
|
||||
#[error("negative scale for an op that requires positive inputs {0}")]
|
||||
NegativeScale(String),
|
||||
#[error("invalid input type {0}")]
|
||||
/// Invalid input type
|
||||
InvalidInputType(String),
|
||||
}
|
||||
|
||||
@@ -1687,6 +1687,7 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
// assumes unique values in fullset
|
||||
pub(crate) fn get_missing_set_elements<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
@@ -1699,6 +1700,8 @@ pub(crate) fn get_missing_set_elements<
|
||||
let set_len = fullset.len();
|
||||
input.flatten();
|
||||
|
||||
// while fullset is less than len of input concat
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !fullset.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
|
||||
@@ -105,7 +105,10 @@ impl InputType {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone>(&self, input: &mut T) {
|
||||
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone + std::fmt::Debug>(
|
||||
&self,
|
||||
input: &mut T,
|
||||
) {
|
||||
match self {
|
||||
InputType::Bool => {
|
||||
let boolean_input = input.clone().to_i64().unwrap();
|
||||
@@ -118,7 +121,7 @@ impl InputType {
|
||||
*input = T::from_f32(f32_input).unwrap();
|
||||
}
|
||||
InputType::F32 => {
|
||||
let f32_input = input.clone().to_f32().unwrap();
|
||||
let f32_input: f32 = input.clone().to_f32().unwrap();
|
||||
*input = T::from_f32(f32_input).unwrap();
|
||||
}
|
||||
InputType::F64 => {
|
||||
@@ -133,6 +136,22 @@ impl InputType {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for InputType {
|
||||
type Err = CircuitError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"bool" => Ok(InputType::Bool),
|
||||
"f16" => Ok(InputType::F16),
|
||||
"f32" => Ok(InputType::F32),
|
||||
"f64" => Ok(InputType::F64),
|
||||
"int" => Ok(InputType::Int),
|
||||
"tdim" => Ok(InputType::TDim),
|
||||
e => Err(CircuitError::InvalidInputType(e.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Input {
|
||||
|
||||
@@ -211,7 +211,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.min_lookup_inputs().to_string().green(),
|
||||
self.max_range_size().to_string().green(),
|
||||
self.dynamic_lookup_col_coord().to_string().green(),
|
||||
self.shuffle_col_coord().to_string().green(),
|
||||
self.shuffle_col_coord().to_string().green(),
|
||||
self.max_dynamic_input_len().to_string().green()
|
||||
);
|
||||
}
|
||||
@@ -474,7 +474,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min forcefully
|
||||
/// Update the max and min forcefully
|
||||
pub fn update_max_min_lookup_inputs_force(
|
||||
&mut self,
|
||||
min: IntegerRep,
|
||||
@@ -611,7 +611,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<(ValTensor<F>, usize), CircuitError> {
|
||||
|
||||
self.update_max_dynamic_input_len(values.len());
|
||||
|
||||
if let Some(region) = &self.region {
|
||||
|
||||
122
src/commands.rs
122
src/commands.rs
@@ -2,12 +2,7 @@ use alloy::primitives::Address as H160;
|
||||
use clap::{Command, Parser, Subcommand};
|
||||
use clap_complete::{generate, Generator, Shell};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
conversion::{FromPyObject, PyTryFrom},
|
||||
exceptions::PyValueError,
|
||||
prelude::*,
|
||||
types::PyString,
|
||||
};
|
||||
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
@@ -109,8 +104,8 @@ impl IntoPy<PyObject> for TranscriptType {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains TranscriptType from PyObject (Required for TranscriptType to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for TranscriptType {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let trystr = String::extract_bound(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
match strval.to_lowercase().as_str() {
|
||||
"poseidon" => Ok(TranscriptType::Poseidon),
|
||||
@@ -196,9 +191,7 @@ pub enum ContractType {
|
||||
|
||||
impl Default for ContractType {
|
||||
fn default() -> Self {
|
||||
ContractType::Verifier {
|
||||
reusable: false,
|
||||
}
|
||||
ContractType::Verifier { reusable: false }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,10 +203,8 @@ impl std::fmt::Display for ContractType {
|
||||
match self {
|
||||
ContractType::Verifier { reusable: true } => {
|
||||
"verifier/reusable".to_string()
|
||||
},
|
||||
ContractType::Verifier {
|
||||
reusable: false,
|
||||
} => "verifier".to_string(),
|
||||
}
|
||||
ContractType::Verifier { reusable: false } => "verifier".to_string(),
|
||||
ContractType::VerifyingKeyArtifact => "vka".to_string(),
|
||||
}
|
||||
)
|
||||
@@ -241,7 +232,6 @@ impl From<&str> for ContractType {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
|
||||
/// wrapper for H160 to make it easy to parse into flag vals
|
||||
pub struct H160Flag {
|
||||
@@ -287,9 +277,8 @@ impl IntoPy<PyObject> for CalibrationTarget {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains CalibrationTarget from PyObject (Required for CalibrationTarget to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
match strval.to_lowercase().as_str() {
|
||||
"resources" => Ok(CalibrationTarget::Resources {
|
||||
col_overflow: false,
|
||||
@@ -306,12 +295,8 @@ impl<'source> FromPyObject<'source> for CalibrationTarget {
|
||||
impl IntoPy<PyObject> for ContractType {
|
||||
fn into_py(self, py: Python) -> PyObject {
|
||||
match self {
|
||||
ContractType::Verifier { reusable: true } => {
|
||||
"verifier/reusable".to_object(py)
|
||||
}
|
||||
ContractType::Verifier {
|
||||
reusable: false,
|
||||
} => "verifier".to_object(py),
|
||||
ContractType::Verifier { reusable: true } => "verifier/reusable".to_object(py),
|
||||
ContractType::Verifier { reusable: false } => "verifier".to_object(py),
|
||||
ContractType::VerifyingKeyArtifact => "vka".to_object(py),
|
||||
}
|
||||
}
|
||||
@@ -320,13 +305,10 @@ impl IntoPy<PyObject> for ContractType {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains ContractType from PyObject (Required for ContractType to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for ContractType {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
match strval.to_lowercase().as_str() {
|
||||
"verifier" => Ok(ContractType::Verifier {
|
||||
reusable: false,
|
||||
}),
|
||||
"verifier" => Ok(ContractType::Verifier { reusable: false }),
|
||||
"verifier/reusable" => Ok(ContractType::Verifier { reusable: true }),
|
||||
"vka" => Ok(ContractType::VerifyingKeyArtifact),
|
||||
_ => Err(PyValueError::new_err("Invalid value for ContractType")),
|
||||
@@ -341,45 +323,45 @@ pub fn get_styles() -> clap::builder::Styles {
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(
|
||||
clap::builder::styling::AnsiColor::Cyan,
|
||||
))),
|
||||
)
|
||||
.header(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Cyan))),
|
||||
)
|
||||
.literal(
|
||||
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta))),
|
||||
)
|
||||
.invalid(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
|
||||
)
|
||||
.error(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red))),
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(
|
||||
clap::builder::styling::AnsiColor::Cyan,
|
||||
))),
|
||||
)
|
||||
.literal(clap::builder::styling::Style::new().fg_color(Some(
|
||||
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Magenta),
|
||||
)))
|
||||
.invalid(clap::builder::styling::Style::new().bold().fg_color(Some(
|
||||
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red),
|
||||
)))
|
||||
.error(clap::builder::styling::Style::new().bold().fg_color(Some(
|
||||
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Red),
|
||||
)))
|
||||
.valid(
|
||||
clap::builder::styling::Style::new()
|
||||
.bold()
|
||||
.underline()
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::Green))),
|
||||
)
|
||||
.placeholder(
|
||||
clap::builder::styling::Style::new().fg_color(Some(clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White))),
|
||||
.fg_color(Some(clap::builder::styling::Color::Ansi(
|
||||
clap::builder::styling::AnsiColor::Green,
|
||||
))),
|
||||
)
|
||||
.placeholder(clap::builder::styling::Style::new().fg_color(Some(
|
||||
clap::builder::styling::Color::Ansi(clap::builder::styling::AnsiColor::White),
|
||||
)))
|
||||
}
|
||||
|
||||
|
||||
/// Print completions for the given generator
|
||||
pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
|
||||
generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout());
|
||||
}
|
||||
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, about, long_about = None)]
|
||||
@@ -393,7 +375,6 @@ pub struct Cli {
|
||||
pub command: Option<Commands>,
|
||||
}
|
||||
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
|
||||
pub enum Commands {
|
||||
@@ -443,7 +424,7 @@ pub enum Commands {
|
||||
},
|
||||
|
||||
/// Calibrates the proving scale, lookup bits and logrows from a circuit settings file.
|
||||
CalibrateSettings {
|
||||
CalibrateSettings {
|
||||
/// The path to the .json calibration data file.
|
||||
#[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)]
|
||||
data: Option<PathBuf>,
|
||||
@@ -490,7 +471,7 @@ pub enum Commands {
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
|
||||
/// Gets an SRS from a circuit settings file.
|
||||
/// Gets an SRS from a circuit settings file.
|
||||
#[command(name = "get-srs")]
|
||||
GetSrs {
|
||||
/// The path to output the desired srs file, if set to None will save to ~/.ezkl/srs
|
||||
@@ -575,7 +556,7 @@ pub enum Commands {
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = TranscriptType::default(),
|
||||
value_enum,
|
||||
value_enum,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
transcript: TranscriptType,
|
||||
@@ -625,7 +606,7 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)]
|
||||
disable_selector_compression: Option<bool>,
|
||||
},
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
#[command(arg_required_else_help = true)]
|
||||
SetupTestEvmData {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
@@ -649,7 +630,7 @@ pub enum Commands {
|
||||
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
|
||||
output_source: TestDataSource,
|
||||
},
|
||||
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
|
||||
/// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data.
|
||||
#[command(arg_required_else_help = true)]
|
||||
TestUpdateAccountCalls {
|
||||
/// The path to the verifier contract's address
|
||||
@@ -662,7 +643,7 @@ pub enum Commands {
|
||||
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
|
||||
rpc_url: Option<String>,
|
||||
},
|
||||
/// Swaps the positions in the transcript that correspond to commitments
|
||||
/// Swaps the positions in the transcript that correspond to commitments
|
||||
SwapProofCommitments {
|
||||
/// The path to the proof file
|
||||
#[arg(short = 'P', long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)]
|
||||
@@ -672,7 +653,7 @@ pub enum Commands {
|
||||
witness_path: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Loads model, data, and creates proof
|
||||
/// Loads model, data, and creates proof
|
||||
Prove {
|
||||
/// The path to the .json witness file (generated using the gen-witness command)
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)]
|
||||
@@ -694,7 +675,7 @@ pub enum Commands {
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = ProofType::Single,
|
||||
value_enum,
|
||||
value_enum,
|
||||
value_hint = clap::ValueHint::Other
|
||||
)]
|
||||
proof_type: ProofType,
|
||||
@@ -702,7 +683,7 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)]
|
||||
check_mode: Option<CheckMode>,
|
||||
},
|
||||
/// Encodes a proof into evm calldata
|
||||
/// Encodes a proof into evm calldata
|
||||
#[command(name = "encode-evm-calldata")]
|
||||
EncodeEvmCalldata {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
@@ -715,7 +696,7 @@ pub enum Commands {
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_vk: Option<H160Flag>,
|
||||
},
|
||||
/// Creates an Evm verifier for a single proof
|
||||
/// Creates an Evm verifier for a single proof
|
||||
#[command(name = "create-evm-verifier")]
|
||||
CreateEvmVerifier {
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
@@ -737,7 +718,7 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)]
|
||||
reusable: Option<bool>,
|
||||
},
|
||||
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
|
||||
/// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier
|
||||
#[command(name = "create-evm-vka")]
|
||||
CreateEvmVKArtifact {
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
@@ -756,7 +737,7 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)]
|
||||
abi_path: Option<PathBuf>,
|
||||
},
|
||||
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
|
||||
/// Creates an Evm verifier that attests to on-chain inputs for a single proof
|
||||
#[command(name = "create-evm-da")]
|
||||
CreateEvmDataAttestation {
|
||||
/// The path to load circuit settings .json file from (generated using the gen-settings command)
|
||||
@@ -780,7 +761,7 @@ pub enum Commands {
|
||||
witness: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
/// Creates an Evm verifier for an aggregate proof
|
||||
#[command(name = "create-evm-verifier-aggr")]
|
||||
CreateEvmVerifierAggr {
|
||||
/// The path to SRS, if None will use ~/.ezkl/srs/kzg{logrows}.srs
|
||||
@@ -844,7 +825,7 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
|
||||
/// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl
|
||||
DeployEvm {
|
||||
/// The path to the Solidity code (generated using the create-evm-verifier command)
|
||||
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
|
||||
@@ -865,7 +846,7 @@ pub enum Commands {
|
||||
#[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)]
|
||||
contract: ContractType,
|
||||
},
|
||||
/// Deploys an evm verifier that allows for data attestation
|
||||
/// Deploys an evm verifier that allows for data attestation
|
||||
#[command(name = "deploy-evm-da")]
|
||||
DeployEvmDataAttestation {
|
||||
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
|
||||
@@ -890,7 +871,7 @@ pub enum Commands {
|
||||
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
|
||||
private_key: Option<String>,
|
||||
},
|
||||
/// Verifies a proof using a local Evm executor, returning accept or reject
|
||||
/// Verifies a proof using a local Evm executor, returning accept or reject
|
||||
#[command(name = "verify-evm")]
|
||||
VerifyEvm {
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
@@ -918,7 +899,6 @@ pub enum Commands {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
impl Commands {
|
||||
/// Converts the commands to a json string
|
||||
pub fn as_json(&self) -> String {
|
||||
@@ -929,4 +909,4 @@ impl Commands {
|
||||
pub fn from_json(json: &str) -> Self {
|
||||
serde_json::from_str(json).unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -712,7 +712,8 @@ impl ToPyObject for DataSource {
|
||||
DataSource::OnChain(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("rpc_url", &source.rpc).unwrap();
|
||||
dict.set_item("calls_to_accounts", &source.calls).unwrap();
|
||||
dict.set_item("calls_to_accounts", &source.calls.to_object(py))
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
DataSource::DB(source) => {
|
||||
|
||||
@@ -60,7 +60,10 @@ use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDictMethods;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Deref;
|
||||
pub use utilities::*;
|
||||
@@ -343,10 +346,10 @@ impl ToPyObject for GraphWitness {
|
||||
if let Some(processed_inputs) = &self.processed_inputs {
|
||||
//poseidon_hash
|
||||
if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash {
|
||||
insert_poseidon_hash_pydict(dict_inputs, processed_inputs_poseidon_hash).unwrap();
|
||||
insert_poseidon_hash_pydict(&dict_inputs, processed_inputs_poseidon_hash).unwrap();
|
||||
}
|
||||
if let Some(processed_inputs_polycommit) = &processed_inputs.polycommit {
|
||||
insert_polycommit_pydict(dict_inputs, processed_inputs_polycommit).unwrap();
|
||||
insert_polycommit_pydict(&dict_inputs, processed_inputs_polycommit).unwrap();
|
||||
}
|
||||
|
||||
dict.set_item("processed_inputs", dict_inputs).unwrap();
|
||||
@@ -354,10 +357,10 @@ impl ToPyObject for GraphWitness {
|
||||
|
||||
if let Some(processed_params) = &self.processed_params {
|
||||
if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash {
|
||||
insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash).unwrap();
|
||||
insert_poseidon_hash_pydict(&dict_params, processed_params_poseidon_hash).unwrap();
|
||||
}
|
||||
if let Some(processed_params_polycommit) = &processed_params.polycommit {
|
||||
insert_polycommit_pydict(dict_inputs, processed_params_polycommit).unwrap();
|
||||
insert_polycommit_pydict(&dict_params, processed_params_polycommit).unwrap();
|
||||
}
|
||||
|
||||
dict.set_item("processed_params", dict_params).unwrap();
|
||||
@@ -365,10 +368,11 @@ impl ToPyObject for GraphWitness {
|
||||
|
||||
if let Some(processed_outputs) = &self.processed_outputs {
|
||||
if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash {
|
||||
insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash).unwrap();
|
||||
insert_poseidon_hash_pydict(&dict_outputs, processed_outputs_poseidon_hash)
|
||||
.unwrap();
|
||||
}
|
||||
if let Some(processed_outputs_polycommit) = &processed_outputs.polycommit {
|
||||
insert_polycommit_pydict(dict_inputs, processed_outputs_polycommit).unwrap();
|
||||
insert_polycommit_pydict(&dict_outputs, processed_outputs_polycommit).unwrap();
|
||||
}
|
||||
|
||||
dict.set_item("processed_outputs", dict_outputs).unwrap();
|
||||
@@ -379,7 +383,10 @@ impl ToPyObject for GraphWitness {
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Result<(), PyErr> {
|
||||
fn insert_poseidon_hash_pydict(
|
||||
pydict: &Bound<'_, PyDict>,
|
||||
poseidon_hash: &Vec<Fp>,
|
||||
) -> Result<(), PyErr> {
|
||||
let poseidon_hash: Vec<String> = poseidon_hash.iter().map(field_to_string).collect();
|
||||
pydict.set_item("poseidon_hash", poseidon_hash)?;
|
||||
|
||||
@@ -387,7 +394,10 @@ fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Resu
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
fn insert_polycommit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
|
||||
fn insert_polycommit_pydict(
|
||||
pydict: &Bound<'_, PyDict>,
|
||||
commits: &Vec<Vec<G1Affine>>,
|
||||
) -> Result<(), PyErr> {
|
||||
use crate::bindings::python::PyG1Affine;
|
||||
let poseidon_hash: Vec<Vec<PyG1Affine>> = commits
|
||||
.iter()
|
||||
|
||||
@@ -656,7 +656,7 @@ impl Model {
|
||||
|
||||
let mut symbol_values = SymbolValues::default();
|
||||
for (symbol, value) in run_args.variables.iter() {
|
||||
let symbol = model.symbol_table.sym(symbol);
|
||||
let symbol = model.symbols.sym(symbol);
|
||||
symbol_values = symbol_values.with(&symbol, *value as i64);
|
||||
debug!("set {} to {}", symbol, value);
|
||||
}
|
||||
@@ -1199,9 +1199,9 @@ impl Model {
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
region.debug_report();
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!(
|
||||
trace!("input indices: {:?}", node.inputs());
|
||||
trace!("output scales: {:?}", node.out_scales());
|
||||
trace!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
@@ -1220,8 +1220,8 @@ impl Model {
|
||||
// we re-assign inputs, always from the 0 outlet
|
||||
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
|
||||
};
|
||||
debug!("output dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
trace!("output dims: {:?}", node.out_dims());
|
||||
trace!(
|
||||
"input dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
|
||||
@@ -1007,21 +1007,21 @@ pub fn new_op_from_onnx(
|
||||
op
|
||||
}
|
||||
"Iff" => SupportedOp::Linear(PolyOp::Iff),
|
||||
"Less" => {
|
||||
"<" => {
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::Less)
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "less".to_string()));
|
||||
}
|
||||
}
|
||||
"LessEqual" => {
|
||||
"<=" => {
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::LessEqual)
|
||||
} else {
|
||||
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
|
||||
}
|
||||
}
|
||||
"Greater" => {
|
||||
">" => {
|
||||
// Extract the slope layer hyperparams
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::Greater)
|
||||
@@ -1029,7 +1029,7 @@ pub fn new_op_from_onnx(
|
||||
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
|
||||
}
|
||||
}
|
||||
"GreaterEqual" => {
|
||||
">=" => {
|
||||
// Extract the slope layer hyperparams
|
||||
if inputs.len() == 2 {
|
||||
SupportedOp::Hybrid(HybridOp::GreaterEqual)
|
||||
@@ -1155,6 +1155,41 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
}
|
||||
"Div" => {
|
||||
let const_idx = inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, n)| n.is_constant())
|
||||
.map(|(i, _)| i)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if const_idx.len() > 1 {
|
||||
return Err(GraphError::InvalidDims(idx, "div".to_string()));
|
||||
}
|
||||
|
||||
let const_idx = const_idx[0];
|
||||
|
||||
if const_idx != 1 {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
}
|
||||
|
||||
if let Some(c) = inputs[const_idx].opkind().get_mutable_constant() {
|
||||
if c.raw_values.len() == 1 && c.raw_values[0] != 0. {
|
||||
inputs[const_idx].decrement_use();
|
||||
deleted_indices.push(const_idx);
|
||||
// get the non constant index
|
||||
let denom = c.raw_values[0];
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Div {
|
||||
denom: denom.into(),
|
||||
})
|
||||
} else {
|
||||
unimplemented!("only support non zero divisors of size 1")
|
||||
}
|
||||
} else {
|
||||
unimplemented!("only support div with constant as second input")
|
||||
}
|
||||
}
|
||||
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
|
||||
"Square" => SupportedOp::Linear(PolyOp::Pow(2)),
|
||||
"Conv" => {
|
||||
@@ -1215,7 +1250,7 @@ pub fn new_op_from_onnx(
|
||||
"And" => SupportedOp::Linear(PolyOp::And),
|
||||
"Or" => SupportedOp::Linear(PolyOp::Or),
|
||||
"Xor" => SupportedOp::Linear(PolyOp::Xor),
|
||||
"Equals" => SupportedOp::Hybrid(HybridOp::Equals),
|
||||
"==" => SupportedOp::Hybrid(HybridOp::Equals),
|
||||
"Deconv" => {
|
||||
let deconv_node: &Deconv = match node.op().downcast_ref::<Deconv>() {
|
||||
Some(b) => b,
|
||||
|
||||
@@ -9,8 +9,7 @@ use itertools::Itertools;
|
||||
use log::debug;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, types::PyString, FromPyObject, IntoPy, PyAny, PyObject, PyResult,
|
||||
PyTryFrom, Python, ToPyObject,
|
||||
exceptions::PyValueError, FromPyObject, IntoPy, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -137,10 +136,8 @@ impl IntoPy<PyObject> for Visibility {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains Visibility from PyObject (Required for Visibility to be compatible with Python)
|
||||
impl<'source> FromPyObject<'source> for Visibility {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let trystr = <PyString as PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
let strval = strval.as_str();
|
||||
|
||||
if strval.contains("hashed/private") {
|
||||
|
||||
@@ -120,7 +120,7 @@ pub fn version() -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
/// Bindings managment
|
||||
/// Bindings management
|
||||
#[cfg(any(
|
||||
feature = "ios-bindings",
|
||||
all(target_arch = "wasm32", target_os = "unknown"),
|
||||
|
||||
@@ -46,6 +46,9 @@ use thiserror::Error as thisError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDictMethods;
|
||||
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
|
||||
fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat {
|
||||
@@ -116,9 +119,8 @@ impl ToPyObject for ProofType {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python)
|
||||
impl<'source> pyo3::FromPyObject<'source> for ProofType {
|
||||
fn extract(ob: &'source pyo3::PyAny) -> pyo3::PyResult<Self> {
|
||||
let trystr = <pyo3::types::PyString as pyo3::PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> pyo3::PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
match strval.to_lowercase().as_str() {
|
||||
"single" => Ok(ProofType::Single),
|
||||
"for-aggr" => Ok(ProofType::ForAggr),
|
||||
@@ -174,9 +176,8 @@ impl pyo3::IntoPy<PyObject> for StrategyType {
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python)
|
||||
impl<'source> pyo3::FromPyObject<'source> for StrategyType {
|
||||
fn extract(ob: &'source pyo3::PyAny) -> pyo3::PyResult<Self> {
|
||||
let trystr = <pyo3::types::PyString as pyo3::PyTryFrom>::try_from(ob)?;
|
||||
let strval = trystr.to_string();
|
||||
fn extract_bound(ob: &pyo3::Bound<'source, pyo3::PyAny>) -> pyo3::PyResult<Self> {
|
||||
let strval = String::extract_bound(ob)?;
|
||||
match strval.to_lowercase().as_str() {
|
||||
"single" => Ok(StrategyType::Single),
|
||||
"accum" => Ok(StrategyType::Accum),
|
||||
@@ -235,7 +236,7 @@ impl ToPyObject for TranscriptType {
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
///
|
||||
pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) {
|
||||
pub fn g1affine_to_pydict(g1affine_dict: &pyo3::Bound<'_, PyDict>, g1affine: &G1Affine) {
|
||||
let g1affine_x = field_to_string(&g1affine.x);
|
||||
let g1affine_y = field_to_string(&g1affine.y);
|
||||
g1affine_dict.set_item("x", g1affine_x).unwrap();
|
||||
@@ -246,7 +247,7 @@ pub fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) {
|
||||
use halo2curves::bn256::G1;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
///
|
||||
pub fn g1_to_pydict(g1_dict: &PyDict, g1: &G1) {
|
||||
pub fn g1_to_pydict(g1_dict: &pyo3::Bound<'_, PyDict>, g1: &G1) {
|
||||
let g1_x = field_to_string(&g1.x);
|
||||
let g1_y = field_to_string(&g1.y);
|
||||
let g1_z = field_to_string(&g1.z);
|
||||
@@ -337,7 +338,7 @@ where
|
||||
dict.set_item("instances", field_elems).unwrap();
|
||||
let hex_proof = hex::encode(&self.proof);
|
||||
dict.set_item("proof", format!("0x{}", hex_proof)).unwrap();
|
||||
dict.set_item("transcript_type", self.transcript_type)
|
||||
dict.set_item("transcript_type", self.transcript_type.to_object(py))
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
|
||||
@@ -1109,6 +1109,13 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
pub fn expand(&self, shape: &[usize]) -> Result<Self, TensorError> {
|
||||
// if both have length 1 then we can just return the tensor
|
||||
if self.dims().iter().product::<usize>() == 1 && shape.iter().product::<usize>() == 1 {
|
||||
let mut output = self.clone();
|
||||
output.reshape(shape)?;
|
||||
return Ok(output);
|
||||
}
|
||||
|
||||
if self.dims().len() > shape.len() {
|
||||
return Err(TensorError::DimError(format!(
|
||||
"Cannot expand {:?} to the smaller shape {:?}",
|
||||
|
||||
@@ -1050,6 +1050,7 @@ pub fn scatter_nd<T: TensorType + Send + Sync>(
|
||||
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
let index_val = index.get_slice(&slice)?;
|
||||
let index_slice = index_val.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
|
||||
let src_val = src.get_slice(&slice)?;
|
||||
output.set_slice(&index_slice, &src_val)?;
|
||||
Ok::<_, TensorError>(())
|
||||
|
||||
@@ -492,7 +492,7 @@ mod native_tests {
|
||||
#[cfg(feature="icicle")]
|
||||
seq!(N in 0..=2 {
|
||||
#(#[test_case(TESTS_AGGR[N])])*
|
||||
fn aggr_prove_and_verify_(test: &str) {
|
||||
fn kzg_aggr_prove_and_verify_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(test_dir.path().to_str().unwrap(), test);
|
||||
@@ -636,7 +636,7 @@ mod native_tests {
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_large_batch_public_outputs_(test: &str) {
|
||||
// currently variable output rank is not supported in ONNX
|
||||
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" {
|
||||
if test != "gather_nd" && test != "lstm_large" && test != "lstm_medium" && test != "scatter_nd" {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
|
||||
@@ -68,6 +68,8 @@ mod py_tests {
|
||||
"install",
|
||||
"torch-geometric==2.5.2",
|
||||
"torch==2.2.2",
|
||||
"datasets==3.2.0",
|
||||
"torchtext==0.17.2",
|
||||
"torchvision==0.17.2",
|
||||
"pandas==2.2.1",
|
||||
"numpy==1.26.4",
|
||||
@@ -189,6 +191,27 @@ mod py_tests {
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
#[test]
|
||||
fn neural_bag_of_words_notebook() {
|
||||
crate::py_tests::init_binary();
|
||||
let test_dir: TempDir = TempDir::new("neural_bow").unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::py_tests::mv_test_(path, "neural_bow.ipynb");
|
||||
run_notebook(path, "neural_bow.ipynb");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felt_conversion_test_notebook() {
|
||||
crate::py_tests::init_binary();
|
||||
let test_dir: TempDir = TempDir::new("felt_conversion_test").unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::py_tests::mv_test_(path, "felt_conversion_test.ipynb");
|
||||
run_notebook(path, "felt_conversion_test.ipynb");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
|
||||
Reference in New Issue
Block a user