mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-09 14:28:00 -05:00
493 lines
20 KiB
Plaintext
493 lines
20 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d0a82619",
|
|
"metadata": {},
|
|
"source": [
|
|
"Credits to [geohot](https://github.com/geohot/ai-notebooks/blob/master/mnist_gan.ipynb) for most of this code\n",
|
|
"\n",
|
|
"## Model Architecture and training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c22afe46",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install pytorch_lightning\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "12fb79a8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import random\n",
|
|
"import math\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"\n",
|
|
"import pytorch_lightning as pl\n",
|
|
"\n",
|
|
"# check if notebook is in colab\n",
|
|
"try:\n",
|
|
" # install ezkl\n",
|
|
" import google.colab\n",
|
|
" import subprocess\n",
|
|
" import sys\n",
|
|
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
|
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
|
"\n",
|
|
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
|
"except:\n",
|
|
" pass\n",
|
|
"\n",
|
|
"\n",
|
|
"# uncomment for more descriptive logging \n",
|
|
"# import logging\n",
|
|
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
|
"# logging.basicConfig(format=FORMAT)\n",
|
|
"# logging.getLogger().setLevel(logging.DEBUG)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8638e94e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class BaseDataModule(pl.LightningDataModule):\n",
|
|
" def __init__(self, batch_size=32, split=0.8, *args, **kwargs):\n",
|
|
" super().__init__()\n",
|
|
" self.ds_X, self.ds_Y = self.get_dataset(*args, **kwargs)\n",
|
|
" self.split = int(self.ds_X.shape[0]*split)\n",
|
|
" self.batch_size = batch_size\n",
|
|
"\n",
|
|
" def train_dataloader(self):\n",
|
|
" ds_X_train, ds_Y_train = self.ds_X[0:self.split], self.ds_Y[0:self.split]\n",
|
|
" return torch.utils.data.DataLoader(list(zip(ds_X_train, ds_Y_train)), batch_size=self.batch_size)\n",
|
|
"\n",
|
|
" def val_dataloader(self):\n",
|
|
" ds_X_test, ds_Y_test = self.ds_X[self.split:], self.ds_Y[self.split:]\n",
|
|
" return torch.utils.data.DataLoader(list(zip(ds_X_test, ds_Y_test)), batch_size=self.batch_size)\n",
|
|
"\n",
|
|
"class ReverseDataModule(BaseDataModule):\n",
|
|
" def get_dataset(self, cnt=10000, seq_len=6):\n",
|
|
" ds = np.random.randint(0, 10, size=(cnt, seq_len))\n",
|
|
" return ds, ds[:, ::-1].ravel().reshape(cnt, seq_len)\n",
|
|
" \n",
|
|
"# dataset idea from https://github.com/karpathy/minGPT/blob/master/play_math.ipynb\n",
|
|
"class AdditionDataModule(BaseDataModule):\n",
|
|
" def get_dataset(self):\n",
|
|
" ret = []\n",
|
|
" for i in range(100):\n",
|
|
" for j in range(100):\n",
|
|
" s = i+j\n",
|
|
" ret.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])\n",
|
|
" ds = np.array(ret)\n",
|
|
" return ds[:, 0:6], np.copy(ds[:, 1:]) \n",
|
|
"\n",
|
|
"# this is the hardest to learn and requires 4 layers\n",
|
|
"class ParityDataModule(BaseDataModule):\n",
|
|
" def get_dataset(self, seq_len=10):\n",
|
|
" ds_X, ds_Y = [], []\n",
|
|
" for i in range(2**seq_len):\n",
|
|
" x = [int(x) for x in list(bin(i)[2:].rjust(seq_len, '0'))]\n",
|
|
" ds_X.append(x)\n",
|
|
" ds_Y.append((np.cumsum(x)%2).tolist())\n",
|
|
" return np.array(ds_X), np.array(ds_Y)\n",
|
|
" \n",
|
|
"class WikipediaDataModule(BaseDataModule):\n",
|
|
" def get_dataset(self, seq_len=50):\n",
|
|
" global enwik8\n",
|
|
" if 'enwik8' not in globals():\n",
|
|
" import requests\n",
|
|
" enwik8_zipped = requests.get(\"https://data.deepai.org/enwik8.zip\").content\n",
|
|
" from zipfile import ZipFile\n",
|
|
" import io\n",
|
|
" enwik8 = ZipFile(io.BytesIO(enwik8_zipped)).read('enwik8')\n",
|
|
" en = np.frombuffer(enwik8, dtype=np.uint8).astype(np.int)\n",
|
|
" en = en[0:-seq_len+1]\n",
|
|
" en[en>127] = 127\n",
|
|
" return en[0:-1].reshape(-1, seq_len), en[1:].reshape(-1, seq_len)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "323554ca",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def attention(queries, keys, values):\n",
|
|
" d = queries.shape[-1]\n",
|
|
" scores = torch.matmul(queries, keys.transpose(-2,-1))/math.sqrt(d)\n",
|
|
" attention_weights = F.softmax(scores, dim=-1)\n",
|
|
" return torch.matmul(attention_weights, values)\n",
|
|
"\n",
|
|
"class MultiHeadAttention(nn.Module):\n",
|
|
" def __init__(self, embed_dim, num_heads):\n",
|
|
" super(MultiHeadAttention, self).__init__()\n",
|
|
" self.embed_dim, self.num_heads = embed_dim, num_heads\n",
|
|
" assert embed_dim % num_heads == 0\n",
|
|
" self.projection_dim = embed_dim // num_heads\n",
|
|
" \n",
|
|
" self.W_q = nn.Linear(embed_dim, embed_dim)\n",
|
|
" self.W_k = nn.Linear(embed_dim, embed_dim)\n",
|
|
" self.W_v = nn.Linear(embed_dim, embed_dim)\n",
|
|
" self.W_o = nn.Linear(embed_dim, embed_dim)\n",
|
|
"\n",
|
|
" def transpose(self, x):\n",
|
|
" x = x.reshape(x.shape[0], x.shape[1], self.num_heads, self.projection_dim)\n",
|
|
" return x.permute(0, 2, 1, 3)\n",
|
|
" \n",
|
|
" def transpose_output(self, x):\n",
|
|
" x = x.permute(0, 2, 1, 3)\n",
|
|
" return x.reshape(x.shape[0], x.shape[1], self.embed_dim)\n",
|
|
" \n",
|
|
" def forward(self, q, k, v):\n",
|
|
" q = self.transpose(self.W_q(q))\n",
|
|
" k = self.transpose(self.W_k(k))\n",
|
|
" v = self.transpose(self.W_v(v))\n",
|
|
" output = attention(q, k, v)\n",
|
|
" return self.W_o(self.transpose_output(output))\n",
|
|
" \n",
|
|
"class TransformerBlock(nn.Module):\n",
|
|
" def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
|
|
" super(TransformerBlock, self).__init__()\n",
|
|
" self.att = MultiHeadAttention(embed_dim, num_heads)\n",
|
|
" self.ffn = nn.Sequential(\n",
|
|
" nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim)\n",
|
|
" )\n",
|
|
" self.layernorm1 = nn.LayerNorm(embed_dim)\n",
|
|
" self.layernorm2 = nn.LayerNorm(embed_dim)\n",
|
|
" self.dropout = nn.Dropout(rate)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" x = self.layernorm1(x + self.dropout(self.att(x, x, x)))\n",
|
|
" x = self.layernorm2(x + self.dropout(self.ffn(x)))\n",
|
|
" return x\n",
|
|
" \n",
|
|
"class TokenAndPositionEmbedding(nn.Module):\n",
|
|
" def __init__(self, maxlen, vocab_size, embed_dim):\n",
|
|
" super(TokenAndPositionEmbedding, self).__init__()\n",
|
|
" self.token_emb = nn.Embedding(vocab_size, embed_dim)\n",
|
|
" self.pos_emb = nn.Embedding(maxlen, embed_dim)\n",
|
|
" def forward(self, x):\n",
|
|
" pos = torch.arange(0, x.size(1), dtype=torch.int32, device=x.device)\n",
|
|
" return self.token_emb(x) + self.pos_emb(pos).view(1, x.size(1), -1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "167e42e3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LittleTransformer(pl.LightningModule):\n",
|
|
" def __init__(self, seq_len=6, max_value=10, layer_count=2, embed_dim=128, num_heads=4, ff_dim=32):\n",
|
|
" super().__init__()\n",
|
|
" self.max_value = max_value\n",
|
|
" self.model = nn.Sequential(\n",
|
|
" TokenAndPositionEmbedding(seq_len, max_value, embed_dim),\n",
|
|
" *[TransformerBlock(embed_dim, num_heads, ff_dim) for x in range(layer_count)],\n",
|
|
" nn.Linear(embed_dim, max_value),\n",
|
|
" nn.LogSoftmax(dim=-1))\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" return self.model(x)\n",
|
|
" \n",
|
|
" def training_step(self, batch, batch_idx):\n",
|
|
" x, y = batch\n",
|
|
" output = self.model(x)\n",
|
|
" loss = F.nll_loss(output.view(-1, self.max_value), y.view(-1))\n",
|
|
" self.log(\"train_loss\", loss)\n",
|
|
" return loss\n",
|
|
" \n",
|
|
" def validation_step(self, val_batch, batch_idx):\n",
|
|
" x, y = val_batch\n",
|
|
" pred = self.model(x).argmax(dim=2)\n",
|
|
" val_accuracy = (pred == y).type(torch.float).mean()\n",
|
|
" self.log(\"val_accuracy\", val_accuracy, prog_bar=True)\n",
|
|
" \n",
|
|
" def configure_optimizers(self):\n",
|
|
" if self.device.type == 'cuda':\n",
|
|
" import apex\n",
|
|
" return apex.optimizers.FusedAdam(self.parameters(), lr=3e-4)\n",
|
|
" else:\n",
|
|
" return torch.optim.Adam(self.parameters(), lr=3e-4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a2f48c98",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = LittleTransformer(seq_len=6)\n",
|
|
"trainer = pl.Trainer(enable_progress_bar=True, max_epochs=0)\n",
|
|
"data = AdditionDataModule(batch_size=64)\n",
|
|
"#data = ReverseDataModule(cnt=1000, seq_len=20)\n",
|
|
"#data = ParityDataModule(seq_len=14)\n",
|
|
"trainer.fit(model, data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fa7d277e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## EZKL "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6f339a28",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"import os \n",
|
|
"\n",
|
|
"model_path = os.path.join('network.onnx')\n",
|
|
"compiled_model_path = os.path.join('network.compiled')\n",
|
|
"pk_path = os.path.join('test.pk')\n",
|
|
"vk_path = os.path.join('test.vk')\n",
|
|
"settings_path = os.path.join('settings.json')\n",
|
|
"\n",
|
|
"witness_path = os.path.join('witness.json')\n",
|
|
"data_path = os.path.join('input.json')\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "27ce542b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"import json\n",
|
|
"\n",
|
|
"\n",
|
|
"shape = [1, 6]\n",
|
|
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
|
|
"x = torch.zeros(shape, dtype=torch.long)\n",
|
|
"x = x.reshape(shape)\n",
|
|
"\n",
|
|
"print(x)\n",
|
|
"\n",
|
|
"# Flips the neural net into inference mode\n",
|
|
"model.eval()\n",
|
|
"model.to('cpu')\n",
|
|
"\n",
|
|
" # Export the model\n",
|
|
"torch.onnx.export(model, # model being run\n",
|
|
" x, # model input (or a tuple for multiple inputs)\n",
|
|
" model_path, # where to save the model (can be a file or file-like object)\n",
|
|
" export_params=True, # store the trained parameter weights inside the model file\n",
|
|
" opset_version=10, # the ONNX version to export the model to\n",
|
|
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
|
" input_names = ['input'], # the model's input names\n",
|
|
" output_names = ['output'], # the model's output names\n",
|
|
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
|
" 'output' : {0 : 'batch_size'}})\n",
|
|
"\n",
|
|
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
|
"\n",
|
|
"data_json = dict(input_data = [data_array])\n",
|
|
"\n",
|
|
"print(data_json)\n",
|
|
"\n",
|
|
" # Serialize data into file:\n",
|
|
"json.dump( data_json, open(data_path, 'w' ))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "36ddc6f9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import ezkl \n",
|
|
"\n",
|
|
"!RUST_LOG=trace\n",
|
|
"# TODO: Dictionary outputs\n",
|
|
"res = ezkl.gen_settings(model_path, settings_path)\n",
|
|
"assert res == True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2fe6d972",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cal_path = os.path.join(\"calibration.json\")\n",
|
|
"\n",
|
|
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
|
|
"\n",
|
|
"data = dict(input_data = [data_array])\n",
|
|
"\n",
|
|
"# Serialize data into file:\n",
|
|
"json.dump(data, open(cal_path, 'w'))\n",
|
|
"\n",
|
|
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
|
|
"assert res == True\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0990f5a8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
|
|
"assert res == True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1b80dc01",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# srs path\n",
|
|
"res = await ezkl.get_srs( settings_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "54cbde29",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# now generate the witness file \n",
|
|
"witness_path = \"gan_witness.json\"\n",
|
|
"\n",
|
|
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
|
|
"assert os.path.isfile(witness_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "28760638",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"res = ezkl.mock(witness_path, compiled_model_path)\n",
|
|
"assert res == True"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5e595112",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
|
|
"# WE GOT KEYS\n",
|
|
"# WE GOT CIRCUIT PARAMETERS\n",
|
|
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
|
|
"\n",
|
|
"res = ezkl.setup(\n",
|
|
" compiled_model_path,\n",
|
|
" vk_path,\n",
|
|
" pk_path,\n",
|
|
" \n",
|
|
" )\n",
|
|
"\n",
|
|
"assert res == True\n",
|
|
"assert os.path.isfile(vk_path)\n",
|
|
"assert os.path.isfile(pk_path)\n",
|
|
"assert os.path.isfile(settings_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d37adaef",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# GENERATE A PROOF\n",
|
|
"\n",
|
|
"\n",
|
|
"proof_path = os.path.join('test.pf')\n",
|
|
"\n",
|
|
"res = ezkl.prove(\n",
|
|
" witness_path,\n",
|
|
" compiled_model_path,\n",
|
|
" pk_path,\n",
|
|
" proof_path,\n",
|
|
" \n",
|
|
" \"single\",\n",
|
|
" )\n",
|
|
"\n",
|
|
"print(res)\n",
|
|
"assert os.path.isfile(proof_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5b58acd5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# VERIFY IT\n",
|
|
"res = ezkl.verify(\n",
|
|
" proof_path,\n",
|
|
" settings_path,\n",
|
|
" vk_path,\n",
|
|
" \n",
|
|
" )\n",
|
|
"\n",
|
|
"assert res == True\n",
|
|
"print(\"verified\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.15"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |