Files
ezkl/examples/notebooks/kmeans.ipynb
2025-06-27 22:58:10 +02:00

292 lines
11 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "cf69bb3f-94e6-4dba-92cd-ce08df117d67",
"metadata": {},
"source": [
"## K-means\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95613ee9",
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"sk2torch\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"# here we create and (potentially train a model)\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import json\n",
"import numpy as np\n",
"from sklearn.cluster import KMeans\n",
"from hummingbird.ml import convert\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"\n",
"\n",
"# Create a dataset of two Gaussians. There will be some overlap\n",
"# between the two classes, which adds some uncertainty to the model.\n",
"xs = np.concatenate(\n",
" [\n",
" np.random.random(size=(256, 2)) + [1, 0],\n",
" np.random.random(size=(256, 2)) + [-1, 0],\n",
" ],\n",
" axis=0,\n",
")\n",
"\n",
"# Train an SVM on the data and wrap it in PyTorch.\n",
"sk_model = KMeans()\n",
"sk_model.fit(xs)\n",
"model = convert(sk_model, backend=\"pytorch\").model\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37637c4",
"metadata": {},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82db373a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"# Create a coordinate grid to compute a vector field on.\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"\n",
"# Input to the model\n",
"shape = xs.shape[1:]\n",
"x = grid_xs[0:1]\n",
"torch_out = model(x)\n",
"# Export the model\n",
"torch.onnx.export(model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
" \"network.onnx\",\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names=['input'], # the model's input names\n",
" output_names=['output'], # the model's output names\n",
" dynamic_axes={'input': {0: 'batch_size'}, # variable length axes\n",
" 'output': {0: 'batch_size'}})\n",
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5e374a2",
"metadata": {},
"outputs": [],
"source": [
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (grid_xs[0:20].detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3aa4f090",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b74dcee",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18c8b7c7",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c384cbc8",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76f00d41",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}