Files
ezkl/examples/notebooks/mnist_gan.ipynb
2025-06-27 22:58:10 +02:00

428 lines
16 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"Credits to [geohot](https://github.com/geohot/ai-notebooks/blob/master/mnist_gan.ipynb) for most of this code"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"tf2onnx\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"# make sure you have the dependencies required here already installed\n",
"import ezkl\n",
"import os\n",
"import json\n",
"import time\n",
"import random\n",
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
"\n",
"# uncomment for more descriptive logging \n",
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"# logging.basicConfig(format=FORMAT)\n",
"# logging.getLogger().setLevel(logging.INFO)\n",
"\n",
"# Can we build a simple GAN that can produce all 10 mnist digits?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
"x_train, x_test = [x/255.0 for x in [x_train, x_test]]\n",
"y_train, y_test = [tf.keras.utils.to_categorical(x) for x in [y_train, y_test]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
"x = Reshape((28,28,1))(x)\n",
"\n",
"x = Conv2D(64, (5,5), padding='same', strides=(2,2))(x)\n",
"x = BatchNormalization()(x)\n",
"x = ELU()(x)\n",
"\n",
"x = Conv2D(128, (5,5), padding='same', strides=(2,2))(x)\n",
"x = BatchNormalization()(x)\n",
"x = ELU()(x)\n",
"\n",
"x = Flatten()(x)\n",
"x = Dense(128)(x)\n",
"x = BatchNormalization()(x)\n",
"x = ELU()(x)\n",
"x = Dense(1, activation='sigmoid')(x)\n",
"dm = Model(in1, x)\n",
"dm.compile(opt, 'binary_crossentropy')\n",
"dm.summary()\n",
"\n",
"# generator, output digits\n",
"x = in1 = Input((ZDIM,))\n",
"\n",
"x = Dense(7*7*64)(x)\n",
"x = BatchNormalization()(x)\n",
"x = ELU()(x)\n",
"x = Reshape((7,7,64))(x)\n",
"\n",
"x = Conv2DTranspose(128, (5,5), strides=(2,2), padding='same')(x)\n",
"x = BatchNormalization()(x)\n",
"x = ELU()(x)\n",
"\n",
"x = Conv2DTranspose(1, (5,5), strides=(2,2), padding='same')(x)\n",
"x = Activation('sigmoid')(x)\n",
"x = Reshape((28,28))(x)\n",
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
"tm = Model(gm.input, x)\n",
"tm.compile(opt, 'binary_crossentropy')\n",
"\n",
"dlosses, glosses = [], []"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"from matplotlib.pyplot import figure, imshow, show\n",
"\n",
"BS = 256\n",
"\n",
"# GAN training loop\n",
"# make larger if you want it to look better\n",
"for i in range(1):\n",
" # train discriminator\n",
" dm.trainable = True\n",
" real_i = x_train[np.random.choice(x_train.shape[0], BS)]\n",
" fake_i = gm.predict_on_batch(np.random.normal(0,1,size=(BS,ZDIM)))\n",
" dloss_r = dm.train_on_batch(real_i, np.ones(BS))\n",
" dloss_f = dm.train_on_batch(fake_i, np.zeros(BS))\n",
" dloss = (dloss_r + dloss_f)/2\n",
"\n",
" # train generator\n",
" dm.trainable = False\n",
" gloss_0 = tm.train_on_batch(np.random.normal(0,1,size=(BS,ZDIM)), np.ones(BS))\n",
" gloss_1 = tm.train_on_batch(np.random.normal(0,1,size=(BS,ZDIM)), np.ones(BS))\n",
" gloss = (gloss_0 + gloss_1)/2\n",
"\n",
" if i%50 == 0:\n",
" print(\"%4d: dloss:%8.4f gloss:%8.4f\" % (i, dloss, gloss))\n",
" dlosses.append(dloss)\n",
" glosses.append(gloss)\n",
" \n",
" if i%250 == 0:\n",
" \n",
" figure(figsize=(16,16))\n",
" imshow(np.concatenate(gm.predict(np.random.normal(size=(10,ZDIM))), axis=1))\n",
" show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib.pyplot import plot, legend\n",
"figure(figsize=(8,8))\n",
"plot(dlosses[100:], label=\"Discriminator Loss\")\n",
"plot(glosses[100:], label=\"Generator Loss\")\n",
"legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = []\n",
"for i in range(10):\n",
" x.append(np.concatenate(gm.predict(np.random.normal(size=(10,ZDIM))), axis=1))\n",
"imshow(np.concatenate(x, axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"\n",
"model_path = os.path.join('gan.onnx')\n",
"compiled_model_path = os.path.join('gan.compiled')\n",
"pk_path = os.path.join('gan.pk')\n",
"vk_path = os.path.join('gan.vk')\n",
"settings_path = os.path.join('gan_settings.json')\n",
"srs_path = os.path.join('gan_kzg.srs')\n",
"witness_path = os.path.join('gan_witness.json')\n",
"data_path = os.path.join('gan_input.json')\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we export the _generator_ to onnx"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import numpy as np\n",
"import tf2onnx\n",
"import tensorflow as tf\n",
"import json\n",
"\n",
"shape = [1, ZDIM]\n",
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
"x = 0.1*np.random.rand(1,*shape)\n",
"\n",
"spec = tf.TensorSpec(shape, tf.float32, name='input_0')\n",
"\n",
"\n",
"tf2onnx.convert.from_keras(gm, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)\n",
"\n",
"data_array = x.reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ezkl\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"private\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (0.2 * np.random.rand(20, *shape)).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"\n",
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"witness_path = \"gan_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# uncomment to mock prove\n",
"# res = ezkl.mock(witness_path, compiled_model_path)\n",
"# assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}