Compare commits

...

4 Commits

Author SHA1 Message Date
dante
ff563e93a7 fix: bump python version (#761) 2024-04-02 17:08:26 +01:00
dante
5639d36097 chore: verify aggr wasm unit test (#760) 2024-04-01 20:54:20 +01:00
dante
4ec8d13082 chore: verify aggr in wasm (#758) 2024-03-29 23:28:20 +00:00
dante
12735aefd4 chore: reduce softmax recip DR (#756) 2024-03-27 01:14:29 +00:00
26 changed files with 4079 additions and 153 deletions

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag

View File

@@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -70,7 +70,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set Cargo.toml version to match github tag
@@ -115,7 +115,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -176,7 +176,7 @@ jobs:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.7
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
@@ -228,7 +228,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -263,7 +263,7 @@ jobs:
apk add py3-pip
pip3 install -U pip
python3 -m venv .venv
source .venv/bin/activate
source .venv/bin/activate
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
python3 -c "import ezkl"
@@ -287,7 +287,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
- name: Set Cargo.toml version to match github tag
shell: bash

View File

@@ -184,7 +184,7 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -207,7 +207,7 @@ jobs:
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -224,7 +224,7 @@ jobs:
mock-proving-tests:
runs-on: non-gpu
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -281,7 +281,7 @@ jobs:
prove-and-verify-evm-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -354,7 +354,7 @@ jobs:
prove-and-verify-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -460,7 +460,7 @@ jobs:
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -495,7 +495,7 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -512,7 +512,7 @@ jobs:
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -557,16 +557,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Build python ezkl
@@ -576,12 +578,12 @@ jobs:
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -592,7 +594,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
@@ -612,7 +614,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -626,10 +628,14 @@ jobs:
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -645,7 +651,5 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

3
.gitignore vendored
View File

@@ -48,4 +48,5 @@ node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/wasm/vk.key
!tests/wasm/vk_aggr.key

4
Cargo.lock generated
View File

@@ -4601,9 +4601,9 @@ dependencies = [
[[package]]
name = "serde-wasm-bindgen"
version = "0.4.5"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b4c031cd0d9014307d82b8abf653c0290fbdaeb4c02d00c63cf52f728628bf"
checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b"
dependencies = [
"js-sys",
"serde",

View File

@@ -95,10 +95,10 @@ getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-bindgen-test = "0.3.42"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"

View File

@@ -67,6 +67,7 @@
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
"model.output_names=['output']\n",
"\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",

View File

@@ -38,7 +38,7 @@
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers.legacy import Adam\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
@@ -71,9 +71,11 @@
},
"outputs": [],
"source": [
"opt = Adam()\n",
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
@@ -114,8 +116,11 @@
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
@@ -415,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -349,6 +349,8 @@
"z_log_var = Dense(ZDIM)(x)\n",
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
"dec = get_decoder()\n",
"dec.output_names=['output']\n",
"\n",
"out = dec(z)\n",
"\n",
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",

View File

@@ -61,11 +61,10 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"from torch import nn\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"\n",
@@ -77,28 +76,12 @@
"clr.fit(X_train, y_train)\n",
"\n",
"\n",
"trees = []\n",
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
" super(RandomForest, self).__init__()\n",
" self.trees = nn.ModuleList(trees)\n",
"\n",
" def forward(self, x):\n",
" out = self.trees[0](x)\n",
" for tree in self.trees[1:]:\n",
" out += tree(x)\n",
" return out / len(self.trees)\n",
"\n",
"\n",
"torch_rf = RandomForest(trees)\n",
"torch_rf = convert(clr, 'torch')\n",
"# assert predictions from torch are = to sklearn \n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
@@ -134,14 +117,12 @@
"\n",
"# export to onnx format\n",
"\n",
"torch_rf.eval()\n",
"\n",
"# Input to the model\n",
"shape = X_train.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=False)\n",
"torch_out = torch_rf(x)\n",
"torch_out = torch_rf.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(torch_rf, # model being run\n",
"torch.onnx.export(torch_rf.model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
@@ -158,7 +139,7 @@
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -321,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -227,6 +227,10 @@
" self.length = self.compute_length(self.file_good)\n",
" self.data = self.load_data(self.file_good)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -749,7 +753,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -209,6 +209,11 @@
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
" self.data = self.load_data(self.file_good, self.file_bad)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -637,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -1,14 +1,14 @@
attrs==22.2.0
exceptiongroup==1.1.1
importlib-metadata==6.1.0
attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.0.1
packaging==23.0
pluggy==1.0.0
pytest==7.2.2
maturin==1.5.0
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
tomli==2.0.1
typing-extensions==4.5.0
zipp==3.15.0
onnx==1.14.1
onnxruntime==1.14.1
numpy==1.21.6
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -568,10 +568,10 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let is_assigned = !input.any_unknowns()?;
let sorted = if is_assigned {
input
.get_int_evals()?
.iter()
.sorted_by(|a, b| a.cmp(b))
let mut int_evals = input.get_int_evals()?;
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
int_evals
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
} else {
@@ -753,20 +753,28 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookups.tables[1], &table_1)?;
let table_len = table_0.len();
trace!("assigning tables took: {:?}", start.elapsed());
// now create a vartensor of constants for the dynamic lookup index
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
let _table_index =
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
trace!("assigning table index took: {:?}", start.elapsed());
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();
trace!("assigning lookups took: {:?}", start.elapsed());
// now set the lookup index
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
trace!("assigning lookup index took: {:?}", start.elapsed());
if !region.is_dummy() {
(0..table_len)
.map(|i| {
@@ -3251,11 +3259,15 @@ pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// get the max then subtract it
let max_val = max(config, region, values)?;
// rebase the input to 0
let sub = pairwise(config, region, &[values[0].clone(), max_val], BaseOp::Sub)?;
// elementwise exponential
let ex = nonlinearity(
config,
region,
values,
&[sub],
&LookupOp::Exp { scale: input_scale },
)?;

View File

@@ -163,7 +163,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants.into_iter());
self.assigned_constants.extend(constants);
}
///
@@ -389,7 +389,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants.into_iter());
constants.extend(local_reg.assigned_constants);
res
})
@@ -574,8 +574,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -599,8 +601,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -630,9 +634,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self.assigned_constants,
)
} else {
let mut values_map = values.create_constants_map();
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {

View File

@@ -911,7 +911,7 @@ pub(crate) fn calibrate(
let model = Model::from_run_args(&settings.run_args, &model_path)?;
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
debug!("num of calibration batches: {}", chunks.len());
info!("num calibration batches: {}", chunks.len());
debug!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(

View File

@@ -39,7 +39,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
@@ -1451,7 +1451,8 @@ impl GraphCircuit {
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
/// The configuration for the graph circuit
pub struct CircuitSize {
num_instances: usize,
num_advice_columns: usize,
num_fixed: usize,
@@ -1461,7 +1462,8 @@ struct CircuitSize {
}
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
///
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),

View File

@@ -1,4 +1,8 @@
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::CircuitSize;
use crate::pfsys::{Snark, SnarkWitness};
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::plonk::{self};
use halo2_proofs::{
@@ -16,6 +20,8 @@ use halo2_wrong_ecc::{
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::debug;
use log::trace;
use rand::rngs::OsRng;
use snark_verifier::loader::native::NativeLoader;
@@ -193,6 +199,23 @@ impl AggregationConfig {
let main_gate_config = MainGate::<F>::configure(meta);
let range_config =
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
#[cfg(not(target_arch = "wasm32"))]
{
let circuit_size = CircuitSize::from_cs(meta, 23);
// not wasm
debug!(
"circuit size: \n {}",
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
.unwrap()
);
}
AggregationConfig {
main_gate_config,
range_config,

View File

@@ -448,25 +448,39 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Returns the number of constants in the [ValTensor].
pub fn num_constants(&self) -> usize {
pub fn create_constants_map_iterator(
&self,
) -> core::iter::FilterMap<
core::slice::Iter<'_, ValType<F>>,
fn(&ValType<F>) -> Option<(F, ValType<F>)>,
> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
ValTensor::Instance { .. } => 0,
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
}),
ValTensor::Instance { .. } => {
unreachable!("Instance tensors do not have constants")
}
}
}
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map(&self) -> ConstantsMap<F> {
match self {
ValTensor::Value { inner, .. } => {
let map = inner.iter().fold(ConstantsMap::new(), |mut acc, x| {
if let ValType::Constant(c) = x {
acc.insert(*c, x.clone());
ValTensor::Value { inner, .. } => inner
.par_iter()
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
acc
});
map
}
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
}
}

View File

@@ -8,12 +8,14 @@ use crate::graph::quantize_float;
use crate::graph::scale_to_multiplier;
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::create_proof_circuit;
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::pfsys::verify_proof_circuit;
use crate::pfsys::TranscriptType;
use crate::tensor::TensorType;
use crate::CheckMode;
use crate::Commitments;
use console_error_panic_hook;
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
@@ -33,11 +35,10 @@ use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, PrimeField};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use console_error_panic_hook;
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
@@ -395,6 +396,90 @@ pub fn verify(
}
}
#[wasm_bindgen]
#[allow(non_snake_case)]
/// Verify aggregate proof in browser using wasm
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
let orig_n = 1 << logrows;
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(JsError::new(&format!("{}", e))),
}
}
/// Prove in browser using wasm
#[wasm_bindgen]
pub fn prove(

View File

@@ -56,33 +56,38 @@ mod py_tests {
// source .env/bin/activate
// pip install -r requirements.txt
// maturin develop --release --features python-bindings
// first install tf2onnx as it has protobuf conflict with onnx
let status = Command::new("pip")
.args(["install", "tf2onnx==1.16.1"])
.status()
.expect("failed to execute process");
assert!(status.success());
// now install torch, pandas, numpy, seaborn, jupyter
let status = Command::new("pip")
.args([
"install",
"torch-geometric==2.5.0",
"torch==2.0.1",
"torchvision==0.15.2",
"pandas==2.0.3",
"numpy==1.23",
"seaborn==0.12.2",
"jupyter==1.0.0",
"onnx==1.14.0",
"kaggle==1.5.15",
"py-solc-x==1.1.1",
"web3==6.5.0",
"librosa==0.10.0.post2",
"keras==2.12.0",
"tensorflow==2.12.0",
"tensorflow-datasets==4.9.3",
"tf2onnx==1.14.0",
"pytorch-lightning==2.0.6",
"torch-geometric==2.5.2",
"torch==2.2.2",
"torchvision==0.17.2",
"pandas==2.2.1",
"numpy==1.26.4",
"seaborn==0.13.2",
"notebook==7.1.2",
"nbconvert==7.16.3",
"onnx==1.16.0",
"kaggle==1.6.8",
"py-solc-x==2.0.2",
"web3==6.16.0",
"librosa==0.10.1",
"keras==3.1.1",
"tensorflow==2.16.1",
"tensorflow-datasets==4.9.4",
"pytorch-lightning==2.2.1",
"sk2torch==1.2.0",
"scikit-learn==1.3.1",
"xgboost==1.7.6",
"hummingbird-ml==0.4.9",
"lightgbm==4.0.0",
"scikit-learn==1.4.1.post1",
"xgboost==2.0.3",
"hummingbird-ml==0.4.11",
"lightgbm==4.3.0",
])
.status()
.expect("failed to execute process");

View File

@@ -11,7 +11,7 @@ mod wasm32 {
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
@@ -27,10 +27,29 @@ mod wasm32 {
pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx");
pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json");
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json");
pub const PROOF_AGGR: &[u8] = include_bytes!("../tests/wasm/proof_aggr.json");
pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json");
pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key");
pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key");
pub const VK_AGGR: &[u8] = include_bytes!("../tests/wasm/vk_aggr.key");
pub const SRS: &[u8] = include_bytes!("../tests/wasm/kzg");
pub const SRS1: &[u8] = include_bytes!("../tests/wasm/kzg1.srs");
#[wasm_bindgen_test]
async fn can_verify_aggr() {
let value = verifyAggr(
wasm_bindgen::Clamped(PROOF_AGGR.to_vec()),
wasm_bindgen::Clamped(VK_AGGR.to_vec()),
21,
wasm_bindgen::Clamped(SRS1.to_vec()),
"kzg",
)
.map_err(|_| "failed")
.unwrap();
// should not fail
assert!(value);
}
#[wasm_bindgen_test]
async fn verify_encode_verifier_calldata() {

BIN
tests/wasm/kzg1.srs Normal file

Binary file not shown.

3075
tests/wasm/proof_aggr.json Normal file

File diff suppressed because one or more lines are too long

BIN
tests/wasm/vk_aggr.key Normal file

Binary file not shown.