mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-08 22:08:07 -05:00
428 lines
16 KiB
Plaintext
428 lines
16 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n",
|
|
"Credits to [geohot](https://github.com/geohot/ai-notebooks/blob/master/mnist_gan.ipynb) for most of this code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# check if notebook is in colab\n",
|
|
"try:\n",
|
|
" # install ezkl\n",
|
|
" import google.colab\n",
|
|
" import subprocess\n",
|
|
" import sys\n",
|
|
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
|
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"tf2onnx\"])\n",
|
|
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
|
"\n",
|
|
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
|
"except:\n",
|
|
" pass\n",
|
|
"\n",
|
|
"# make sure you have the dependencies required here already installed\n",
|
|
"import ezkl\n",
|
|
"import os\n",
|
|
"import json\n",
|
|
"import time\n",
|
|
"import random\n",
|
|
"import logging\n",
|
|
"\n",
|
|
"import tensorflow as tf\n",
|
|
"from tensorflow.keras.optimizers import Adam\n",
|
|
"from tensorflow.keras.layers import *\n",
|
|
"from tensorflow.keras.models import Model\n",
|
|
"from tensorflow.keras.datasets import mnist\n",
|
|
"\n",
|
|
"# uncomment for more descriptive logging \n",
|
|
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
|
"# logging.basicConfig(format=FORMAT)\n",
|
|
"# logging.getLogger().setLevel(logging.INFO)\n",
|
|
"\n",
|
|
"# Can we build a simple GAN that can produce all 10 mnist digits?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
|
|
"x_train, x_test = [x/255.0 for x in [x_train, x_test]]\n",
|
|
"y_train, y_test = [tf.keras.utils.to_categorical(x) for x in [y_train, y_test]]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"ZDIM = 100\n",
|
|
"\n",
|
|
"opt = Adam()\n",
|
|
"\n",
|
|
"\n",
|
|
"# discriminator\n",
|
|
"# 0 if it's fake, 1 if it's real\n",
|
|
"x = in1 = Input((28,28))\n",
|
|
"x = Reshape((28,28,1))(x)\n",
|
|
"\n",
|
|
"x = Conv2D(64, (5,5), padding='same', strides=(2,2))(x)\n",
|
|
"x = BatchNormalization()(x)\n",
|
|
"x = ELU()(x)\n",
|
|
"\n",
|
|
"x = Conv2D(128, (5,5), padding='same', strides=(2,2))(x)\n",
|
|
"x = BatchNormalization()(x)\n",
|
|
"x = ELU()(x)\n",
|
|
"\n",
|
|
"x = Flatten()(x)\n",
|
|
"x = Dense(128)(x)\n",
|
|
"x = BatchNormalization()(x)\n",
|
|
"x = ELU()(x)\n",
|
|
"x = Dense(1, activation='sigmoid')(x)\n",
|
|
"dm = Model(in1, x)\n",
|
|
"dm.compile(opt, 'binary_crossentropy')\n",
|
|
"dm.summary()\n",
|
|
"\n",
|
|
"# generator, output digits\n",
|
|
"x = in1 = Input((ZDIM,))\n",
|
|
"\n",
|
|
"x = Dense(7*7*64)(x)\n",
|
|
"x = BatchNormalization()(x)\n",
|
|
"x = ELU()(x)\n",
|
|
"x = Reshape((7,7,64))(x)\n",
|
|
"\n",
|
|
"x = Conv2DTranspose(128, (5,5), strides=(2,2), padding='same')(x)\n",
|
|
"x = BatchNormalization()(x)\n",
|
|
"x = ELU()(x)\n",
|
|
"\n",
|
|
"x = Conv2DTranspose(1, (5,5), strides=(2,2), padding='same')(x)\n",
|
|
"x = Activation('sigmoid')(x)\n",
|
|
"x = Reshape((28,28))(x)\n",
|
|
"\n",
|
|
"gm = Model(in1, x)\n",
|
|
"gm.compile('adam', 'mse')\n",
|
|
"gm.output_names=['output']\n",
|
|
"gm.summary()\n",
|
|
"\n",
|
|
"opt = Adam()\n",
|
|
"\n",
|
|
"# GAN\n",
|
|
"dm.trainable = False\n",
|
|
"x = dm(gm.output)\n",
|
|
"tm = Model(gm.input, x)\n",
|
|
"tm.compile(opt, 'binary_crossentropy')\n",
|
|
"\n",
|
|
"dlosses, glosses = [], []"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"from matplotlib.pyplot import figure, imshow, show\n",
|
|
"\n",
|
|
"BS = 256\n",
|
|
"\n",
|
|
"# GAN training loop\n",
|
|
"# make larger if you want it to look better\n",
|
|
"for i in range(1):\n",
|
|
" # train discriminator\n",
|
|
" dm.trainable = True\n",
|
|
" real_i = x_train[np.random.choice(x_train.shape[0], BS)]\n",
|
|
" fake_i = gm.predict_on_batch(np.random.normal(0,1,size=(BS,ZDIM)))\n",
|
|
" dloss_r = dm.train_on_batch(real_i, np.ones(BS))\n",
|
|
" dloss_f = dm.train_on_batch(fake_i, np.zeros(BS))\n",
|
|
" dloss = (dloss_r + dloss_f)/2\n",
|
|
"\n",
|
|
" # train generator\n",
|
|
" dm.trainable = False\n",
|
|
" gloss_0 = tm.train_on_batch(np.random.normal(0,1,size=(BS,ZDIM)), np.ones(BS))\n",
|
|
" gloss_1 = tm.train_on_batch(np.random.normal(0,1,size=(BS,ZDIM)), np.ones(BS))\n",
|
|
" gloss = (gloss_0 + gloss_1)/2\n",
|
|
"\n",
|
|
" if i%50 == 0:\n",
|
|
" print(\"%4d: dloss:%8.4f gloss:%8.4f\" % (i, dloss, gloss))\n",
|
|
" dlosses.append(dloss)\n",
|
|
" glosses.append(gloss)\n",
|
|
" \n",
|
|
" if i%250 == 0:\n",
|
|
" \n",
|
|
" figure(figsize=(16,16))\n",
|
|
" imshow(np.concatenate(gm.predict(np.random.normal(size=(10,ZDIM))), axis=1))\n",
|
|
" show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from matplotlib.pyplot import plot, legend\n",
|
|
"figure(figsize=(8,8))\n",
|
|
"plot(dlosses[100:], label=\"Discriminator Loss\")\n",
|
|
"plot(glosses[100:], label=\"Generator Loss\")\n",
|
|
"legend()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = []\n",
|
|
"for i in range(10):\n",
|
|
" x.append(np.concatenate(gm.predict(np.random.normal(size=(10,ZDIM))), axis=1))\n",
|
|
"imshow(np.concatenate(x, axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os \n",
|
|
"\n",
|
|
"model_path = os.path.join('gan.onnx')\n",
|
|
"compiled_model_path = os.path.join('gan.compiled')\n",
|
|
"pk_path = os.path.join('gan.pk')\n",
|
|
"vk_path = os.path.join('gan.vk')\n",
|
|
"settings_path = os.path.join('gan_settings.json')\n",
|
|
"srs_path = os.path.join('gan_kzg.srs')\n",
|
|
"witness_path = os.path.join('gan_witness.json')\n",
|
|
"data_path = os.path.join('gan_input.json')\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we export the _generator_ to onnx"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import tf2onnx\n",
|
|
"import tensorflow as tf\n",
|
|
"import json\n",
|
|
"\n",
|
|
"shape = [1, ZDIM]\n",
|
|
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
|
|
"x = 0.1*np.random.rand(1,*shape)\n",
|
|
"\n",
|
|
"spec = tf.TensorSpec(shape, tf.float32, name='input_0')\n",
|
|
"\n",
|
|
"\n",
|
|
"tf2onnx.convert.from_keras(gm, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)\n",
|
|
"\n",
|
|
"data_array = x.reshape([-1]).tolist()\n",
|
|
"\n",
|
|
"data = dict(input_data = [data_array])\n",
|
|
"\n",
|
|
" # Serialize data into file:\n",
|
|
"json.dump( data, open(data_path, 'w' ))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import ezkl\n",
|
|
"\n",
|
|
"run_args = ezkl.PyRunArgs()\n",
|
|
"run_args.input_visibility = \"private\"\n",
|
|
"run_args.param_visibility = \"fixed\"\n",
|
|
"run_args.output_visibility = \"public\"\n",
|
|
"run_args.variables = [(\"batch_size\", 1)]\n",
|
|
"\n",
|
|
"!RUST_LOG=trace\n",
|
|
"# TODO: Dictionary outputs\n",
|
|
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
|
"assert res == True\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cal_path = os.path.join(\"calibration.json\")\n",
|
|
"\n",
|
|
"data_array = (0.2 * np.random.rand(20, *shape)).reshape([-1]).tolist()\n",
|
|
"\n",
|
|
"data = dict(input_data = [data_array])\n",
|
|
"\n",
|
|
"# Serialize data into file:\n",
|
|
"json.dump(data, open(cal_path, 'w'))\n",
|
|
"\n",
|
|
"\n",
|
|
"ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", scales=[0,6])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
|
"assert res == True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# srs path\n",
|
|
"res = await ezkl.get_srs( settings_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# now generate the witness file \n",
|
|
"witness_path = \"gan_witness.json\"\n",
|
|
"\n",
|
|
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
|
"assert os.path.isfile(witness_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# uncomment to mock prove\n",
|
|
"# res = ezkl.mock(witness_path, compiled_model_path)\n",
|
|
"# assert res == True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
|
"# WE GOT KEYS\n",
|
|
"# WE GOT CIRCUIT PARAMETERS\n",
|
|
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
|
"\n",
|
|
"res = ezkl.setup(\n",
|
|
" compiled_model_path,\n",
|
|
" vk_path,\n",
|
|
" pk_path,\n",
|
|
" \n",
|
|
" )\n",
|
|
"\n",
|
|
"assert res == True\n",
|
|
"assert os.path.isfile(vk_path)\n",
|
|
"assert os.path.isfile(pk_path)\n",
|
|
"assert os.path.isfile(settings_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# GENERATE A PROOF\n",
|
|
"\n",
|
|
"\n",
|
|
"proof_path = os.path.join('test.pf')\n",
|
|
"\n",
|
|
"res = ezkl.prove(\n",
|
|
" witness_path,\n",
|
|
" compiled_model_path,\n",
|
|
" pk_path,\n",
|
|
" proof_path,\n",
|
|
" \n",
|
|
" \"single\",\n",
|
|
" )\n",
|
|
"\n",
|
|
"print(res)\n",
|
|
"assert os.path.isfile(proof_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# VERIFY IT\n",
|
|
"res = ezkl.verify(\n",
|
|
" proof_path,\n",
|
|
" settings_path,\n",
|
|
" vk_path,\n",
|
|
" \n",
|
|
" )\n",
|
|
"\n",
|
|
"assert res == True\n",
|
|
"print(\"verified\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
} |