mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-09 14:28:00 -05:00
407 lines
15 KiB
Plaintext
407 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
|
|
"metadata": {},
|
|
"source": [
|
|
"## EZKL Jupyter Notebook Demo (Aggregated Proofs) \n",
|
|
"\n",
|
|
"Demonstrates how to use EZKL with aggregated proofs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "95613ee9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# check if notebook is in colab\n",
|
|
"try:\n",
|
|
" # install ezkl\n",
|
|
" import google.colab\n",
|
|
" import subprocess\n",
|
|
" import sys\n",
|
|
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
|
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
|
"\n",
|
|
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
|
"except:\n",
|
|
" pass\n",
|
|
"\n",
|
|
"\n",
|
|
"# here we create and (potentially train a model)\n",
|
|
"\n",
|
|
"# make sure you have the dependencies required here already installed\n",
|
|
"from torch import nn\n",
|
|
"import ezkl\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"\n",
|
|
"# Defines the model\n",
|
|
"# we got convs, we got relu, we got linear layers\n",
|
|
"# What else could one want ????\n",
|
|
"\n",
|
|
"class MyModel(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(MyModel, self).__init__()\n",
|
|
"\n",
|
|
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)\n",
|
|
" self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)\n",
|
|
"\n",
|
|
" self.relu = nn.ReLU()\n",
|
|
"\n",
|
|
" self.d1 = nn.Linear(48, 48)\n",
|
|
" self.d2 = nn.Linear(48, 10)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" # 32x1x28x28 => 32x32x26x26\n",
|
|
" x = self.conv1(x)\n",
|
|
" x = self.relu(x)\n",
|
|
" x = self.conv2(x)\n",
|
|
" x = self.relu(x)\n",
|
|
"\n",
|
|
" # flatten => 32 x (32*26*26)\n",
|
|
" x = x.flatten(start_dim = 1)\n",
|
|
"\n",
|
|
" # 32 x (32*26*26) => 32x128\n",
|
|
" x = self.d1(x)\n",
|
|
" x = self.relu(x)\n",
|
|
"\n",
|
|
" # logits => 32x10\n",
|
|
" logits = self.d2(x)\n",
|
|
"\n",
|
|
" return logits\n",
|
|
"\n",
|
|
"\n",
|
|
"circuit = MyModel()\n",
|
|
"\n",
|
|
"# Train the model as you like here (skipped for brevity)\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b37637c4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model_path = os.path.join('network.onnx')\n",
|
|
"compiled_model_path = os.path.join('network.compiled')\n",
|
|
"pk_path = os.path.join('test.pk')\n",
|
|
"vk_path = os.path.join('test.vk')\n",
|
|
"proof_path = os.path.join('test.pf')\n",
|
|
"settings_path = os.path.join('settings.json')\n",
|
|
"srs_path = os.path.join('kzg.srs')\n",
|
|
"witness_path = os.path.join('witness.json')\n",
|
|
"data_path = os.path.join('input.json')\n",
|
|
"aggregate_proof_path = os.path.join('aggr.pf')\n",
|
|
"aggregate_vk_path = os.path.join('aggr.vk')\n",
|
|
"aggregate_pk_path = os.path.join('aggr.pk')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "82db373a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"shape = [1, 28, 28]\n",
|
|
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
|
|
"x = 0.1*torch.rand(1,*shape, requires_grad=True)\n",
|
|
"\n",
|
|
"# Flips the neural net into inference mode\n",
|
|
"circuit.eval()\n",
|
|
"\n",
|
|
" # Export the model\n",
|
|
"torch.onnx.export(circuit, # model being run\n",
|
|
" x, # model input (or a tuple for multiple inputs)\n",
|
|
" model_path, # where to save the model (can be a file or file-like object)\n",
|
|
" export_params=True, # store the trained parameter weights inside the model file\n",
|
|
" opset_version=10, # the ONNX version to export the model to\n",
|
|
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
|
" input_names = ['input'], # the model's input names\n",
|
|
" output_names = ['output'], # the model's output names\n",
|
|
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
|
" 'output' : {0 : 'batch_size'}})\n",
|
|
"\n",
|
|
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
|
"\n",
|
|
"data = dict(input_data = [data_array])\n",
|
|
"\n",
|
|
" # Serialize data into file:\n",
|
|
"json.dump( data, open(data_path, 'w' ))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d5e374a2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!RUST_LOG=trace\n",
|
|
"# TODO: Dictionary outputs\n",
|
|
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
|
"assert res == True\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cal_path = os.path.join(\"calibration.json\")\n",
|
|
"\n",
|
|
"data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()\n",
|
|
"\n",
|
|
"data = dict(input_data = [data_array])\n",
|
|
"\n",
|
|
"# Serialize data into file:\n",
|
|
"json.dump(data, open(cal_path, 'w'))\n",
|
|
"\n",
|
|
"\n",
|
|
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3aa4f090",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
|
"assert res == True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8b74dcee",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# srs path\n",
|
|
"res = await ezkl.get_srs( settings_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "18c8b7c7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# now generate the witness file \n",
|
|
"\n",
|
|
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
|
"assert os.path.isfile(witness_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b1c561a8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
|
"# WE GOT KEYS\n",
|
|
"# WE GOT CIRCUIT PARAMETERS\n",
|
|
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"res = ezkl.setup(\n",
|
|
" compiled_model_path,\n",
|
|
" vk_path,\n",
|
|
" pk_path,\n",
|
|
" \n",
|
|
" )\n",
|
|
"\n",
|
|
"assert res == True\n",
|
|
"assert os.path.isfile(vk_path)\n",
|
|
"assert os.path.isfile(pk_path)\n",
|
|
"assert os.path.isfile(settings_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c384cbc8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# GENERATE A PROOF\n",
|
|
"\n",
|
|
"\n",
|
|
"proof_path = os.path.join('test.pf')\n",
|
|
"\n",
|
|
"res = ezkl.prove(\n",
|
|
" witness_path,\n",
|
|
" compiled_model_path,\n",
|
|
" pk_path,\n",
|
|
" proof_path,\n",
|
|
" \n",
|
|
" \"for-aggr\", # IMPORTANT NOTE: To produce an aggregated EVM proof you will want to use poseidon for the smaller proofs\n",
|
|
" )\n",
|
|
"\n",
|
|
"print(res)\n",
|
|
"assert os.path.isfile(proof_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "76f00d41",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# VERIFY IT\n",
|
|
"\n",
|
|
"res = ezkl.verify(\n",
|
|
" proof_path,\n",
|
|
" settings_path,\n",
|
|
" vk_path,\n",
|
|
" )\n",
|
|
"\n",
|
|
"assert res == True\n",
|
|
"print(\"verified\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0832b909",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate a larger SRS. This is needed for the aggregated proof\n",
|
|
"\n",
|
|
"res = await ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c5a64be6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Run mock aggregate to check whether the proof works\n",
|
|
"# Use mock to check for validity as it takes a shorter time to check compared to a full aggregated proof\n",
|
|
"\n",
|
|
"res = ezkl.mock_aggregate([proof_path], 21)\n",
|
|
"assert res == True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fee8acc6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Setup the vk and pk for aggregate\n",
|
|
"res = ezkl.setup_aggregate(\n",
|
|
" [proof_path],\n",
|
|
" aggregate_vk_path,\n",
|
|
" aggregate_pk_path,\n",
|
|
" 21\n",
|
|
")\n",
|
|
"\n",
|
|
"assert os.path.isfile(aggregate_vk_path)\n",
|
|
"assert os.path.isfile(aggregate_pk_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "171702d3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Run aggregate proof\n",
|
|
"res = ezkl.aggregate(\n",
|
|
" [proof_path],\n",
|
|
" aggregate_proof_path,\n",
|
|
" aggregate_pk_path,\n",
|
|
" \"evm\",\n",
|
|
" 21,\n",
|
|
" \"safe\"\n",
|
|
")\n",
|
|
"\n",
|
|
"assert os.path.isfile(aggregate_proof_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "671dfdd5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Check if the proof is valid\n",
|
|
"res = ezkl.verify_aggr(\n",
|
|
" aggregate_proof_path,\n",
|
|
" aggregate_vk_path,\n",
|
|
" 21,\n",
|
|
")\n",
|
|
"assert res == True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "50eba2f4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a smart contract verifier for the aggregated proof\n",
|
|
"\n",
|
|
"sol_code_path = os.path.join(\"Verifier.sol\")\n",
|
|
"abi_path = os.path.join(\"Verifier_ABI.json\")\n",
|
|
"\n",
|
|
"res = await ezkl.create_evm_verifier_aggr(\n",
|
|
" [settings_path],\n",
|
|
" aggregate_vk_path,\n",
|
|
" sol_code_path,\n",
|
|
" abi_path,\n",
|
|
" logrows=21)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |