Compare commits

..

1 Commits

Author SHA1 Message Date
dante
5639d36097 chore: verify aggr wasm unit test (#760) 2024-04-01 20:54:20 +01:00
21 changed files with 3929 additions and 117 deletions

View File

@@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -70,7 +70,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set Cargo.toml version to match github tag
@@ -115,7 +115,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -176,7 +176,7 @@ jobs:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.7
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
@@ -228,7 +228,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -263,7 +263,7 @@ jobs:
apk add py3-pip
pip3 install -U pip
python3 -m venv .venv
source .venv/bin/activate
source .venv/bin/activate
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
python3 -c "import ezkl"
@@ -287,7 +287,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
- name: Set Cargo.toml version to match github tag
shell: bash

View File

@@ -557,12 +557,14 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
@@ -581,7 +583,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -612,7 +614,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -630,6 +632,8 @@ jobs:
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -645,7 +649,5 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

3
.gitignore vendored
View File

@@ -48,4 +48,5 @@ node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/wasm/vk.key
!tests/wasm/vk_aggr.key

4
Cargo.lock generated
View File

@@ -4601,9 +4601,9 @@ dependencies = [
[[package]]
name = "serde-wasm-bindgen"
version = "0.4.5"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b4c031cd0d9014307d82b8abf653c0290fbdaeb4c02d00c63cf52f728628bf"
checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b"
dependencies = [
"js-sys",
"serde",

View File

@@ -95,10 +95,10 @@ getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-bindgen-test = "0.3.42"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"

View File

@@ -67,6 +67,7 @@
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
"model.output_names=['output']\n",
"\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",

View File

@@ -38,7 +38,7 @@
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers.legacy import Adam\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
@@ -71,9 +71,11 @@
},
"outputs": [],
"source": [
"opt = Adam()\n",
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
@@ -114,8 +116,11 @@
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
@@ -415,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -349,6 +349,8 @@
"z_log_var = Dense(ZDIM)(x)\n",
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
"dec = get_decoder()\n",
"dec.output_names=['output']\n",
"\n",
"out = dec(z)\n",
"\n",
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",

View File

@@ -61,11 +61,10 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"from torch import nn\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"\n",
@@ -77,28 +76,12 @@
"clr.fit(X_train, y_train)\n",
"\n",
"\n",
"trees = []\n",
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
" super(RandomForest, self).__init__()\n",
" self.trees = nn.ModuleList(trees)\n",
"\n",
" def forward(self, x):\n",
" out = self.trees[0](x)\n",
" for tree in self.trees[1:]:\n",
" out += tree(x)\n",
" return out / len(self.trees)\n",
"\n",
"\n",
"torch_rf = RandomForest(trees)\n",
"torch_rf = convert(clr, 'torch')\n",
"# assert predictions from torch are = to sklearn \n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
@@ -134,14 +117,12 @@
"\n",
"# export to onnx format\n",
"\n",
"torch_rf.eval()\n",
"\n",
"# Input to the model\n",
"shape = X_train.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=False)\n",
"torch_out = torch_rf(x)\n",
"torch_out = torch_rf.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(torch_rf, # model being run\n",
"torch.onnx.export(torch_rf.model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
@@ -158,7 +139,7 @@
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -321,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -227,6 +227,10 @@
" self.length = self.compute_length(self.file_good)\n",
" self.data = self.load_data(self.file_good)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -749,7 +753,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -209,6 +209,11 @@
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
" self.data = self.load_data(self.file_good, self.file_bad)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -637,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -1,14 +1,14 @@
attrs==22.2.0
exceptiongroup==1.1.1
importlib-metadata==6.1.0
attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.0.1
packaging==23.0
pluggy==1.0.0
pytest==7.2.2
maturin==1.5.0
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
tomli==2.0.1
typing-extensions==4.5.0
zipp==3.15.0
onnx==1.14.1
onnxruntime==1.14.1
numpy==1.21.6
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -39,7 +39,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
@@ -1451,7 +1451,8 @@ impl GraphCircuit {
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
/// The configuration for the graph circuit
pub struct CircuitSize {
num_instances: usize,
num_advice_columns: usize,
num_fixed: usize,
@@ -1461,7 +1462,8 @@ struct CircuitSize {
}
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
///
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),

View File

@@ -1,4 +1,8 @@
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::CircuitSize;
use crate::pfsys::{Snark, SnarkWitness};
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::plonk::{self};
use halo2_proofs::{
@@ -16,6 +20,8 @@ use halo2_wrong_ecc::{
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::debug;
use log::trace;
use rand::rngs::OsRng;
use snark_verifier::loader::native::NativeLoader;
@@ -193,6 +199,23 @@ impl AggregationConfig {
let main_gate_config = MainGate::<F>::configure(meta);
let range_config =
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
#[cfg(not(target_arch = "wasm32"))]
{
let circuit_size = CircuitSize::from_cs(meta, 23);
// not wasm
debug!(
"circuit size: \n {}",
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
.unwrap()
);
}
AggregationConfig {
main_gate_config,
range_config,

View File

@@ -419,6 +419,8 @@ pub fn verifyAggr(
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
let orig_n = 1 << logrows;
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
@@ -433,7 +435,7 @@ pub fn verifyAggr(
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, logrows),
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
@@ -442,7 +444,7 @@ pub fn verifyAggr(
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, logrows)
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
@@ -458,7 +460,7 @@ pub fn verifyAggr(
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, logrows),
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
@@ -466,7 +468,7 @@ pub fn verifyAggr(
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, logrows)
>(&proof, &params, &vk, strategy, orig_n)
}
}
}

View File

@@ -56,33 +56,38 @@ mod py_tests {
// source .env/bin/activate
// pip install -r requirements.txt
// maturin develop --release --features python-bindings
// first install tf2onnx as it has protobuf conflict with onnx
let status = Command::new("pip")
.args(["install", "tf2onnx==1.16.1"])
.status()
.expect("failed to execute process");
assert!(status.success());
// now install torch, pandas, numpy, seaborn, jupyter
let status = Command::new("pip")
.args([
"install",
"torch-geometric==2.5.0",
"torch==2.0.1",
"torchvision==0.15.2",
"pandas==2.0.3",
"numpy==1.23",
"seaborn==0.12.2",
"jupyter==1.0.0",
"onnx==1.14.0",
"kaggle==1.5.15",
"py-solc-x==1.1.1",
"web3==6.5.0",
"librosa==0.10.0.post2",
"keras==2.12.0",
"tensorflow==2.12.0",
"tensorflow-datasets==4.9.3",
"tf2onnx==1.14.0",
"pytorch-lightning==2.0.6",
"torch-geometric==2.5.2",
"torch==2.2.2",
"torchvision==0.17.2",
"pandas==2.2.1",
"numpy==1.26.4",
"seaborn==0.13.2",
"notebook==7.1.2",
"nbconvert==7.16.3",
"onnx==1.16.0",
"kaggle==1.6.8",
"py-solc-x==2.0.2",
"web3==6.16.0",
"librosa==0.10.1",
"keras==3.1.1",
"tensorflow==2.16.1",
"tensorflow-datasets==4.9.4",
"pytorch-lightning==2.2.1",
"sk2torch==1.2.0",
"scikit-learn==1.3.1",
"xgboost==1.7.6",
"hummingbird-ml==0.4.9",
"lightgbm==4.0.0",
"scikit-learn==1.4.1.post1",
"xgboost==2.0.3",
"hummingbird-ml==0.4.11",
"lightgbm==4.3.0",
])
.status()
.expect("failed to execute process");

View File

@@ -11,7 +11,7 @@ mod wasm32 {
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
@@ -27,10 +27,29 @@ mod wasm32 {
pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx");
pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json");
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json");
pub const PROOF_AGGR: &[u8] = include_bytes!("../tests/wasm/proof_aggr.json");
pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json");
pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key");
pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key");
pub const VK_AGGR: &[u8] = include_bytes!("../tests/wasm/vk_aggr.key");
pub const SRS: &[u8] = include_bytes!("../tests/wasm/kzg");
pub const SRS1: &[u8] = include_bytes!("../tests/wasm/kzg1.srs");
#[wasm_bindgen_test]
async fn can_verify_aggr() {
let value = verifyAggr(
wasm_bindgen::Clamped(PROOF_AGGR.to_vec()),
wasm_bindgen::Clamped(VK_AGGR.to_vec()),
21,
wasm_bindgen::Clamped(SRS1.to_vec()),
"kzg",
)
.map_err(|_| "failed")
.unwrap();
// should not fail
assert!(value);
}
#[wasm_bindgen_test]
async fn verify_encode_verifier_calldata() {

BIN
tests/wasm/kzg1.srs Normal file

Binary file not shown.

3075
tests/wasm/proof_aggr.json Normal file

File diff suppressed because one or more lines are too long

BIN
tests/wasm/vk_aggr.key Normal file

Binary file not shown.