Compare commits

...

6 Commits

Author SHA1 Message Date
dante
5389012b68 fix: patch large batch ex (#763) 2024-04-03 02:33:57 +01:00
dante
48223cca11 fix: make commitment optional for backwards compat (#762) 2024-04-03 02:26:50 +01:00
dante
32c3a5e159 fix: hold stacked outputs in a separate map 2024-04-02 21:37:20 +01:00
dante
ff563e93a7 fix: bump python version (#761) 2024-04-02 17:08:26 +01:00
dante
5639d36097 chore: verify aggr wasm unit test (#760) 2024-04-01 20:54:20 +01:00
dante
4ec8d13082 chore: verify aggr in wasm (#758) 2024-03-29 23:28:20 +00:00
33 changed files with 4139 additions and 185 deletions

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag

View File

@@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -70,7 +70,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set Cargo.toml version to match github tag
@@ -115,7 +115,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -176,7 +176,7 @@ jobs:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.7
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
@@ -228,7 +228,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -263,7 +263,7 @@ jobs:
apk add py3-pip
pip3 install -U pip
python3 -m venv .venv
source .venv/bin/activate
source .venv/bin/activate
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
python3 -c "import ezkl"
@@ -287,7 +287,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
- name: Set Cargo.toml version to match github tag
shell: bash

View File

@@ -184,7 +184,7 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -207,7 +207,7 @@ jobs:
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -224,7 +224,7 @@ jobs:
mock-proving-tests:
runs-on: non-gpu
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -281,7 +281,7 @@ jobs:
prove-and-verify-evm-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -354,7 +354,7 @@ jobs:
prove-and-verify-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -460,7 +460,7 @@ jobs:
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -495,7 +495,7 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -512,7 +512,7 @@ jobs:
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -557,16 +557,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Build python ezkl
@@ -576,12 +578,12 @@ jobs:
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -592,7 +594,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
@@ -612,7 +614,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-02-06
@@ -626,10 +628,14 @@ jobs:
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -645,7 +651,5 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

3
.gitignore vendored
View File

@@ -48,4 +48,5 @@ node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/wasm/vk.key
!tests/wasm/vk_aggr.key

4
Cargo.lock generated
View File

@@ -4601,9 +4601,9 @@ dependencies = [
[[package]]
name = "serde-wasm-bindgen"
version = "0.4.5"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b4c031cd0d9014307d82b8abf653c0290fbdaeb4c02d00c63cf52f728628bf"
checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b"
dependencies = [
"js-sys",
"serde",

View File

@@ -95,10 +95,10 @@ getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-bindgen-test = "0.3.42"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"

View File

@@ -67,6 +67,7 @@
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
"model.output_names=['output']\n",
"\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",

View File

@@ -38,7 +38,7 @@
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers.legacy import Adam\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
@@ -71,9 +71,11 @@
},
"outputs": [],
"source": [
"opt = Adam()\n",
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
@@ -114,8 +116,11 @@
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
@@ -415,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -349,6 +349,8 @@
"z_log_var = Dense(ZDIM)(x)\n",
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
"dec = get_decoder()\n",
"dec.output_names=['output']\n",
"\n",
"out = dec(z)\n",
"\n",
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",

View File

@@ -61,11 +61,10 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"from torch import nn\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"\n",
@@ -77,28 +76,12 @@
"clr.fit(X_train, y_train)\n",
"\n",
"\n",
"trees = []\n",
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
" super(RandomForest, self).__init__()\n",
" self.trees = nn.ModuleList(trees)\n",
"\n",
" def forward(self, x):\n",
" out = self.trees[0](x)\n",
" for tree in self.trees[1:]:\n",
" out += tree(x)\n",
" return out / len(self.trees)\n",
"\n",
"\n",
"torch_rf = RandomForest(trees)\n",
"torch_rf = convert(clr, 'torch')\n",
"# assert predictions from torch are = to sklearn \n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
@@ -134,14 +117,12 @@
"\n",
"# export to onnx format\n",
"\n",
"torch_rf.eval()\n",
"\n",
"# Input to the model\n",
"shape = X_train.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=False)\n",
"torch_out = torch_rf(x)\n",
"torch_out = torch_rf.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(torch_rf, # model being run\n",
"torch.onnx.export(torch_rf.model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
@@ -158,7 +139,7 @@
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -321,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -227,6 +227,10 @@
" self.length = self.compute_length(self.file_good)\n",
" self.data = self.load_data(self.file_good)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -749,7 +753,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -209,6 +209,11 @@
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
" self.data = self.load_data(self.file_good, self.file_bad)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -637,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,13 @@
{
"input_data": [
[
0.8894134163856506,
0.8894201517105103
]
],
"output_data": [
[
0.8436377
]
]
}

Binary file not shown.

View File

@@ -1,14 +1,14 @@
attrs==22.2.0
exceptiongroup==1.1.1
importlib-metadata==6.1.0
attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.0.1
packaging==23.0
pluggy==1.0.0
pytest==7.2.2
maturin==1.5.0
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
tomli==2.0.1
typing-extensions==4.5.0
zipp==3.15.0
onnx==1.14.1
onnxruntime==1.14.1
numpy==1.21.6
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -444,7 +444,7 @@ pub enum Commands {
disable_selector_compression: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Aggregates proofs :)
Aggregate {
@@ -479,7 +479,7 @@ pub enum Commands {
split_proofs: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
@@ -726,7 +726,7 @@ pub enum Commands {
logrows: u32,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl

View File

@@ -339,7 +339,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
split_proofs,
disable_selector_compression,
commitment,
commitment.into(),
),
Commands::Aggregate {
proof_path,
@@ -360,7 +360,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
check_mode,
split_proofs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Verify {
@@ -384,7 +384,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
reduced_srs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(not(target_arch = "wasm32"))]
@@ -586,7 +586,7 @@ pub(crate) async fn get_srs_cmd(
} else if let Some(settings_p) = settings_path {
if settings_p.exists() {
let settings = GraphSettings::load(&settings_p)?;
settings.run_args.commitment
settings.run_args.commitment.into()
} else {
return Err(err_string.into());
}
@@ -666,21 +666,17 @@ pub(crate) async fn gen_witness(
// if any of the settings have kzg visibility then we need to load the srs
let commitment: Commitments = settings.run_args.commitment.into();
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
if get_srs_path(
settings.run_args.logrows,
srs_path.clone(),
settings.run_args.commitment,
)
.exists()
{
match settings.run_args.commitment {
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
match Commitments::from(settings.run_args.commitment) {
Commitments::KZG => {
let srs: ParamsKZG<Bn256> = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<KZGCommitmentScheme<_>>(
&mut input,
@@ -694,7 +690,7 @@ pub(crate) async fn gen_witness(
load_params_prover::<IPACommitmentScheme<G1Affine>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<IPACommitmentScheme<_>>(
&mut input,
@@ -1303,17 +1299,19 @@ pub(crate) fn create_evm_verifier(
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1347,17 +1345,18 @@ pub(crate) fn create_evm_vk(
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1626,8 +1625,9 @@ pub(crate) fn setup(
}
let logrows = circuit.settings().run_args.logrows;
let commitment: Commitments = circuit.settings().run_args.commitment.into();
let pk = match circuit.settings().run_args.commitment {
let pk = match commitment {
Commitments::KZG => {
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
@@ -1736,7 +1736,8 @@ pub(crate) fn prove(
let transcript: TranscriptType = proof_type.into();
let proof_split_commits: Option<ProofSplitCommit> = data.into();
let commitment = circuit_settings.run_args.commitment;
let commitment = circuit_settings.run_args.commitment.into();
let logrows = circuit_settings.run_args.logrows;
// creates and verifies the proof
let mut snark = match commitment {
Commitments::KZG => {
@@ -1745,7 +1746,7 @@ pub(crate) fn prove(
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
logrows,
Commitments::KZG,
)?;
match strategy {
@@ -2187,8 +2188,9 @@ pub(crate) fn verify(
let circuit_settings = GraphSettings::load(&settings_path)?;
let logrows = circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
match circuit_settings.run_args.commitment {
match commitment {
Commitments::KZG => {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let params: ParamsKZG<Bn256> = if reduced_srs {

View File

@@ -39,7 +39,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
@@ -1451,7 +1451,8 @@ impl GraphCircuit {
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
/// The configuration for the graph circuit
pub struct CircuitSize {
num_instances: usize,
num_advice_columns: usize,
num_fixed: usize,
@@ -1461,7 +1462,8 @@ struct CircuitSize {
}
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
///
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),

View File

@@ -803,13 +803,18 @@ impl Model {
let input_state_idx = input_state_idx(&input_mappings);
let mut output_mappings = vec![];
for mapping in b.output_mapping.iter() {
for (i, mapping) in b.output_mapping.iter().enumerate() {
let mut mappings = vec![];
if let Some(outlet) = mapping.last_value_slot {
mappings.push(OutputMapping::Single {
outlet,
is_state: mapping.state,
});
} else if mapping.state {
mappings.push(OutputMapping::Single {
outlet: i,
is_state: mapping.state,
});
}
if let Some(last) = mapping.scan {
mappings.push(OutputMapping::Stacked {
@@ -818,6 +823,7 @@ impl Model {
is_state: false,
});
}
output_mappings.push(mappings);
}
@@ -1264,8 +1270,8 @@ impl Model {
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, inputs, model.graph.inputs
"{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}",
num_iter, inputs, model.graph.inputs, model.graph.outputs
);
let mut full_results: Vec<ValTensor<Fp>> = vec![];
@@ -1297,6 +1303,7 @@ impl Model {
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
let mut outlets = BTreeMap::new();
let mut stacked_outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res) {
for mapping in mappings {
@@ -1309,25 +1316,42 @@ impl Model {
let stacked_res = full_results[*outlet]
.clone()
.concat_axis(outlet_res.clone(), axis)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
stacked_outlets.insert(outlet, stacked_res);
}
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
full_results = outlets.into_values().collect_vec();
// now extend with stacked elements
let mut pre_stacked_outlets = outlets.clone();
pre_stacked_outlets.extend(stacked_outlets);
let outlets = outlets.into_values().collect_vec();
full_results = pre_stacked_outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
assert_eq!(
input_states.len(),
output_states.len(),
"input and output states must be the same length, got {:?} and {:?}",
input_mappings,
output_mappings
);
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
values[*input_idx] = full_results[output_idx].clone();
assert_eq!(
values[*input_idx].dims(),
outlets[output_idx].dims(),
"input and output dims must be the same, got {:?} and {:?}",
values[*input_idx].dims(),
outlets[output_idx].dims()
);
values[*input_idx] = outlets[output_idx].clone();
}
}

View File

@@ -114,6 +114,12 @@ pub enum Commitments {
IPA,
}
impl From<Option<Commitments>> for Commitments {
fn from(value: Option<Commitments>) -> Self {
value.unwrap_or(Commitments::KZG)
}
}
impl FromStr for Commitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
@@ -214,7 +220,7 @@ pub struct RunArgs {
pub check_mode: CheckMode,
/// commitment scheme
#[arg(long, default_value = "kzg")]
pub commitment: Commitments,
pub commitment: Option<Commitments>,
}
impl Default for RunArgs {
@@ -234,7 +240,7 @@ impl Default for RunArgs {
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: Commitments::KZG,
commitment: None,
}
}
}

View File

@@ -1,4 +1,8 @@
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::CircuitSize;
use crate::pfsys::{Snark, SnarkWitness};
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::plonk::{self};
use halo2_proofs::{
@@ -16,6 +20,8 @@ use halo2_wrong_ecc::{
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::debug;
use log::trace;
use rand::rngs::OsRng;
use snark_verifier::loader::native::NativeLoader;
@@ -193,6 +199,23 @@ impl AggregationConfig {
let main_gate_config = MainGate::<F>::configure(meta);
let range_config =
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
#[cfg(not(target_arch = "wasm32"))]
{
let circuit_size = CircuitSize::from_cs(meta, 23);
// not wasm
debug!(
"circuit size: \n {}",
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
.unwrap()
);
}
AggregationConfig {
main_gate_config,
range_config,

View File

@@ -197,7 +197,7 @@ impl From<PyRunArgs> for RunArgs {
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: py_run_args.commitment.into(),
commitment: Some(py_run_args.commitment.into()),
}
}
}
@@ -234,6 +234,16 @@ pub enum PyCommitments {
IPA,
}
impl From<Option<Commitments>> for PyCommitments {
fn from(commitment: Option<Commitments>) -> Self {
match commitment {
Some(Commitments::KZG) => PyCommitments::KZG,
Some(Commitments::IPA) => PyCommitments::IPA,
None => PyCommitments::KZG,
}
}
}
impl From<PyCommitments> for Commitments {
fn from(py_commitments: PyCommitments) -> Self {
match py_commitments {

View File

@@ -1,3 +1,5 @@
use core::{iter::FilterMap, slice::Iter};
use crate::circuit::region::ConstantsMap;
use super::{
@@ -450,10 +452,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map_iterator(
&self,
) -> core::iter::FilterMap<
core::slice::Iter<'_, ValType<F>>,
fn(&ValType<F>) -> Option<(F, ValType<F>)>,
> {
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {

View File

@@ -8,12 +8,14 @@ use crate::graph::quantize_float;
use crate::graph::scale_to_multiplier;
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::create_proof_circuit;
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::pfsys::verify_proof_circuit;
use crate::pfsys::TranscriptType;
use crate::tensor::TensorType;
use crate::CheckMode;
use crate::Commitments;
use console_error_panic_hook;
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
@@ -33,11 +35,10 @@ use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, PrimeField};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use console_error_panic_hook;
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
@@ -336,8 +337,94 @@ pub fn verify(
let orig_n = 1 << circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match circuit_settings.run_args.commitment {
let result = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(JsError::new(&format!("{}", e))),
}
}
#[wasm_bindgen]
#[allow(non_snake_case)]
/// Verify aggregate proof in browser using wasm
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
let orig_n = 1 << logrows;
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
@@ -436,8 +523,9 @@ pub fn prove(
// read in kzg params
let mut reader = std::io::BufReader::new(&srs[..]);
let commitment = circuit.settings().run_args.commitment.into();
// creates and verifies the proof
let proof = match circuit.settings().run_args.commitment {
let proof = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)

View File

@@ -122,7 +122,7 @@ mod native_tests {
let settings: GraphSettings = serde_json::from_str(&settings).unwrap();
let logrows = settings.run_args.logrows;
download_srs(logrows, settings.run_args.commitment);
download_srs(logrows, settings.run_args.commitment.into());
}
fn mv_test_(test_dir: &str, test: &str) {
@@ -200,7 +200,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 91] = [
const TESTS: [&str; 92] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -296,6 +296,7 @@ mod native_tests {
"reducel1",
"reducel2", // 89
"1l_lppool",
"lstm_large", // 91
];
const WASM_TESTS: [&str; 46] = [
@@ -534,7 +535,7 @@ mod native_tests {
}
});
seq!(N in 0..=90 {
seq!(N in 0..=91 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -622,7 +623,7 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
if test != "gather_nd" && test != "lstm_large" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
@@ -1970,7 +1971,7 @@ mod native_tests {
.expect("failed to parse settings file");
// get_srs for the graph_settings_num_instances
download_srs(1, graph_settings.run_args.commitment);
download_srs(1, graph_settings.run_args.commitment.into());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([

View File

@@ -56,33 +56,38 @@ mod py_tests {
// source .env/bin/activate
// pip install -r requirements.txt
// maturin develop --release --features python-bindings
// first install tf2onnx as it has protobuf conflict with onnx
let status = Command::new("pip")
.args(["install", "tf2onnx==1.16.1"])
.status()
.expect("failed to execute process");
assert!(status.success());
// now install torch, pandas, numpy, seaborn, jupyter
let status = Command::new("pip")
.args([
"install",
"torch-geometric==2.5.0",
"torch==2.0.1",
"torchvision==0.15.2",
"pandas==2.0.3",
"numpy==1.23",
"seaborn==0.12.2",
"jupyter==1.0.0",
"onnx==1.14.0",
"kaggle==1.5.15",
"py-solc-x==1.1.1",
"web3==6.5.0",
"librosa==0.10.0.post2",
"keras==2.12.0",
"tensorflow==2.12.0",
"tensorflow-datasets==4.9.3",
"tf2onnx==1.14.0",
"pytorch-lightning==2.0.6",
"torch-geometric==2.5.2",
"torch==2.2.2",
"torchvision==0.17.2",
"pandas==2.2.1",
"numpy==1.26.4",
"seaborn==0.13.2",
"notebook==7.1.2",
"nbconvert==7.16.3",
"onnx==1.16.0",
"kaggle==1.6.8",
"py-solc-x==2.0.2",
"web3==6.16.0",
"librosa==0.10.1",
"keras==3.1.1",
"tensorflow==2.16.1",
"tensorflow-datasets==4.9.4",
"pytorch-lightning==2.2.1",
"sk2torch==1.2.0",
"scikit-learn==1.3.1",
"xgboost==1.7.6",
"hummingbird-ml==0.4.9",
"lightgbm==4.0.0",
"scikit-learn==1.4.1.post1",
"xgboost==2.0.3",
"hummingbird-ml==0.4.11",
"lightgbm==4.3.0",
])
.status()
.expect("failed to execute process");

View File

@@ -11,7 +11,7 @@ mod wasm32 {
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
@@ -27,10 +27,29 @@ mod wasm32 {
pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx");
pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json");
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json");
pub const PROOF_AGGR: &[u8] = include_bytes!("../tests/wasm/proof_aggr.json");
pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json");
pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key");
pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key");
pub const VK_AGGR: &[u8] = include_bytes!("../tests/wasm/vk_aggr.key");
pub const SRS: &[u8] = include_bytes!("../tests/wasm/kzg");
pub const SRS1: &[u8] = include_bytes!("../tests/wasm/kzg1.srs");
#[wasm_bindgen_test]
async fn can_verify_aggr() {
let value = verifyAggr(
wasm_bindgen::Clamped(PROOF_AGGR.to_vec()),
wasm_bindgen::Clamped(VK_AGGR.to_vec()),
21,
wasm_bindgen::Clamped(SRS1.to_vec()),
"kzg",
)
.map_err(|_| "failed")
.unwrap();
// should not fail
assert!(value);
}
#[wasm_bindgen_test]
async fn verify_encode_verifier_calldata() {

BIN
tests/wasm/kzg1.srs Normal file

Binary file not shown.

Binary file not shown.

3075
tests/wasm/proof_aggr.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -24,8 +24,7 @@
"param_visibility": "Private",
"div_rebasing": false,
"rebase_frac_zero_constants": false,
"check_mode": "UNSAFE",
"commitment": "KZG"
"check_mode": "UNSAFE"
},
"num_rows": 16,
"total_dynamic_col_size": 0,

BIN
tests/wasm/vk_aggr.key Normal file

Binary file not shown.