Files
ezkl/examples/notebooks/mnist_classifier.ipynb
dante 13dae3392f fix!: make calibrate-settings sync in python (#616)
BREAKING CHANGE: calibrate settings is no longer async
2023-11-18 01:08:10 +03:00

409 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Simple MNIST Classifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"tensorflow_datasets\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"tf2onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import time\n",
"import random\n",
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.layers import *\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"\n",
"# uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.INFO)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"(ds_train, ds_test), ds_info = tfds.load(\n",
" 'mnist',\n",
" split=['train', 'test'],\n",
" shuffle_files=True,\n",
" as_supervised=True,\n",
" with_info=True,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def normalize_img(image, label):\n",
" \"\"\"Normalizes images: `uint8` -> `float32`.\"\"\"\n",
" return tf.cast(image, tf.float32) / 255., label\n",
"\n",
"ds_train = ds_train.map(\n",
" normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n",
"ds_train = ds_train.cache()\n",
"ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
"ds_train = ds_train.batch(128)\n",
"ds_train = ds_train.prefetch(tf.data.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds_test = ds_test.map(\n",
" normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n",
"ds_test = ds_test.batch(128)\n",
"ds_test = ds_test.cache()\n",
"ds_test = ds_test.prefetch(tf.data.AUTOTUNE)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
" tf.keras.layers.Dense(128, activation='relu'),\n",
" tf.keras.layers.Dense(10)\n",
"])\n",
"model.compile(\n",
" optimizer=tf.keras.optimizers.Adam(0.001),\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n",
")\n",
"\n",
"model.fit(\n",
" ds_train,\n",
" epochs=6,\n",
" validation_data=ds_test,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('key.pk')\n",
"vk_path = os.path.join('key.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"srs_path = os.path.join('kzg.srs')\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(list(ds_train)[0][0].numpy()[0:1].shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import numpy as np\n",
"import tf2onnx\n",
"import tensorflow as tf\n",
"import json\n",
"\n",
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
"x = list(ds_train)[0][0].numpy()[0:1]\n",
"\n",
"spec = tf.TensorSpec([1, 28, 28], tf.float32, name='input_0')\n",
"\n",
"\n",
"tf2onnx.convert.from_keras(model, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)\n",
"\n",
"data_array = x.reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"private\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", scales = [0, 7])\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = ezkl.get_srs(srs_path, settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"witness_path = \"witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# uncomment to mock prove\n",
"res = ezkl.mock(witness_path, compiled_model_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" srs_path,\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" srs_path,\n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" srs_path,\n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now create an EVM / `.sol` verifier that can be deployed on chain to verify submitted proofs using a view function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"abi_path = 'test.abi'\n",
"sol_code_path = 'test.sol'\n",
"\n",
"res = ezkl.create_evm_verifier(\n",
" vk_path,\n",
" srs_path,\n",
" settings_path,\n",
" sol_code_path,\n",
" abi_path,\n",
" )\n",
"assert res == True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify on the evm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Make sure anvil is running locally first\n",
"# run with $ anvil -p 3030\n",
"# we use the default anvil node here\n",
"import json\n",
"\n",
"address_path = os.path.join(\"address.json\")\n",
"\n",
"res = ezkl.deploy_evm(\n",
" address_path,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
"\n",
"with open(address_path, 'r') as file:\n",
" addr = file.read().rstrip()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make sure anvil is running locally\n",
"# $ anvil -p 3030\n",
"\n",
"res = ezkl.verify_evm(\n",
" proof_path,\n",
" addr,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}