Files
ezkl/examples/notebooks/little_transformer.ipynb
2025-06-27 22:58:10 +02:00

493 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "d0a82619",
"metadata": {},
"source": [
"Credits to [geohot](https://github.com/geohot/ai-notebooks/blob/master/mnist_gan.ipynb) for most of this code\n",
"\n",
"## Model Architecture and training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c22afe46",
"metadata": {},
"outputs": [],
"source": [
"%pip install pytorch_lightning\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12fb79a8",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import math\n",
"import numpy as np\n",
"\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"import pytorch_lightning as pl\n",
"\n",
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"\n",
"# uncomment for more descriptive logging \n",
"# import logging\n",
"# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"# logging.basicConfig(format=FORMAT)\n",
"# logging.getLogger().setLevel(logging.DEBUG)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8638e94e",
"metadata": {},
"outputs": [],
"source": [
"class BaseDataModule(pl.LightningDataModule):\n",
" def __init__(self, batch_size=32, split=0.8, *args, **kwargs):\n",
" super().__init__()\n",
" self.ds_X, self.ds_Y = self.get_dataset(*args, **kwargs)\n",
" self.split = int(self.ds_X.shape[0]*split)\n",
" self.batch_size = batch_size\n",
"\n",
" def train_dataloader(self):\n",
" ds_X_train, ds_Y_train = self.ds_X[0:self.split], self.ds_Y[0:self.split]\n",
" return torch.utils.data.DataLoader(list(zip(ds_X_train, ds_Y_train)), batch_size=self.batch_size)\n",
"\n",
" def val_dataloader(self):\n",
" ds_X_test, ds_Y_test = self.ds_X[self.split:], self.ds_Y[self.split:]\n",
" return torch.utils.data.DataLoader(list(zip(ds_X_test, ds_Y_test)), batch_size=self.batch_size)\n",
"\n",
"class ReverseDataModule(BaseDataModule):\n",
" def get_dataset(self, cnt=10000, seq_len=6):\n",
" ds = np.random.randint(0, 10, size=(cnt, seq_len))\n",
" return ds, ds[:, ::-1].ravel().reshape(cnt, seq_len)\n",
" \n",
"# dataset idea from https://github.com/karpathy/minGPT/blob/master/play_math.ipynb\n",
"class AdditionDataModule(BaseDataModule):\n",
" def get_dataset(self):\n",
" ret = []\n",
" for i in range(100):\n",
" for j in range(100):\n",
" s = i+j\n",
" ret.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])\n",
" ds = np.array(ret)\n",
" return ds[:, 0:6], np.copy(ds[:, 1:]) \n",
"\n",
"# this is the hardest to learn and requires 4 layers\n",
"class ParityDataModule(BaseDataModule):\n",
" def get_dataset(self, seq_len=10):\n",
" ds_X, ds_Y = [], []\n",
" for i in range(2**seq_len):\n",
" x = [int(x) for x in list(bin(i)[2:].rjust(seq_len, '0'))]\n",
" ds_X.append(x)\n",
" ds_Y.append((np.cumsum(x)%2).tolist())\n",
" return np.array(ds_X), np.array(ds_Y)\n",
" \n",
"class WikipediaDataModule(BaseDataModule):\n",
" def get_dataset(self, seq_len=50):\n",
" global enwik8\n",
" if 'enwik8' not in globals():\n",
" import requests\n",
" enwik8_zipped = requests.get(\"https://data.deepai.org/enwik8.zip\").content\n",
" from zipfile import ZipFile\n",
" import io\n",
" enwik8 = ZipFile(io.BytesIO(enwik8_zipped)).read('enwik8')\n",
" en = np.frombuffer(enwik8, dtype=np.uint8).astype(np.int)\n",
" en = en[0:-seq_len+1]\n",
" en[en>127] = 127\n",
" return en[0:-1].reshape(-1, seq_len), en[1:].reshape(-1, seq_len)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "323554ca",
"metadata": {},
"outputs": [],
"source": [
"def attention(queries, keys, values):\n",
" d = queries.shape[-1]\n",
" scores = torch.matmul(queries, keys.transpose(-2,-1))/math.sqrt(d)\n",
" attention_weights = F.softmax(scores, dim=-1)\n",
" return torch.matmul(attention_weights, values)\n",
"\n",
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, embed_dim, num_heads):\n",
" super(MultiHeadAttention, self).__init__()\n",
" self.embed_dim, self.num_heads = embed_dim, num_heads\n",
" assert embed_dim % num_heads == 0\n",
" self.projection_dim = embed_dim // num_heads\n",
" \n",
" self.W_q = nn.Linear(embed_dim, embed_dim)\n",
" self.W_k = nn.Linear(embed_dim, embed_dim)\n",
" self.W_v = nn.Linear(embed_dim, embed_dim)\n",
" self.W_o = nn.Linear(embed_dim, embed_dim)\n",
"\n",
" def transpose(self, x):\n",
" x = x.reshape(x.shape[0], x.shape[1], self.num_heads, self.projection_dim)\n",
" return x.permute(0, 2, 1, 3)\n",
" \n",
" def transpose_output(self, x):\n",
" x = x.permute(0, 2, 1, 3)\n",
" return x.reshape(x.shape[0], x.shape[1], self.embed_dim)\n",
" \n",
" def forward(self, q, k, v):\n",
" q = self.transpose(self.W_q(q))\n",
" k = self.transpose(self.W_k(k))\n",
" v = self.transpose(self.W_v(v))\n",
" output = attention(q, k, v)\n",
" return self.W_o(self.transpose_output(output))\n",
" \n",
"class TransformerBlock(nn.Module):\n",
" def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
" super(TransformerBlock, self).__init__()\n",
" self.att = MultiHeadAttention(embed_dim, num_heads)\n",
" self.ffn = nn.Sequential(\n",
" nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim)\n",
" )\n",
" self.layernorm1 = nn.LayerNorm(embed_dim)\n",
" self.layernorm2 = nn.LayerNorm(embed_dim)\n",
" self.dropout = nn.Dropout(rate)\n",
" \n",
" def forward(self, x):\n",
" x = self.layernorm1(x + self.dropout(self.att(x, x, x)))\n",
" x = self.layernorm2(x + self.dropout(self.ffn(x)))\n",
" return x\n",
" \n",
"class TokenAndPositionEmbedding(nn.Module):\n",
" def __init__(self, maxlen, vocab_size, embed_dim):\n",
" super(TokenAndPositionEmbedding, self).__init__()\n",
" self.token_emb = nn.Embedding(vocab_size, embed_dim)\n",
" self.pos_emb = nn.Embedding(maxlen, embed_dim)\n",
" def forward(self, x):\n",
" pos = torch.arange(0, x.size(1), dtype=torch.int32, device=x.device)\n",
" return self.token_emb(x) + self.pos_emb(pos).view(1, x.size(1), -1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "167e42e3",
"metadata": {},
"outputs": [],
"source": [
"class LittleTransformer(pl.LightningModule):\n",
" def __init__(self, seq_len=6, max_value=10, layer_count=2, embed_dim=128, num_heads=4, ff_dim=32):\n",
" super().__init__()\n",
" self.max_value = max_value\n",
" self.model = nn.Sequential(\n",
" TokenAndPositionEmbedding(seq_len, max_value, embed_dim),\n",
" *[TransformerBlock(embed_dim, num_heads, ff_dim) for x in range(layer_count)],\n",
" nn.Linear(embed_dim, max_value),\n",
" nn.LogSoftmax(dim=-1))\n",
" \n",
" def forward(self, x):\n",
" return self.model(x)\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" x, y = batch\n",
" output = self.model(x)\n",
" loss = F.nll_loss(output.view(-1, self.max_value), y.view(-1))\n",
" self.log(\"train_loss\", loss)\n",
" return loss\n",
" \n",
" def validation_step(self, val_batch, batch_idx):\n",
" x, y = val_batch\n",
" pred = self.model(x).argmax(dim=2)\n",
" val_accuracy = (pred == y).type(torch.float).mean()\n",
" self.log(\"val_accuracy\", val_accuracy, prog_bar=True)\n",
" \n",
" def configure_optimizers(self):\n",
" if self.device.type == 'cuda':\n",
" import apex\n",
" return apex.optimizers.FusedAdam(self.parameters(), lr=3e-4)\n",
" else:\n",
" return torch.optim.Adam(self.parameters(), lr=3e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a2f48c98",
"metadata": {},
"outputs": [],
"source": [
"model = LittleTransformer(seq_len=6)\n",
"trainer = pl.Trainer(enable_progress_bar=True, max_epochs=0)\n",
"data = AdditionDataModule(batch_size=64)\n",
"#data = ReverseDataModule(cnt=1000, seq_len=20)\n",
"#data = ParityDataModule(seq_len=14)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "markdown",
"id": "fa7d277e",
"metadata": {},
"source": [
"## EZKL "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f339a28",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import os \n",
"\n",
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27ce542b",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import json\n",
"\n",
"\n",
"shape = [1, 6]\n",
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
"x = torch.zeros(shape, dtype=torch.long)\n",
"x = x.reshape(shape)\n",
"\n",
"print(x)\n",
"\n",
"# Flips the neural net into inference mode\n",
"model.eval()\n",
"model.to('cpu')\n",
"\n",
" # Export the model\n",
"torch.onnx.export(model, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data_json = dict(input_data = [data_array])\n",
"\n",
"print(data_json)\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data_json, open(data_path, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36ddc6f9",
"metadata": {},
"outputs": [],
"source": [
"import ezkl \n",
"\n",
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fe6d972",
"metadata": {},
"outputs": [],
"source": [
"cal_path = os.path.join(\"calibration.json\")\n",
"\n",
"data_array = (torch.randn(20, *shape).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(cal_path, 'w'))\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n",
"assert res == True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0990f5a8",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b80dc01",
"metadata": {},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54cbde29",
"metadata": {},
"outputs": [],
"source": [
"# now generate the witness file \n",
"witness_path = \"gan_witness.json\"\n",
"\n",
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28760638",
"metadata": {},
"outputs": [],
"source": [
"res = ezkl.mock(witness_path, compiled_model_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e595112",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# HERE WE SETUP THE CIRCUIT PARAMS\n",
"# WE GOT KEYS\n",
"# WE GOT CIRCUIT PARAMETERS\n",
"# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK\n",
"\n",
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d37adaef",
"metadata": {},
"outputs": [],
"source": [
"# GENERATE A PROOF\n",
"\n",
"\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b58acd5",
"metadata": {},
"outputs": [],
"source": [
"# VERIFY IT\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}