Files
concrete/examples/QuantizedLogisticRegression.ipynb

863 lines
71 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "0fe629d6",
"metadata": {},
"source": [
"# Quantized Logistic Regression\n",
"\n",
"Currently, **hdk** only supports unsigned integers up to 7-bits. Nevertheless, we want to evaluate a logistic regression model with it. Luckily, we can make use of **quantization** to overcome this limitation!"
]
},
{
"cell_type": "markdown",
"id": "d0cfb561",
"metadata": {},
"source": [
"### Let's start by importing some libraries to develop our logistic regression model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3c1d929c",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch"
]
},
{
"cell_type": "markdown",
"id": "69d25f7c",
"metadata": {},
"source": [
"### And some helpers for visualization"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a89c1a6c",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from IPython.display import display"
]
},
{
"cell_type": "markdown",
"id": "7729c1de",
"metadata": {},
"source": [
"### We need a dataset, a handcrafted one for simplicity"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b77a9e82",
"metadata": {},
"outputs": [],
"source": [
"x = torch.tensor([[1, 1], [1, 2], [2, 1], [4, 1], [3, 2], [4, 2]]).float()\n",
"y = torch.tensor([[0], [0], [0], [1], [1], [1]]).float()"
]
},
{
"cell_type": "markdown",
"id": "cc8673ff",
"metadata": {},
"source": [
"### Let's visualize our dataset to get a grasp of it"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "35a98d1a",
"metadata": {},
"outputs": [],
"source": [
"plt.ioff()\n",
"fig, ax = plt.subplots(1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "56703410",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPTklEQVR4nO3df4zkd13H8efruPPHpcgZbqO117v1D1ABKbQr1ED0lCgHmBJjTagVbCO5RKsup4mNEOkpaaIhchQbOC6lOdT1wNAGSgNGImAlhJo9LO2VCmmEKweNt7S5omBMznv7x3eW7q27O7N3szuzn30+ksnO9/v93Hxf/XTvtd/5zMxtqgpJ0sa3ZdQBJEnDYaFLUiMsdElqhIUuSY2w0CWpEVtHdeKdO3fW5OTkqE4vSRvS8ePHv1lVE0sdG1mhT05OMjs7O6rTS9KGlOTkcsdccpGkRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSI9ov9MW/M9XfoSqpUX0LPcnlST6V5ItJHk4yvcSYJHlXkkeTPJjkyrWJu0oHD8KBA0+XeFW3ffDgKFNJQzczA5OTsGVL93VmZtSJ2jeOcz7IFfpZ4A+q6nnA1cBNSZ63aMyrgOf0bvuB9ww15YWogjNn4Lbbni71Awe67TNnvFJXM2ZmYP9+OHmy+7Y+ebLbHoeCadW4znlqlcWW5CPA7VX1iQX73gt8uqqO9ba/BOytqseXe5ypqala838PfWGJz5uehkOHIFnbc0vrZHKyK5TF9uyBr351vdNsDqOc8yTHq2pqqWOrWkNPMgm8GLh/0aHLgK8t2D7V27f4z+9PMptkdm5ubjWnvjBJV94LWeZqzGOPrW6/Lt64zvnAhZ7kEuAu4E1V9a0LOVlVHamqqaqamphY8jcoDdf8FfpCC9fUpQbs3r26/bp44zrnAxV6km10ZT5TVXcvMeTrwOULtnf19o3OwuWW6Wk4d677unBNXWrArbfC9u3n79u+vduvtTGucz7Iu1wCvA94pKrescywe4A39N7tcjXw1Err5+sigR07zl8zP3So296xw2UXNeP66+HIkW79Num+HjnS7dfaGNc57/uiaJKXA/8MPASc6+1+M7AboKoO90r/dmAf8B3gxqpa8RXPdXlRtAt4fnkv3pakDWSlF0W39vvDVfUZYMUGrO6nwk0XFm+NLS5vy1xSo9r/pKgkbRIWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUiL6FnuTOJKeTnFjm+LOSfDTJF5I8nOTG4ceUJPUzyBX6UWDfCsdvAr5YVVcAe4G/SPI9Fx9NkrQafQu9qu4DnlxpCPDMJAEu6Y09O5x4kqRBDWMN/XbgJ4BvAA8B01V1bqmBSfYnmU0yOzc3N4RTS5LmDaPQXwk8APwI8CLg9iQ/sNTAqjpSVVNVNTUxMTGEU0uS5g2j0G8E7q7Oo8BXgB8fwuNKklZhGIX+GPAKgCQ/BPwY8O9DeFxJ0ips7TcgyTG6d6/sTHIKuAXYBlBVh4G3AUeTPAQEuLmqvrlmiSVJS+pb6FV1XZ/j3wB+cWiJJEkXxE+KSlIjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmN6FvoSe5McjrJiRXG7E3yQJKHk/zTcCNKkgYxyBX6UWDfcgeT7ADeDVxTVc8HfnUoySRJq9K30KvqPuDJFYb8GnB3VT3WG396SNkkSaswjDX05wI/mOTTSY4necNyA5PsTzKbZHZubm4Ip5YkzRtGoW8FrgJeA7wS+OMkz11qYFUdqaqpqpqamJgYwqklSfO2DuExTgFPVNW3gW8nuQ+4AvjyEB5bkjSgYVyhfwR4eZKtSbYDLwUeGcLjSpJWoe8VepJjwF5gZ5JTwC3ANoCqOlxVjyT5e+BB4BxwR1Ut+xZHSdLa6FvoVXXdAGPeDrx9KIkkSRfET4pKUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY3oW+hJ7kxyOsmJPuN+KsnZJNcOL54kaVCDXKEfBfatNCDJM4A/B/5hCJkkSRegb6FX1X3Ak32G/S5wF3B6GKEkSat30WvoSS4Dfhl4zwBj9yeZTTI7Nzd3saeWJC0wjBdF3wncXFXn+g2sqiNVNVVVUxMTE0M4tSRp3tYhPMYU8IEkADuBVyc5W1UfHsJjS5IGdNGFXlU/On8/yVHgXstcktZf30JPcgzYC+xMcgq4BdgGUFWH1zSdJGlgfQu9qq4b9MGq6oaLSiNJumB+UlSSGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJakTfQk9yZ5LTSU4sc/z6JA8meSjJZ5NcMfyYkqR+BrlCPwrsW+H4V4CfraqfBN4GHBlCLknSKm3tN6Cq7ksyucLxzy7Y/Bywawi5JEmrNOw19N8EPr7cwST7k8wmmZ2bmxvyqSVpcxtaoSf5ObpCv3m5MVV1pKqmqmpqYmJiWKeWJDHAkssgkrwQuAN4VVU9MYzHlCStzkVfoSfZDdwNvL6qvnzxkSRJF6LvFXqSY8BeYGeSU8AtwDaAqjoMvBV4NvDuJABnq2pqrQJLkpY2yLtcrutz/I3AG4eWSJJ0QfykqCQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUiPYLvWrlbQ2fcy6NRN9CT3JnktNJTixzPEneleTRJA8muXL4MS/QwYNw4MDThVLVbR88OMpUbXPOtUnMzMDkJGzZ0n2dmRl1osGu0I8C+1Y4/irgOb3bfuA9Fx9rCKrgzBm47banC+bAgW77zBmvGteCc65NYmYG9u+Hkye7b+uTJ7vtkZd6VfW9AZPAiWWOvRe4bsH2l4BL+z3mVVddVWvu3Lmq6emqbs672/R0t19rwznXJrBnz/nf4vO3PXvW/tzAbC3Tq6kBrpqSTAL3VtULljh2L/BnVfWZ3vY/AjdX1ewSY/fTXcWze/fuq06ePHkBP4JWqap7TjTv3DlI1v68m5lzrsZt2bL0E86k+3ZfS0mOV9XUkrnW9tTnq6ojVTVVVVMTExPrccLuKf9CC9d3NXzOuTaB3btXt3+9DKPQvw5cvmB7V2/faC1cv52e7n5sTk+fv76r4XLOtUnceits337+vu3bu/2jtHUIj3EP8DtJPgC8FHiqqh4fwuNenAR27OgK5dChbvvQoe7Yjh0uAawF51ybxPXXd1/f8hZ47LHuyvzWW5/ePyp919CTHAP2AjuB/wBuAbYBVNXhJAFup3snzHeAG5daP19samqqZmf7Drt4VecXyeJtDZ9zLq2ZldbQ+16hV9V1fY4XcNMFZlt7i4vEYll7zrk0Eu1/UlSSNgkLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJasRAv7FoTU6czAHr8CuLvmsn8M11PN8wbdTsGzU3bNzsGzU3bNzs6517T1Ut+RuCRlbo6y3J7HL/5OS426jZN2pu2LjZN2pu2LjZxym3Sy6S1AgLXZIasZkK/cioA1yEjZp9o+aGjZt9o+aGjZt9bHJvmjV0SWrdZrpCl6SmWeiS1IimCj3JnUlOJzmxzPEkeVeSR5M8mOTK9c64nAGy703yVJIHere3rnfGpSS5PMmnknwxycNJppcYM3bzPmDucZ3z70vyL0m+0Mv+J0uM+d4kH+zN+f1JJkcQdXGmQXLfkGRuwZy/cRRZl5PkGUn+Ncm9Sxwb/ZxXVTM34GeAK4ETyxx/NfBxIMDVwP2jzryK7HuBe0edc4lclwJX9u4/E/gy8Lxxn/cBc4/rnAe4pHd/G3A/cPWiMb8NHO7dfx3wwQ2S+wbg9lFnXeG/4feBv13q+2Ic5rypK/Squg94coUhrwX+qjqfA3YkuXR90q1sgOxjqaoer6rP9+7/J/AIcNmiYWM37wPmHku9efyv3ua23m3xuxteC7y/d/9DwCuSZJ0iLmnA3GMryS7gNcAdywwZ+Zw3VegDuAz42oLtU2yQv8Q9P917uvrxJM8fdZjFek8xX0x35bXQWM/7CrlhTOe899T/AeA08ImqWnbOq+os8BTw7HUNuYQBcgP8Sm9p7kNJLl/fhCt6J/CHwLlljo98zjdboW9kn6f7NxyuAP4S+PBo45wvySXAXcCbqupbo84zqD65x3bOq+p/q+pFwC7gJUleMOJIAxkg90eByap6IfAJnr7iHakkvwScrqrjo86yks1W6F8HFv7E39XbN/aq6lvzT1er6mPAtiQ7RxwLgCTb6EpxpqruXmLIWM57v9zjPOfzquoM8Clg36JD353zJFuBZwFPrGu4FSyXu6qeqKr/6W3eAVy1ztGW8zLgmiRfBT4A/HySv1k0ZuRzvtkK/R7gDb13XVwNPFVVj4861CCS/PD8elySl9D9vxv5X9BepvcBj1TVO5YZNnbzPkjuMZ7ziSQ7eve/H/gF4N8WDbsH+I3e/WuBT1bv1bpRGST3otdWrqF7bWPkquqPqmpXVU3SveD5yar69UXDRj7nW9fzZGstyTG6dybsTHIKuIXuhReq6jDwMbp3XDwKfAe4cTRJ/78Bsl8L/FaSs8B/A68b9V/QnpcBrwce6q2NArwZ2A1jPe+D5B7XOb8UeH+SZ9D9kPm7qro3yZ8Cs1V1D90Pq79O8ijdi+2vG13c7xok9+8luQY4S5f7hpGlHcC4zbkf/ZekRmy2JRdJapaFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhrxf09l6LOTuZAtAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x_min, x_max = x[:, 0].min(), x[:, 0].max()\n",
"x_deviation = x_max - x_min\n",
"\n",
"y_min, y_max = x[:, 1].min(), x[:, 1].max()\n",
"y_deviation = y_max - y_min\n",
"\n",
"ax.set_xlim(x_min - (x_deviation / 10), x_max + (x_deviation / 10))\n",
"ax.set_ylim(y_min - (y_deviation / 10), y_max + (y_deviation / 10))\n",
"\n",
"ax.scatter(\n",
" np.array([x_i[0] for x_i, y_i in zip(x, y) if y_i == 0], dtype=np.float32),\n",
" np.array([x_i[1] for x_i, y_i in zip(x, y) if y_i == 0], dtype=np.float32),\n",
" marker=\"x\",\n",
" color=\"red\",\n",
")\n",
"ax.scatter(\n",
" np.array([x_i[0] for x_i, y_i in zip(x, y) if y_i == 1], dtype=np.float32),\n",
" np.array([x_i[1] for x_i, y_i in zip(x, y) if y_i == 1], dtype=np.float32),\n",
" marker=\"o\",\n",
" color=\"blue\",\n",
")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "e31b82e8",
"metadata": {},
"source": [
"### Now, we need a model so let's define it"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cc5e72a2",
"metadata": {},
"outputs": [],
"source": [
"class Model(torch.nn.Module):\n",
" def __init__(self, n):\n",
" super(Model, self).__init__()\n",
" self.fc = torch.nn.Linear(n, 1)\n",
"\n",
" def forward(self, x):\n",
" output = torch.sigmoid(self.fc(x))\n",
" return output"
]
},
{
"cell_type": "markdown",
"id": "cefd8346",
"metadata": {},
"source": [
"### And create one\n",
"\n",
"The main purpose of this tutorial is not to train a logistic regression model but to use it homomorphically. So we will not discuss about how the model is trained."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b9879f4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1 | Loss: 0.5758869647979736\n",
"Epoch: 101 | Loss: 0.13611836731433868\n",
"Epoch: 201 | Loss: 0.08021673560142517\n",
"Epoch: 301 | Loss: 0.05636058747768402\n",
"Epoch: 401 | Loss: 0.043306026607751846\n",
"Epoch: 501 | Loss: 0.03511128947138786\n",
"Epoch: 601 | Loss: 0.029501130804419518\n",
"Epoch: 701 | Loss: 0.025424323976039886\n",
"Epoch: 801 | Loss: 0.02233024500310421\n",
"Epoch: 901 | Loss: 0.01990305446088314\n",
"Epoch: 1001 | Loss: 0.0179488193243742\n",
"Epoch: 1101 | Loss: 0.01634199731051922\n",
"Epoch: 1201 | Loss: 0.014997857622802258\n",
"Epoch: 1301 | Loss: 0.013856985606253147\n",
"Epoch: 1401 | Loss: 0.012876608408987522\n",
"Epoch: 1501 | Loss: 0.012025204487144947\n"
]
}
],
"source": [
"model = Model(x.shape[1])\n",
"\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=1)\n",
"criterion = torch.nn.BCELoss()\n",
"\n",
"epochs = 1501\n",
"for e in range(1, epochs + 1):\n",
" optimizer.zero_grad()\n",
"\n",
" out = model(x)\n",
" loss = criterion(out, y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if e % 100 == 1 or e == epochs:\n",
" print(\"Epoch:\", e, \"|\", \"Loss:\", loss.item())"
]
},
{
"cell_type": "markdown",
"id": "01cfc83f",
"metadata": {},
"source": [
"### Time to make some predictions"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "78356d37",
"metadata": {},
"outputs": [],
"source": [
"contour_plot_x_data = np.linspace(x_min - (x_deviation / 10), x_max + 2 * (x_deviation / 10), 250)\n",
"contour_plot_y_data = np.linspace(y_min - (y_deviation / 10), y_max + 2 * (y_deviation / 10), 250)\n",
"contour_plot_x_data, contour_plot_y_data = np.meshgrid(contour_plot_x_data, contour_plot_y_data)\n",
"\n",
"inputs = np.stack((contour_plot_x_data.flatten(), contour_plot_y_data.flatten()), axis=1)\n",
"predictions = model(torch.tensor(inputs).float()).detach().numpy()"
]
},
{
"cell_type": "markdown",
"id": "58160140",
"metadata": {},
"source": [
"### Let's visualize our predictions to see how our model performs"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2a623999",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAT40lEQVR4nO3df2xdZ33H8c+XxluYkyURQcyuw7I/6KbDDeVHRjpRbVnQltJBqilMo9tgrYYibV06tElD449WG39NaAjPFURRqbLescIEFasrWIdy6aKV1VNoC27syaprg22sBRzfkPg2P6773R/nGhzX9r22j+9z7nPfL8nKvec88fn0afLJ8XPOvdfcXQCA1ve60AEAANmg0AEgEhQ6AESCQgeASFDoABCJLaEO3NnZ6bt27Qp1eOTMlStXtGPHDm3btk033XRT6DhAbj3//PM/cvc3LrcvWKHv2rVLx48fD3V45MzQ0JCOHDmi22+/Xdu3bw8dB8itzs7O7620jyUX5EKSJJqfn9fExISmp6dDxwFaEoWO3BgbG1OxWNTVq1dDRwFaEoUOAJGg0AEgEhQ6AESCQkfuXLt2LXQEoCVR6MiVcrmsarWq4eHh0FGAlkOhI1eSJFFfX1/oGEBLotABIBIUOgBEgkJHLrm7xsfHQ8cAWgqFjtxJkkS9vb2qVCq8DQCwBvEX+tLPTOUzVFtCoVBQqVQKHQNoKXXfbdHM9kh6VNKbJLmkk+7eu2SMSeqVdKekiqR73P257OOu0dNPS1euSIcPS2ZpmT/1lLR1q3TwYOh0QGYGB6VSSbp4UdqxQzp0SNq3L3SquOVxzhs5Q69K+it3TyTdJuk+M0uWjHmfpLfUvo5J+lymKdfDPS3zgYG0xBfKfGAg3c6ZOiIxOCj190vlcvrHulxOnw8Ohk4Wr7zOed0zdHefljRde3zJzIYl3SxpaNGwuyQ96u4u6Vkz22lmXbXfG4ZZemYupSU+MJA+PnDgp2fsQARKJen69Ru3Xb+ebg99xhirvM75mtbQzWyvpHdIGliy62ZJE4ueT9a2Lf39x8zsrJmdnZubW2PUdVhc6gso85YxNTWl2dlZ7nap4+LFtW3HxuV1zhsudDPbJukrkj7m7j9ez8Hc/aS773f3/Z2dnev5Fms9YLrMstjC8gtyr7u7W729vXrllVdCR8m1HTvWth0bl9c5b6jQzaxDaZl/wd0fX2bIlKQ9i5731LaFs3jN/MAB6YEH0l8Xr6kj9173uvhvxNqoQ4ekjo4bt3V0pNuxOfI6543c5WKSPi9p2N0/vcKwJyT9uZl9UdIBSReDrp9L6bLK1q03rpkvLL9s3cqyC6KxsGabtzsuYpbXOW/kQ6LfI+nDkgbN7IXatk9IerMkufsJSV9TesviS0pvW7w386TrcfBgeia+UN4LpU6ZIzL79oUvk3aTxzlv5C6X/5K0agPW7m65L6tQmVpa3pQ5gEg1coYOBJMkiUZHR+Xu2rVrl7q6ukJHAnKLK07IvbGxMd4GAGgAhQ4AkaDQASASFDoARIJCR8sol8uhIwC5RqGjJbi7RkdHNTIyEjoKkFsUOlpGf39/6AhArlHoABAJCh0AIkGho6VUq1XW0YEVUOhoGYVCQX19fRoZGdH0dNg38wTyiEJHS0mShE8wAlZAoQNAJCh0AIgEhQ4AkaDQ0XKmpqY0OzvL3S7AEhQ6Wk53d7d6e3tDxwByh0IHgEhQ6AAQCQodACJBoaNlVatVXbp0KXQMIDcodLSkQqGgUqmkiYkJSh2oodDRstxdL7/8cugYQG5Q6AAQCQodACJBoQNAJCh0tDwujAKpuoVuZo+Y2Xkze3GF/TvMrN/MvmNm58zs3uxjAssbGxtTqVTSzMxM6ChAcI2coZ+SdMcq+++TNOTut0o6KOkfzOxnNh4NaMzU1FToCEAu1C10dz8j6cJqQyRtNzOTtK02tppNPABAo7Zk8D0ekvSEpB9I2i7p99391eUGmtkxScckaefOnRkcGgCwIIuLooclvSCpW9LbJT1kZj+/3EB3P+nu+919f2dnZwaHBqQLFy6oUqloeHiYi6Noa1kU+r2SHvfUS5LGJP1KBt8XaEihUFBfXx8fHo22l0Whf1/SeyXJzN4k6Zcl8XpsAGiyumvoZvaY0rtXdpvZpKQHJXVIkrufkPRJSafMbFCSSfq4u/9o0xIDAJZVt9Dd/e46+38g6bczSwQAWBdeKYooJEmi+fl5Xb58OXQUIBgKHdF45plnNDs7y8VRtC0KHdHo7u5WsVgMHQMIhkIHgEhQ6AAQCQodACJBoSM6lUpF09PToWMATUehIyoLF0bL5XLoKEDTUeiITrlc5tZFtCUKHQAiQaEDQCQodACIBIWOKM3Pz2toaIi7XdBWKHREJ0kSjY2NqVgs6urVq6HjAE1DoQNAJCh0AIgEhY6oXbt2LXQEoGkodESrXC6rWq1qeHg4dBSgKSh0RCtJEvX19YWOATQNhQ4AkaDQASASFDoARIJCR/TcnXdfRFug0BG1JEnU29vLh16gLVDoiF6hUFCpVAodA9h0FDoARIJCB4BI1C10M3vEzM6b2YurjDloZi+Y2Tkz+89sIwIAGtHIGfopSXestNPMdkr6rKQj7v5WSb+XSTIgY7Ozs9ztgqjVLXR3PyPpwipD/kDS4+7+/dr48xllAzLj7urt7eXNuhC1LNbQb5G0y8yeNrNvm9lHVhpoZsfM7KyZnZ2bm8vg0ACABVsy+h7vkvReSa+X9N9m9qy7jywd6O4nJZ2UpJ6eHs/g2ACAmiwKfVLSjLvPSZozszOSbpX0mkIHAGyeLJZc/k3S7Wa2xcx+TtIBSbwBNXJpfn5ely5dCh0D2BR1z9DN7DFJByXtNrNJSQ9K6pAkdz/h7sNm9u+SvivpVUkPu/uKtzgCoRQKBY2OjsrdtWvXLnV1dYWOBGSqbqG7+90NjPmUpE9lkgjYRGNjYxofH9fRo0dDRwEyxytFASASFDoARIJCR1viwihilMVti0BLOXfunPbu3StJuuWWW8KGATLEGTraTpIk6u/vDx0DyByFDgCRoNABIBIUOgBEgkJH26pWqxoZ4S2HEA8KHW2pUCior69PIyMj3MKIaFDoaFtJkoSOAGSKQgeASFDoABAJCh1t7/Lly6EjAJmg0NHWnnnmGc3Ozmp8fDx0FGDDKHS0te7ubhWLxdAxgExQ6AAQCQodACJBoaPtXbhwQZVKhRcYoeVR6Gh7hUJBpVJJExMTlDpaGoUOSHJ3vfzyy6FjABtCoQNAJCh0AIgEhQ4AkaDQgUUmJye5MIqWRaEDNWNjYzp9+rRmZmZCRwHWhUIHFpmamgodAVi3uoVuZo+Y2Xkze7HOuF81s6qZfTC7eACARjVyhn5K0h2rDTCzmyT9vaT/yCATAGAd6ha6u5+RdKHOsOOSviLpfBahAABrt+E1dDO7WdLvSvpcA2OPmdlZMzs7Nze30UMDmVt4X5fh4eHQUYA1y+Ki6GckfdzdX6030N1Puvt+d9/f2dmZwaGBbBUKBfX19Wl8fJzbF9FytmTwPfZL+qKZSdJuSXeaWdXdv5rB9wYANGjDhe7uv7Tw2MxOSXqSMgeA5qtb6Gb2mKSDknab2aSkByV1SJK7n9jUdACAhtUtdHe/u9Fv5u73bCgNkBPz8/OamZnR9u3bQ0cBGsYrRYElkiRRf3+/KpWKxsfHQ8cBGkahA8soFAoqFouhYwBrQqEDQCQodACIBIUOrKJSqWh6ejp0DKAhFDqwgu7ubhWLRZXL5dBRgIZQ6MAqKHO0EgodACJBoQNAJCh0AIgEhQ7U4e4aGhribhfkHoUOrCJJEp0+fVqlUil0FKAuCh0AIkGhA0AkKHQAiASFDjSIzxhF3lHoQAPOnTunarWqkZGR0FGAFVHoQAOSJFFvb2/oGMCqKHQAiASFDgCRoNABIBIUOrAG1WqVD45GblHoQIMKhYJ6e3v5FCPkFoUOrEGhUOB9XZBbFDoARIJCB4BIUOjAOszOznJxFLlTt9DN7BEzO29mL66w/w/N7LtmNmhm3zKzW7OPCeSHu6u3t1fXrl0LHQW4QSNn6Kck3bHK/jFJv+Hu+yR9UtLJDHIBANZoS70B7n7GzPausv9bi54+K6kng1wAgDXKeg39TyR9faWdZnbMzM6a2dm5ubmMDw0A7a3uGXqjzOw3lRb67SuNcfeTqi3J9PT0eFbHBkKYn58PHQG4QSaFbmZvk/SwpPe5+0wW3xPIs0KhoNHRUbm79uzZo+3bt4eOBGx8ycXM3izpcUkfdnfe/R9tY2xsjFeNIlfqnqGb2WOSDkrabWaTkh6U1CFJ7n5C0gOS3iDps2YmSVV3379ZgQEAy2vkLpe76+z/qKSPZpYIALAuvFIUACJBoQMbNDk5GToCIIlCBzbE3TU6OqqREe4HQHgUOrBB/f39oSMAkih0AIgGhQ4AkaDQASASFDqQgWq1yoVRBEehAxtUKBTU19enkZERXbp0KXQctDEKHchAkiShIwAUOgDEgkIHgEhQ6EBGxsfHNTExofHx8dBR0KYodCAj7q5isRg6BtoYhQ4AkaDQASASFDoARIJCBzJWqVR4gRGCoNCBDHV3d6tUKmlycpJSR9NR6EDGzp07x62LCIJCB4BIUOgAEAkKHQAiQaEDm2B+fp4Lo2g6Ch3IWJIkGhsb0+nTpzUzMxM6DtoIhQ5skqmpqdAR0GYodACIRPyF7r76c2SPOQeC2FJvgJk9Iun9ks67e2GZ/SapV9KdkiqS7nH357IOui5PPy1duSIdPiyZpcXy1FPS1q3SwYOh08WJOUebGByUSiXp4kVpxw7p0CFp376wmRo5Qz8l6Y5V9r9P0ltqX8ckfW7jsTLgnhbLwEBaKAvFMjCQbuesMXvM+Q0uXLigSqWi4eHh0FGQscFBqb9fKpfTP9blcvp8cDBsrrpn6O5+xsz2rjLkLkmPurtLetbMdppZl7tPZxVyXczSs0QpLZSBgfTxgQM/PXtEtpjzGxQKBfX19en+++8PHQUZK5Wk69dv3Hb9ero95Fl6FmvoN0uaWPR8srbtNczsmJmdNbOzc3NzGRy6jsUFs6ANi6WpmHO0gYsX17a9WZp6UdTdT7r7fnff39nZ2YwDpj/yL7awFIDNwZyjDezYsbbtzZJFoU9J2rPoeU9tW1iL128PHJAeeCD9dfH6LrLFnC/L3TU9HXYFEtk6dEjq6LhxW0dHuj2kLAr9CUkfsdRtki4GXz+X0h/xt269cf328OH0+datLAFsBub8NZIkUbFY1OzsLKUekX37pA98QNq5M/1jvXNn+jz0XS7mdc6azOwxSQcl7Zb0f5IelNQhSe5+onbb4kNK74SpSLrX3c/WO3BPT48fP358Q+Eb4n5jkSx9juwx569hZjp69Ki6urpCR0GL6+zs/La7719uXyN3udxdZ79Lum+d2Tbf0iJp82JpCuYcCCL+V4oCQJug0AEgEhQ60CRcGMVmo9CBJnB3FYtFPvACm4pCB5qkXC6HjoDIUegAEAkKHQAiQaEDQCQodKCJqtWqhoaGuNsFm4JCB5okSRKdPn1apVIpdBREikIHgEhQ6AAQibrvtrhpBzb7oaTvNfGQuyX9qInHy1KrZm/V3FLrZm/V3FLrZm927l909zcutyNYoTebmZ1d6S0n865Vs7dqbql1s7dqbql1s+cpN0suABAJCh0AItFOhX4ydIANaNXsrZpbat3srZpbat3sucndNmvoABC7djpDB4CoUegAEImoCt3MHjGz82b24gr7zcz+0cxeMrPvmtk7m51xJQ1kP2hmF83shdrXA83OuBwz22Nm3zSzITM7Z2Z/scyY3M17g7nzOudbzex/zOw7tex/u8yYnzWzL9XmfMDM9gaIujRTI7nvMbMfLprzj4bIuhIzu8nMnjezJ5fZF37O3T2aL0m/Lumdkl5cYf+dkr4uySTdJmkgdOY1ZD8o6cnQOZfJ1SXpnbXH2yWNSEryPu8N5s7rnJukbbXHHZIGJN22ZMyfSTpRe/whSV9qkdz3SHoodNZV/hv+UtK/LPfnIg9zHtUZurufkXRhlSF3SXrUU89K2mlmXc1Jt7oGsueSu0+7+3O1x5ckDUu6ecmw3M17g7lzqTaPl2tPO2pfS+9uuEvSP9Uef1nSe83MmhRxWQ3mzi0z65H0O5IeXmFI8DmPqtAbcLOkiUXPJ9Uif4lrfq324+rXzeytocMsVfsR8x1Kz7wWy/W8r5Jbyumc1370f0HSeUnfcPcV59zdq5IuSnpDU0Muo4HcknS0tjT3ZTPb09yEq/qMpL+W9OoK+4PPebsVeit7Tul7ONwqqU/SV8PGuZGZbZP0FUkfc/cfh87TqDq5czvn7j7v7m+X1CPp3WZWCBypIQ3k7pe0193fJukb+ukZb1Bm9n5J593926GzrKbdCn1K0uJ/8Xtq23LP3X+88OOqu39NUoeZ7Q4cS5JkZh1KS/EL7v74MkNyOe/1cud5zhe4e1nSNyXdsWTXT+bczLZI2iFppqnhVrFSbnefcfertacPS3pXk6Ot5D2SjpjZuKQvSjpkZv+8ZEzwOW+3Qn9C0kdqd13cJumiu7fER8eY2S8srMeZ2buV/r8L/he0lunzkobd/dMrDMvdvDeSO8dz/kYz21l7/HpJvyXpf5cMe0LSH9cef1BSyWtX60JpJPeSaytHlF7bCM7d/8bde9x9r9ILniV3/6Mlw4LP+ZZmHmyzmdljSu9M2G1mk5IeVHrhRe5+QtLXlN5x8ZKkiqR7wyR9rQayf1DSn5pZVdIrkj4U+i9ozXskfVjSYG1tVJI+IenNUq7nvZHceZ3zLkn/ZGY3Kf1H5l/d/Ukz+ztJZ939CaX/WBXN7CWlF9s/FC7uTzSS+34zOyKpqjT3PcHSNiBvc85L/wEgEu225AIA0aLQASASFDoARIJCB4BIUOgAEAkKHQAiQaEDQCT+H+Ww0z5fuBGwAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"contour = ax.contourf(\n",
" contour_plot_x_data,\n",
" contour_plot_y_data,\n",
" predictions.round().reshape(contour_plot_x_data.shape),\n",
" cmap=\"gray\",\n",
" alpha=0.50,\n",
")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "d3f39faa",
"metadata": {},
"source": [
"### As a bonus let's inspect the model parameters"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7fa65211",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[4.53586054]\n",
" [2.37015319]]\n",
"-14.660321235656738\n"
]
}
],
"source": [
"w = np.array(model.fc.weight.flatten().tolist()).reshape((-1, 1))\n",
"b = model.fc.bias.flatten().tolist()[0]\n",
"\n",
"print(w)\n",
"print(b)"
]
},
{
"cell_type": "markdown",
"id": "544d6e34",
"metadata": {},
"source": [
"They are floating point numbers and we can't directly work with them!"
]
},
{
"cell_type": "markdown",
"id": "abf310f2",
"metadata": {},
"source": [
"### So, let's abstract quantization\n",
"\n",
"Here is a quick summary of quantization. We have a range of values and we want to represent them using small number of bits (n). To do this, we split the range into 2^n sections and map each section to a value. Here is a visualization of the process!"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d3ab2aa2",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" version=\"1.1\" width=\"420px\" height=\"195px\" viewBox=\"-0.5 -0.5 420 195\" content=\"&lt;mxfile host=&quot;app.diagrams.net&quot; modified=&quot;2021-08-13T09:47:25.144Z&quot; agent=&quot;5.0 (X11)&quot; etag=&quot;5QhM0DGu1eUjmjeXuyFL&quot; version=&quot;14.9.6&quot; type=&quot;device&quot;&gt;&lt;diagram id=&quot;6rZNNX4_K12e_kCXuZoG&quot; name=&quot;Page-1&quot;&gt;7Zzdb5s6FMD/mkjdw66MHZL0sUl376SrStM6rc8euAGNYAZOk/avnw2YD5t8QAmhJA+tyLFzbM7v2D7HdjtCi9X2vxAHzgO1iTeCwN6O0P0IQhPd8t9C8JoI0MRMBMvQtRORkQse3TeSCkEqXbs2iUoVGaUec4Oy0KK+TyxWkuEwpJtytWfqlVsN8JJogkcLe7r0ybWZk0hnJsjlX4m7dGTLBkhLVlhWTgWRg226KYjQlxFahJSy5Gm1XRBP2E7aJfnevztKs46FxGfHfOFpYrx9//Hw//Lpq8mc+yd4H5mfUy0v2FunL5x2lr1KCxDfvhOG5J8sD0eRa43QPGI4ZLrYYSuPCwz+mOghtmbevL9GZgXuPYSuCAtfeZVNbmdpZqdgYikLiYeZ+1JWj1Pcy0xd1sI36vKGIUg905ikelLHhGNQVhHRdWiR9FtFuyqKTHhAETfVkjBNEX8ovHYuirHVQAhrIfSpTz4WKAgU+5oNQWmKULeg0NBBjdsCpSrqGNR46KBmbYFSFXUMyhw4KKQuLU1BaYo6BjUZOig1mGgMSlXUMajpwEGN2womNEUdg5oNHVRbwYSmqGNQt0MH1VYwoSnqGJTccaidC38gWFoKBKfNYGmRn6poByxuPfxaqBaICtGeDquBC5ju75e6fpbr84ekB+16DmzoOR9/F0Xzg+ltM4dCqmeqik49+utl50NieND0RzNUl9quGdZL3AfF8JDpj2Z4aECfmmG9nH5QDNuaS7XcpGuGVen+xOPWmtvuC39ciseV699sP8kC3lChTEPOyJaVaUYspL/Jgno0zOPmZ9fzFBH23KUvXIJDJlw+fyEhcy3s3aUFK9e2RTPzjeMy8hhgS7S5CXHAZSFd+zYRLwuybgkFZFvXh3YFKNJ5Cj42rvAxCHa7U4lfbVh6yr/C25jMUBloA+3sDBpl8zaOnNgsRpmKkH/DjFvajyUQoIyVPMOEH2A+hLdlTAiMm82HqiI4UxSdej6s2gMozIcFzJM/a3EePH+mPvscxafhd7yCAYJtTEyW5xNlokfUf5eiBz7U8qk4UVduQp2he/YCjI7i7DZR9ytspVdHmqKXE6XoeHqfwmhr8ZqqI3J23olTKr6OrFOPLOM6sk45stQc/Pwjq97NoAsJSdRz1MYhiXYg23FIImle8e47fW2OV83yusbb6ArL4PG2lVCoijrHW+8+xZA2ybTjXDj+ZzJphlE7cazQdWqSVRcuGsRtWpx173LWPo9hCmFbUjInbEOIrxcsqB8Rax3D0Qp/ig5GO6O3XsZpo/fHZaq7GdPcRc4VmqGqtbtnOcMwkh5Y8x16OQpOka2Me5atoKrT25551DCGBLoOieoh0betMaSnADcbR6z7QPzEEgD7espzis1LJdpD8n5ZgZABO0Wkn3WDy+Fh9O0UDh11bF17OoVV4Xky/EQqLfrM1cR/cdh4eThl1/5cZ/zq1GByVGpgGJ36cAv5ZKVXbOWasdcX3ul5ENwIMV4JgP6vKEgKlc/cduL3W1BR9KnotmBnby/FR9U1L5tziw7a7SRbddXh/TNZ5R7IdebaEavKCwmZV5hnXnkPXJDIpqAL52Ya/eImG7ty289t2rPxNq463NW4vZFQZPxBsvetbggfHZWKdVqmm5fuCerf3xjy4KjgCbN2PIF/zP+3SHJckf+DFvTlLw==&lt;/diagram&gt;&lt;/mxfile&gt;\"><defs/><g><path d=\"M 14.37 84 L 361.63 84\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 9.12 84 L 16.12 80.5 L 14.37 84 L 16.12 87.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 366.88 84 L 359.88 87.5 L 361.63 84 L 359.88 80.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 48 94 L 48 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 88 94 L 88 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 128 94 L 128 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 168 94 L 168 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 208 94 L 208 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 248 94 L 248 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 288 94 L 288 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 328 94 L 328 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 48 71 L 60.93 58.07 Q 68 51 78 51 L 98 51 Q 108 51 115.07 58.07 L 123.5 66.5\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 127.21 70.21 L 119.78 67.73 L 123.5 66.5 L 124.73 62.78 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 134.37 123 L 141.63 123\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 129.12 123 L 136.12 119.5 L 134.37 123 L 136.12 126.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 146.88 123 L 139.88 126.5 L 141.63 123 L 139.88 119.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 154.37 123 L 181.63 123\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 149.12 123 L 156.12 119.5 L 154.37 123 L 156.12 126.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 186.88 123 L 179.88 126.5 L 181.63 123 L 179.88 119.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 194.37 123 L 221.63 123\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 189.12 123 L 196.12 119.5 L 194.37 123 L 196.12 126.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 226.88 123 L 219.88 126.5 L 221.63 123 L 219.88 119.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 234.37 123 L 241.63 123\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 229.12 123 L 236.12 119.5 L 234.37 123 L 236.12 126.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 246.88 123 L 239.88 126.5 L 241.63 123 L 239.88 119.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><rect x=\"108\" y=\"94\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 109px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>min(x)</div></div></div></div></foreignObject><text x=\"128\" y=\"108\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\">min(x)</text></switch></g><rect x=\"228\" y=\"94\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 229px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">max(x)</div></div></div></foreignObject><text x=\"248\" y=\"108\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\">max(x)</text></switch></g><path d=\"M 138 148 L 138 128\" fill=\"none\" stroke=\"#000000\" stroke-width=\"2\" stroke-miterlimit=\"10\" stroke-dasharray=\"2 6\" pointer-events=\"stroke\"/><rect x=\"118\" y=\"152\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 119px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 0<br style=\"font-size: 10px\"/></font></div></div></div></div></foreignObject><text x=\"138\" y=\"165\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">Map...</text></switch></g><rect x=\"148\" y=\"152\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 149px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 1<br style=\"font-size: 10px\"/></font></div></div></div></div></foreignObject><text x=\"168\" y=\"165\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">Map...</text></switch></g><path d=\"M 168 148 L 168 128\" fill=\"none\" stroke=\"#000000\" stroke-width=\"2\" stroke-miterlimit=\"10\" stroke-dasharray=\"2 6\" pointer-events=\"stroke\"/><path d=\"M 208 148 L 208 128\" fill=\"none\" stroke=\"#000000\" stroke-width=\"2\" stroke-miterlimit=\"10\" stroke-dasharray=\"2 6\" pointer-events=\"stroke\"/><path d=\"M 238 148 L 238 128\" fill=\"none\" stroke=\"#000000\" stroke-width=\"2\" stroke-miterlimit=\"10\" stroke-dasharray=\"2 6\" pointer-events=\"stroke\"/><path d=\"M 294.37 68.66 L 321.63 68.66\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 289.12 68.66 L 296.12 65.16 L 294.37 68.66 L 296.12 72.16 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 326.88 68.66 L 319.88 72.16 L 321.63 68.66 L 319.88 65.16 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><rect x=\"288\" y=\"18.66\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 29px; margin-left: 289px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><font style=\"font-size: 10px\">Distance<br/>Between<br/>Consecutive<br/>Values</font></div></div></div></foreignObject><text x=\"308\" y=\"32\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\">Distan...</text></switch></g><rect x=\"188\" y=\"152\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 189px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 2</font></div></div></div></div></foreignObject><text x=\"208\" y=\"165\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">Map...</text></switch></g><rect x=\"218\" y=\"152\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 219px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 3</font></div></div></div></div></foreignObject><text x=\"238\" y=\"165\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">Map...</text></switch></g><rect x=\"128\" y=\"174\" width=\"120\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 184px; margin-left: 129px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">(when n = 2)</div></div></div></foreignObject><text x=\"188\" y=\"187\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">(when n = 2)</text></switch></g><rect x=\"28\" y=\"94\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 29px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">0</div></div></div></foreignObject><text x=\"48\" y=\"107\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">0</text></switch></g><rect x=\"308\" y=\"18.66\" width=\"110\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 108px; height: 1px; padding-top: 29px; margin-left: 309px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div><font style=\"font-size: 12px\">= 1 / scale</font></div><div><font style=\"font-size: 12px\">= 1 / q</font></div></div></div></div></foreignObject><text x=\"363\" y=\"32\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">= 1 / scale...</text></switch></g><rect x=\"128\" y=\"24\" width=\"140\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 138px; height: 1px; padding-top: 34px; margin-left: 129px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><font style=\"font-size: 12px\">x =</font><font style=\"font-size: 12px\"> (x   + zp  ) / q </font></div></div></div></foreignObject><text x=\"198\" y=\"37\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">x = (x   + zp  ) / q </text></switch></g><rect x=\"167\" y=\"29\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 168px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div><font style=\"font-size: 10px\">q</font></div></div></div></div></foreignObject><text x=\"187\" y=\"42\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">q</text></switch></g><rect x=\"199\" y=\"29\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 200px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>x</div></div></div></div></foreignObject><text x=\"219\" y=\"42\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">x</text></switch></g><rect x=\"227\" y=\"29\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 228px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>x</div></div></div></div></foreignObject><text x=\"247\" y=\"42\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">x</text></switch></g><rect x=\"48\" y=\"28\" width=\"80\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 78px; height: 1px; padding-top: 38px; margin-left: 49px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>zero point<br/></div><div>zp = 2</div></div></div></div></foreignObject><text x=\"88\" y=\"41\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">zero point...</text></switch></g></g><switch><g requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"/><a transform=\"translate(0,-5)\" xlink:href=\"https://www.diagrams.net/doc/faq/svg-export-text-problems\" target=\"_blank\"><text text-anchor=\"middle\" font-size=\"10px\" x=\"50%\" y=\"100%\">Viewer does not support full SVG 1.1</text></a></switch></svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import SVG\n",
"SVG(filename=\"figures/QuantizationVisualized.svg\")"
]
},
{
"cell_type": "markdown",
"id": "e4314038",
"metadata": {},
"source": [
"If you want to learn more, head to https://intellabs.github.io/distiller/algo_quantization.html"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a8bab855",
"metadata": {},
"outputs": [],
"source": [
"class QuantizationParameters:\n",
" def __init__(self, q, zp, n):\n",
" # q = scale factor = 1 / distance between consecutive values\n",
" # zp = zero point which is used to determine the beginning of the quantized range\n",
" # (quantized 0 = the beginning of the quantized range = zp * distance between consecutive values)\n",
" # n = number of bits\n",
" \n",
" # e.g.,\n",
" \n",
" # n = 2\n",
" # zp = 2\n",
" # q = 0.66\n",
" # distance between consecutive values = 1 / q = 1.5151\n",
" \n",
" # quantized 0 = zp / q = zp * distance between consecutive values = 3.0303\n",
" # quantized 1 = quantized 0 + distance between consecutive values = 4.5454\n",
" # quantized 2 = quantized 1 + distance between consecutive values = 6.0606\n",
" # quantized 3 = quantized 2 + distance between consecutive values = 7.5757\n",
" \n",
" self.q = q\n",
" self.zp = zp\n",
" self.n = n\n",
"\n",
"class QuantizedArray:\n",
" def __init__(self, values, parameters):\n",
" # values = quantized values\n",
" # parameters = parameters used during quantization\n",
" \n",
" # e.g.,\n",
" \n",
" # values = [1, 0, 2, 1]\n",
" # parameters = QuantizationParameters(q=0.66, zp=2, n=2)\n",
" \n",
" # original array = [4.5454, 3.0303, 6.0606, 4.5454]\n",
" \n",
" self.values = np.array(values)\n",
" self.parameters = parameters\n",
"\n",
" @staticmethod\n",
" def of(x, n):\n",
" if not isinstance(x, np.ndarray):\n",
" x = np.array(x)\n",
"\n",
" min_x = x.min()\n",
" max_x = x.max()\n",
"\n",
" if min_x == max_x: # encoding single valued arrays\n",
" \n",
" if min_x == 0.0: # encoding 0s\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = 0 --> q_x = 1 && zp_x = 0 && x_q = 0\n",
" q_x = 1\n",
" zp_x = 0\n",
" x_q = np.zeros(x.shape, dtype=np.uint)\n",
" \n",
" elif min_x < 0.0: # encoding negative scalars\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = -x --> q_x = 1 / x & zp_x = -1 & x_q = 0\n",
" q_x = abs(1 / min_x)\n",
" zp_x = -1\n",
" x_q = np.zeros(x.shape, dtype=np.uint)\n",
" \n",
" else: # encoding positive scalars\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = x --> q_x = 1 / x & zp_x = 0 & x_q = 1\n",
" q_x = 1 / min_x\n",
" zp_x = 0\n",
" x_q = np.ones(x.shape, dtype=np.uint)\n",
" \n",
" else: # encoding multi valued arrays\n",
" \n",
" # distance between consecutive values = range of x / number of different quantized values = (max_x - min_x) / (2^n - 1)\n",
" # q = 1 / distance between consecutive values\n",
" q_x = (2**n - 1) / (max_x - min_x)\n",
" \n",
" # zp = what should be added to 0 to get min_x -> min_x = (0 + zp) / q -> zp = min_x * q\n",
" zp_x = int(round(min_x * q_x))\n",
" \n",
" # x = (x_q + zp) / q -> x_q = (x * q) - zp\n",
" x_q = ((q_x * x) - zp_x).round().astype(np.uint)\n",
"\n",
" return QuantizedArray(x_q, QuantizationParameters(q_x, zp_x, n))\n",
"\n",
" def dequantize(self):\n",
" # x = (x_q + zp) / q\n",
" # x = (x_q + zp) / q\n",
" return (self.values.astype(np.float32) + float(self.parameters.zp)) / self.parameters.q\n",
"\n",
" def affine(self, w, b, min_y, max_y, n_y):\n",
" # the formulas used in this method was derived from the following equations\n",
" #\n",
" # x = (x_q + zp_x) / q_x\n",
" # w = (w_q + zp_w) / q_w\n",
" # b = (b_q + zp_b) / q_b\n",
" #\n",
" # (x * w) + b = ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b)\n",
" # = y = (y_q + zp_y) / q_y\n",
" #\n",
" # So, ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b) = (y_q + zp_y) / q_y\n",
" # We can calculate zp_y and q_y from min_y, max_y, n_y. So, the only unknown is y_q and it can be solved.\n",
"\n",
" x_q = self.values\n",
" w_q = w.values\n",
" b_q = b.values\n",
"\n",
" q_x = self.parameters.q\n",
" q_w = w.parameters.q\n",
" q_b = b.parameters.q\n",
"\n",
" zp_x = self.parameters.zp\n",
" zp_w = w.parameters.zp\n",
" zp_b = b.parameters.zp\n",
"\n",
" q_y = (2**n_y - 1) / (max_y - min_y)\n",
" zp_y = int(round(min_y * q_y))\n",
"\n",
" y_q = (q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b))\n",
" y_q -= min_y * q_y\n",
" y_q = y_q.round().clip(0, 2**n_y - 1).astype(np.uint)\n",
"\n",
" return QuantizedArray(y_q, QuantizationParameters(q_y, zp_y, n_y))\n",
"\n",
"class QuantizedFunction:\n",
" def __init__(self, table, input_parameters=None, output_parameters=None):\n",
" self.table = table\n",
" self.input_parameters = input_parameters\n",
" self.output_parameters = output_parameters\n",
"\n",
" @staticmethod\n",
" def of(f, input_bits, output_bits):\n",
" domain = np.array(range(2**input_bits), dtype=np.uint)\n",
" table = f(domain).round().clip(0, 2**output_bits - 1).astype(np.uint)\n",
" return QuantizedFunction(table)\n",
"\n",
" @staticmethod\n",
" def plain(f, input_parameters, output_bits):\n",
" n = input_parameters.n\n",
"\n",
" domain = np.array(range(2**n), dtype=np.uint)\n",
" inputs = QuantizedArray(domain, input_parameters).dequantize()\n",
"\n",
" outputs = f(inputs)\n",
" quantized_outputs = QuantizedArray.of(outputs, output_bits)\n",
"\n",
" table = quantized_outputs.values\n",
" output_parameters = quantized_outputs.parameters\n",
"\n",
" return QuantizedFunction(table, input_parameters, output_parameters)\n",
"\n",
" def apply(self, x):\n",
" assert x.parameters == self.input_parameters\n",
" return QuantizedArray(self.table[x.values], self.output_parameters)"
]
},
{
"cell_type": "markdown",
"id": "e5be0800",
"metadata": {},
"source": [
"### Let's quantize our model parameters\n",
"\n",
"Since the parameters only consist of scalars, we can use a single bit quantization."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3ec0ad9b",
"metadata": {},
"outputs": [],
"source": [
"parameter_bits = 1\n",
"\n",
"w_q = QuantizedArray.of(w, parameter_bits)\n",
"b_q = QuantizedArray.of(b, parameter_bits)"
]
},
{
"cell_type": "markdown",
"id": "b43c0371",
"metadata": {},
"source": [
"### And quantize our inputs"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "20cea447",
"metadata": {},
"outputs": [],
"source": [
"input_bits = 5\n",
"\n",
"x = inputs\n",
"x_q = QuantizedArray.of(inputs, input_bits)"
]
},
{
"cell_type": "markdown",
"id": "ca76b68d",
"metadata": {},
"source": [
"### Time to make quantized inference"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8728e939",
"metadata": {},
"outputs": [],
"source": [
"output_bits = 7\n",
"\n",
"intermediate = x @ w + b\n",
"intermediate_q = x_q.affine(w_q, b_q, intermediate.min(), intermediate.max(), output_bits)\n",
"\n",
"sigmoid = QuantizedFunction.plain(lambda x: 1 / (1 + np.exp(-x)), intermediate_q.parameters, output_bits)\n",
"y_q = sigmoid.apply(intermediate_q)\n",
"\n",
"quantized_predictions = y_q.dequantize()"
]
},
{
"cell_type": "markdown",
"id": "ab782b4a",
"metadata": {},
"source": [
"### And visualize the results"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "9d2bb5da",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQaElEQVR4nO3df2xdZ33H8fd3jTtTx0siUrGkDssqkY2QUKAZ6QTavERbf6FW05hGt8FaDaXaSje0SkPjj1Ybf01oiK4IoqhUoRtrmaBibVXWoYQuGqyeQltwSqaqagOERgq0iuncdkrW7/4419Qxtu+xc67P9eP3S7rqvec8vueTp/Ynx889NzcyE0nS8vczbQeQJDXDQpekQljoklQIC12SCmGhS1IhVrV14KGhoVy3bl1bh1ehTp8+zYUXXsj555/fdhSpJx5//PEfZeaFs+1rrdDXrVvHzTff3NbhVajnnnuOG2+8kc2bN7cdReqJoaGh7861zyUXSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBWi/EKf+ZmpfoaqpEJ1/dcWI2ITcDfwBiCBfZl5+4wxAdwOXAW8BFyfmY81H3eBHnkEXnkFLr8cIqoyf/hhGByE0dG200mNGR+HgwdhYgLWrIFdu2D79rZTla0f57zOGfoZ4JbM3ApcBtwUEVtnjLkSeFPntgf4TKMpFyOzKvOxsarEp8p8bKza7pm6CjE+Dg88AKdOVd/Wp05Vj8fH205Wrn6d865n6Jl5AjjRuf9iRBwFLgK+M23YtcDdmZnAoxGxNiI2dL62HRHVmTlUJT42Vt3fufO1M3apAAcPwunTZ287fbra3vYZY6n6dc4XtIYeEZuBtwNjM3ZdBHx/2uPjnW0zv35PRByOiMOTk5MLjLoI00t9imWuwkxMLGy7zl2/znntQo+I1cCXgA9n5o8Xc7DM3JeZOzJzx9DQ0GKeYqEHrJZZpptafpEKsWbNwrbr3PXrnNcq9IgYoCrzz2fmfbMM+QGwadrjkc629kxfM9+5E269tfrv9DV1qQC7dsHAwNnbBgaq7eqNfp3zOle5BPBZ4GhmfmKOYfcDH4qIe4GdwESr6+dQLasMDp69Zj61/DI46LKLijG1ZttvV1yUrF/nvM6HRL8LeD8wHhFPdLZ9FHgjQGbuBR6iumTxaarLFm9oPOlijI5WZ+JT5T1V6pa5CrN9e/tlstL045zXucrlP4B5G7BzdctNTYVq1MzytswlFar8d4pK0gphoUtSIeqsoUvLyssvv8yJE/Vek1+9ejXDw8M9TiQtDQtdRdm4cSMHDhyoPX7Xrl1s2rTJUlcRXHJRcTKz9u2ZZ55pO67UGAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJ0LfSIuCsiTkbEkTn2r4mIByLiWxHxZETc0HxMSVI3dc7Q9wNXzLP/JuA7mXkJMAr8XUScf+7RJEkL0bXQM/MQ8MJ8Q4DhiAhgdWfsmWbiSZLqWtXAc3wKuB94DhgGfi8zX51tYETsAfYArF27toFDS5KmNPGi6OXAE8BG4G3ApyLi52YbmJn7MnNHZu4YGhpq4NCSpClNFPoNwH1ZeRp4FvjlBp5XkrQATRT694DdABHxBuCXgGcaeF5J0gJ0XUOPiHuorl5ZHxHHgduAAYDM3At8DNgfEeNAAB/JzB/1LLEkaVZdCz0zr+uy/zngtxpLJElaFN8pKkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCNPGJRdKydezYMS6++OLa41etWsWWLVt6mEhaPAtdK1pmcuDAgVpjjxw5wi233NLjRNLiueQiSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUiK6FHhF3RcTJiDgyz5jRiHgiIp6MiH9vNqIkqY46Z+j7gSvm2hkRa4FPA9dk5luA320kmSRpQboWemYeAl6YZ8jvA/dl5vc64082lE2StABNrKFvAdZFxCMR8c2I+MBcAyNiT0QcjojDk5OTDRxakjSliU8sWgVcCuwGXgf8Z0Q8mplPzRyYmfuAfQAjIyPZwLElSR1NFPpx4PnMnAQmI+IQcAnwU4UuSeqdJpZc/gV4d0SsiogLgJ3A0QaeV5K0AF3P0CPiHmAUWB8Rx4HbgAGAzNybmUcj4l+BbwOvAndm5pyXOEqSeqNroWfmdTXGfBz4eCOJJEmL4jtFJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgqxqu0A0nJy5swZnnrqqQV9zYYNGxgeHu5RIuk1FrpU07Zt27j99tsX9DXbt29n9+7dvPnNb+5RKuk1Frq0ANu2bVvQ+CeffJLdu3f3KI10NtfQJakQFrokFcJCl6RCdC30iLgrIk5GxJEu434lIs5ExHubiydJqqvOGfp+4Ir5BkTEecDfAv/WQCZJ0iJ0LfTMPAS80GXYzcCXgJNNhJIkLdw5r6FHxEXAbwOfqTF2T0QcjojDk5OT53poSdI0Tbwo+kngI5n5areBmbkvM3dk5o6hoaEGDi1JmtLEG4t2APdGBMB64KqIOJOZX27guSVJNZ1zoWfmL07dj4j9wIOWuSQtva6FHhH3AKPA+og4DtwGDABk5t6eppMk1da10DPzurpPlpnXn1MaSdKi+U5RSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQXQs9Iu6KiJMRcWSO/X8QEd+OiPGI+EZEXNJ8TElSN3XO0PcDV8yz/1ng1zNzO/AxYF8DuSRJC7Sq24DMPBQRm+fZ/41pDx8FRhrIJUlaoK6FvkB/DHxlrp0RsQfYA7B27dqGDy31rxdffLH22OHh4R4mUckaK/SI+A2qQn/3XGMycx+dJZmRkZFs6thSv9q6dSsHDhxg8+bNtcZffPHFXHDBBbXHS9M1UugR8VbgTuDKzHy+ieeUSpGZPPvss7XGfv3rX+fGG2/scSKV6pwvW4yINwL3Ae/PzKfOPZIkaTG6nqFHxD3AKLA+Io4DtwEDAJm5F7gVeD3w6YgAOJOZO3oVWJI0uzpXuVzXZf8HgQ82lkiStCi+U1SSCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKkT5hZ45/2M1zzmXWrGq24CIuAt4D3AyM7fNsj+A24GrgJeA6zPzsaaDLsojj8Arr8Dll0NEVSwPPwyDgzA62na6MjnnWiHGx+HgQZiYgDVrYNcu2L693Ux1ztD3A1fMs/9K4E2d2x7gM+ceqwGZVbGMjVWFMlUsY2PVds8am+eca4UYH4cHHoBTp6pv61Onqsfj4+3m6nqGnpmHImLzPEOuBe7OzAQejYi1EbEhM080FXJRIqqzRKgKZWysur9z52tnj2qWc64V4uBBOH367G2nT1fb2zxLb2IN/SLg+9MeH+9s+ykRsSciDkfE4cnJyQYO3cX0gplisfSWc64VYGJiYduXypK+KJqZ+zJzR2buGBoaWooDVr/yTze1FKDecM61AqxZs7DtS6XrkksNPwA2TXs80tnWrunrt1O/8k89Bs8ae8E51wqxa1e1Zj592WVgoNrepiYK/X7gQxFxL7ATmGh9/Ryq4hgcPHv9dmopYHDQYukF51wrxNQ6eb9d5VLnssV7gFFgfUQcB24DBgAycy/wENUli09TXbZ4Q6/CLtjoaHXWOFUkUwVjsfSOc64VYvv29gt8pjpXuVzXZX8CNzWWqGkzi8Ri6T3nXGpF+e8UlaQVwkKXpEJY6JJUiCaucpHUoJdeeomjR4/WGnveeeexZcuWHifScmGhS31k48aN3HHHHbXHX3311QwPD7Nhw4YeptJyYaFLfWbr1q21xx47doxLL720h2m0nLiGLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgoR2dIHD0TED4HvLuEh1wM/WsLjNWm5Zl+uuWH5Zl+uuWH5Zl/q3L+QmRfOtqO1Ql9qEXE4M3e0nWMxlmv25Zoblm/25Zoblm/2fsrtkoskFcJCl6RCrKRC39d2gHOwXLMv19ywfLMv19ywfLP3Te4Vs4YuSaVbSWfoklQ0C12SClFUoUfEXRFxMiKOzLE/IuLvI+LpiPh2RLxjqTPOpUb20YiYiIgnOrdblzrjbCJiU0R8LSK+ExFPRsSfzzKm7+a9Zu5+nfPBiPiviPhWJ/tfzzLmZyPiC505H4uIzS1EnZmpTu7rI+KH0+b8g21knUtEnBcRj0fEg7Psa3/OM7OYG/BrwDuAI3Psvwr4ChDAZcBY25kXkH0UeLDtnLPk2gC8o3N/GHgK2Nrv814zd7/OeQCrO/cHgDHgshlj/hTY27n/PuALyyT39cCn2s46z5/hL4B/mu37oh/mvKgz9Mw8BLwwz5Brgbuz8iiwNiL64qNeamTvS5l5IjMf69x/ETgKXDRjWN/Ne83cfakzj//TeTjQuc28uuFa4HOd+18EdkdELFHEWdXM3bciYgS4GrhzjiGtz3lRhV7DRcD3pz0+zjL5Ie741c6vq1+JiLe0HWamzq+Yb6c685qur+d9ntzQp3Pe+dX/CeAk8NXMnHPOM/MMMAG8fklDzqJGboDf6SzNfTEiNi1twnl9EvhL4NU59rc+5yut0Jezx6j+DYdLgDuAL7cb52wRsRr4EvDhzPxx23nq6pK7b+c8M/8vM98GjADvjIhtLUeqpUbuB4DNmflW4Ku8dsbbqoh4D3AyM7/Zdpb5rLRC/wEw/W/8kc62vpeZP576dTUzHwIGImJ9y7EAiIgBqlL8fGbeN8uQvpz3brn7ec6nZOYp4GvAFTN2/WTOI2IVsAZ4fknDzWOu3Jn5fGb+b+fhnUC/fGDqu4BrIuIYcC+wKyL+ccaY1ud8pRX6/cAHOlddXAZMZOaJtkPVERE/P7UeFxHvpPp/1/oPaCfTZ4GjmfmJOYb13bzXyd3Hc35hRKzt3H8d8JvAf88Ydj/wR5377wUOZufVurbUyT3jtZVrqF7baF1m/lVmjmTmZqoXPA9m5h/OGNb6nK9ayoP1WkTcQ3VlwvqIOA7cRvXCC5m5F3iI6oqLp4GXgBvaSfrTamR/L/AnEXEGeBl4X9s/oB3vAt4PjHfWRgE+CrwR+nre6+Tu1znfAHwuIs6j+kvmnzPzwYj4G+BwZt5P9ZfVP0TE01Qvtr+vvbg/USf3n0XENcAZqtzXt5a2hn6bc9/6L0mFWGlLLpJULAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFeL/Admi1qH7N00UAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for column in contour.collections:\n",
" plt.gca().collections.remove(column)\n",
" \n",
"contour = ax.contourf(\n",
" contour_plot_x_data,\n",
" contour_plot_y_data,\n",
" quantized_predictions.round().reshape(contour_plot_x_data.shape),\n",
" cmap=\"gray\",\n",
" alpha=0.50,\n",
")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "4834cdfc",
"metadata": {},
"source": [
"### Now it's time to make the inference homomorphic"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "fcf4ea26",
"metadata": {},
"outputs": [],
"source": [
"q_y = (2**output_bits - 1) / (intermediate.max() - intermediate.min())\n",
"zp_y = int(round(intermediate.min() * q_y))\n",
"\n",
"q_x = x_q.parameters.q\n",
"q_w = w_q.parameters.q\n",
"q_b = b_q.parameters.q\n",
"\n",
"zp_x = x_q.parameters.zp\n",
"zp_w = w_q.parameters.zp\n",
"zp_b = b_q.parameters.zp\n",
"\n",
"x_q = x_q.values\n",
"w_q = w_q.values\n",
"b_q = b_q.values"
]
},
{
"cell_type": "markdown",
"id": "43e47369",
"metadata": {},
"source": [
"### Simplification to rescue!\n",
"\n",
"The `y_q` formula in `QuantizedArray.affine(...)` can be rewritten to make it easier to implement in homomorphically. Here is the breakdown.\n",
"```\n",
"(q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b)) - (min_y * q_y)\n",
"^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^\n",
"constant (c1) can be done constant (c2) constant (c3) constant (c4)\n",
" on the circuit \n",
" \n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" can be done on the circuit\n",
" \n",
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"cannot be done on the circuit because of floating point operation so will be a single table lookup\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "2de0cf20",
"metadata": {},
"outputs": [],
"source": [
"c1 = q_y / (q_x * q_w)\n",
"c2 = w_q + zp_w\n",
"c3 = (q_x * q_w / q_b) * (b_q + zp_b)\n",
"c4 = intermediate.min() * q_y\n",
"\n",
"def f(x):\n",
" values = ((c1 * (x + c3)) - c4).round().clip(0, 2**output_bits - 1).astype(np.uint)\n",
" after_affine_q = QuantizedArray(values, intermediate_q.parameters)\n",
" \n",
" sigmoid = QuantizedFunction.plain(lambda x: 1 / (1 + np.exp(-x)), after_affine_q.parameters, output_bits)\n",
" y_q = sigmoid.apply(after_affine_q)\n",
" \n",
" return y_q.values\n",
"\n",
"f_q = QuantizedFunction.of(f, output_bits, output_bits)\n",
"\n",
"from hdk.common.extensions.table import LookupTable\n",
"table = LookupTable([int(entry) for entry in f_q.table])\n",
"\n",
"w_0 = int(c2.flatten()[0])\n",
"w_1 = int(c2.flatten()[1])\n",
"\n",
"def infer(x_0, x_1):\n",
" return table[((x_0 + zp_x) * w_0) + ((x_1 + zp_x) * w_1)]"
]
},
{
"cell_type": "markdown",
"id": "93eb9499",
"metadata": {},
"source": [
"### Time to compile our quantized inference function"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a80895fd",
"metadata": {},
"outputs": [],
"source": [
"from hdk.common.data_types.integers import Integer\n",
"from hdk.common.values import EncryptedValue\n",
"from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n",
"\n",
"dataset = []\n",
"for x_i in x_q:\n",
" dataset.append((int(x_i[0]), int(x_i[1])))\n",
" \n",
"homomorphic_model = compile_numpy_function_into_op_graph(\n",
" infer,\n",
" {\n",
" \"x_0\": EncryptedValue(Integer(input_bits, is_signed=False)),\n",
" \"x_1\": EncryptedValue(Integer(input_bits, is_signed=False)),\n",
" },\n",
" iter(dataset),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f0b08a0f",
"metadata": {},
"source": [
"### Here is the textual representation of the operation graph"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "2cc4e11d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"%0 = Constant(2) # Integer<unsigned, 8 bits>\n",
"%1 = Constant(1) # Integer<unsigned, 8 bits>\n",
"%2 = x_0 # Integer<unsigned, 7 bits>\n",
"%3 = Constant(6) # Integer<unsigned, 8 bits>\n",
"%4 = x_1 # Integer<unsigned, 7 bits>\n",
"%5 = Constant(6) # Integer<unsigned, 8 bits>\n",
"%6 = Add(2, 3) # Integer<unsigned, 7 bits>\n",
"%7 = Add(4, 5) # Integer<unsigned, 7 bits>\n",
"%8 = Mul(6, 0) # Integer<unsigned, 7 bits>\n",
"%9 = Mul(7, 1) # Integer<unsigned, 7 bits>\n",
"%10 = Add(8, 9) # Integer<unsigned, 7 bits>\n",
"%11 = TLU(10) # Integer<unsigned, 7 bits>\n",
"return(%11)\n"
]
}
],
"source": [
"from hdk.common.debugging import get_printable_graph\n",
"print(get_printable_graph(homomorphic_model, show_data_types=True))"
]
},
{
"cell_type": "markdown",
"id": "ade14f17",
"metadata": {},
"source": [
"### Finally, it's time to make homomorphic inference\n",
"\n",
"Or, at least, simulate it until the compiler integration is complete."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "dd2d03d7",
"metadata": {},
"outputs": [],
"source": [
"homomorphic_predictions = []\n",
"for x_0, x_1 in map(lambda x_i: (int(x_i[0]), int(x_i[1])), x_q):\n",
" evaluation = homomorphic_model.evaluate({0: x_0, 1: x_1})\n",
" inference = QuantizedArray(evaluation[homomorphic_model.output_nodes[0]], y_q.parameters)\n",
" homomorphic_predictions.append(inference.dequantize())\n",
"homomorphic_predictions = np.array(homomorphic_predictions, dtype=np.float32)"
]
},
{
"cell_type": "markdown",
"id": "443fbc03",
"metadata": {},
"source": [
"### And visualize it"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "57050b5d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQaElEQVR4nO3df2xdZ33H8fd3jTtTx0siUrGkDssqkY2QUKAZ6QTavERbf6FW05hGt8FaDaXaSje0SkPjj1Ybf01oiK4IoqhUoRtrmaBibVXWoYQuGqyeQltwSqaqagOERgq0iuncdkrW7/4419Qxtu+xc67P9eP3S7rqvec8vueTp/Ynx889NzcyE0nS8vczbQeQJDXDQpekQljoklQIC12SCmGhS1IhVrV14KGhoVy3bl1bh1ehTp8+zYUXXsj555/fdhSpJx5//PEfZeaFs+1rrdDXrVvHzTff3NbhVajnnnuOG2+8kc2bN7cdReqJoaGh7861zyUXSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBWi/EKf+ZmpfoaqpEJ1/dcWI2ITcDfwBiCBfZl5+4wxAdwOXAW8BFyfmY81H3eBHnkEXnkFLr8cIqoyf/hhGByE0dG200mNGR+HgwdhYgLWrIFdu2D79rZTla0f57zOGfoZ4JbM3ApcBtwUEVtnjLkSeFPntgf4TKMpFyOzKvOxsarEp8p8bKza7pm6CjE+Dg88AKdOVd/Wp05Vj8fH205Wrn6d865n6Jl5AjjRuf9iRBwFLgK+M23YtcDdmZnAoxGxNiI2dL62HRHVmTlUJT42Vt3fufO1M3apAAcPwunTZ287fbra3vYZY6n6dc4XtIYeEZuBtwNjM3ZdBHx/2uPjnW0zv35PRByOiMOTk5MLjLoI00t9imWuwkxMLGy7zl2/znntQo+I1cCXgA9n5o8Xc7DM3JeZOzJzx9DQ0GKeYqEHrJZZpptafpEKsWbNwrbr3PXrnNcq9IgYoCrzz2fmfbMM+QGwadrjkc629kxfM9+5E269tfrv9DV1qQC7dsHAwNnbBgaq7eqNfp3zOle5BPBZ4GhmfmKOYfcDH4qIe4GdwESr6+dQLasMDp69Zj61/DI46LKLijG1ZttvV1yUrF/nvM6HRL8LeD8wHhFPdLZ9FHgjQGbuBR6iumTxaarLFm9oPOlijI5WZ+JT5T1V6pa5CrN9e/tlstL045zXucrlP4B5G7BzdctNTYVq1MzytswlFar8d4pK0gphoUtSIeqsoUvLyssvv8yJE/Vek1+9ejXDw8M9TiQtDQtdRdm4cSMHDhyoPX7Xrl1s2rTJUlcRXHJRcTKz9u2ZZ55pO67UGAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJ0LfSIuCsiTkbEkTn2r4mIByLiWxHxZETc0HxMSVI3dc7Q9wNXzLP/JuA7mXkJMAr8XUScf+7RJEkL0bXQM/MQ8MJ8Q4DhiAhgdWfsmWbiSZLqWtXAc3wKuB94DhgGfi8zX51tYETsAfYArF27toFDS5KmNPGi6OXAE8BG4G3ApyLi52YbmJn7MnNHZu4YGhpq4NCSpClNFPoNwH1ZeRp4FvjlBp5XkrQATRT694DdABHxBuCXgGcaeF5J0gJ0XUOPiHuorl5ZHxHHgduAAYDM3At8DNgfEeNAAB/JzB/1LLEkaVZdCz0zr+uy/zngtxpLJElaFN8pKkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCNPGJRdKydezYMS6++OLa41etWsWWLVt6mEhaPAtdK1pmcuDAgVpjjxw5wi233NLjRNLiueQiSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUiK6FHhF3RcTJiDgyz5jRiHgiIp6MiH9vNqIkqY46Z+j7gSvm2hkRa4FPA9dk5luA320kmSRpQboWemYeAl6YZ8jvA/dl5vc64082lE2StABNrKFvAdZFxCMR8c2I+MBcAyNiT0QcjojDk5OTDRxakjSliU8sWgVcCuwGXgf8Z0Q8mplPzRyYmfuAfQAjIyPZwLElSR1NFPpx4PnMnAQmI+IQcAnwU4UuSeqdJpZc/gV4d0SsiogLgJ3A0QaeV5K0AF3P0CPiHmAUWB8Rx4HbgAGAzNybmUcj4l+BbwOvAndm5pyXOEqSeqNroWfmdTXGfBz4eCOJJEmL4jtFJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgqxqu0A0nJy5swZnnrqqQV9zYYNGxgeHu5RIuk1FrpU07Zt27j99tsX9DXbt29n9+7dvPnNb+5RKuk1Frq0ANu2bVvQ+CeffJLdu3f3KI10NtfQJakQFrokFcJCl6RCdC30iLgrIk5GxJEu434lIs5ExHubiydJqqvOGfp+4Ir5BkTEecDfAv/WQCZJ0iJ0LfTMPAS80GXYzcCXgJNNhJIkLdw5r6FHxEXAbwOfqTF2T0QcjojDk5OT53poSdI0Tbwo+kngI5n5areBmbkvM3dk5o6hoaEGDi1JmtLEG4t2APdGBMB64KqIOJOZX27guSVJNZ1zoWfmL07dj4j9wIOWuSQtva6FHhH3AKPA+og4DtwGDABk5t6eppMk1da10DPzurpPlpnXn1MaSdKi+U5RSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQXQs9Iu6KiJMRcWSO/X8QEd+OiPGI+EZEXNJ8TElSN3XO0PcDV8yz/1ng1zNzO/AxYF8DuSRJC7Sq24DMPBQRm+fZ/41pDx8FRhrIJUlaoK6FvkB/DHxlrp0RsQfYA7B27dqGDy31rxdffLH22OHh4R4mUckaK/SI+A2qQn/3XGMycx+dJZmRkZFs6thSv9q6dSsHDhxg8+bNtcZffPHFXHDBBbXHS9M1UugR8VbgTuDKzHy+ieeUSpGZPPvss7XGfv3rX+fGG2/scSKV6pwvW4yINwL3Ae/PzKfOPZIkaTG6nqFHxD3AKLA+Io4DtwEDAJm5F7gVeD3w6YgAOJOZO3oVWJI0uzpXuVzXZf8HgQ82lkiStCi+U1SSCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKkT5hZ45/2M1zzmXWrGq24CIuAt4D3AyM7fNsj+A24GrgJeA6zPzsaaDLsojj8Arr8Dll0NEVSwPPwyDgzA62na6MjnnWiHGx+HgQZiYgDVrYNcu2L693Ux1ztD3A1fMs/9K4E2d2x7gM+ceqwGZVbGMjVWFMlUsY2PVds8am+eca4UYH4cHHoBTp6pv61Onqsfj4+3m6nqGnpmHImLzPEOuBe7OzAQejYi1EbEhM080FXJRIqqzRKgKZWysur9z52tnj2qWc64V4uBBOH367G2nT1fb2zxLb2IN/SLg+9MeH+9s+ykRsSciDkfE4cnJyQYO3cX0gplisfSWc64VYGJiYduXypK+KJqZ+zJzR2buGBoaWooDVr/yTze1FKDecM61AqxZs7DtS6XrkksNPwA2TXs80tnWrunrt1O/8k89Bs8ae8E51wqxa1e1Zj592WVgoNrepiYK/X7gQxFxL7ATmGh9/Ryq4hgcPHv9dmopYHDQYukF51wrxNQ6eb9d5VLnssV7gFFgfUQcB24DBgAycy/wENUli09TXbZ4Q6/CLtjoaHXWOFUkUwVjsfSOc64VYvv29gt8pjpXuVzXZX8CNzWWqGkzi8Ri6T3nXGpF+e8UlaQVwkKXpEJY6JJUiCaucpHUoJdeeomjR4/WGnveeeexZcuWHifScmGhS31k48aN3HHHHbXHX3311QwPD7Nhw4YeptJyYaFLfWbr1q21xx47doxLL720h2m0nLiGLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgoR2dIHD0TED4HvLuEh1wM/WsLjNWm5Zl+uuWH5Zl+uuWH5Zl/q3L+QmRfOtqO1Ql9qEXE4M3e0nWMxlmv25Zoblm/25Zoblm/2fsrtkoskFcJCl6RCrKRC39d2gHOwXLMv19ywfLMv19ywfLP3Te4Vs4YuSaVbSWfoklQ0C12SClFUoUfEXRFxMiKOzLE/IuLvI+LpiPh2RLxjqTPOpUb20YiYiIgnOrdblzrjbCJiU0R8LSK+ExFPRsSfzzKm7+a9Zu5+nfPBiPiviPhWJ/tfzzLmZyPiC505H4uIzS1EnZmpTu7rI+KH0+b8g21knUtEnBcRj0fEg7Psa3/OM7OYG/BrwDuAI3Psvwr4ChDAZcBY25kXkH0UeLDtnLPk2gC8o3N/GHgK2Nrv814zd7/OeQCrO/cHgDHgshlj/hTY27n/PuALyyT39cCn2s46z5/hL4B/mu37oh/mvKgz9Mw8BLwwz5Brgbuz8iiwNiL64qNeamTvS5l5IjMf69x/ETgKXDRjWN/Ne83cfakzj//TeTjQuc28uuFa4HOd+18EdkdELFHEWdXM3bciYgS4GrhzjiGtz3lRhV7DRcD3pz0+zjL5Ie741c6vq1+JiLe0HWamzq+Yb6c685qur+d9ntzQp3Pe+dX/CeAk8NXMnHPOM/MMMAG8fklDzqJGboDf6SzNfTEiNi1twnl9EvhL4NU59rc+5yut0Jezx6j+DYdLgDuAL7cb52wRsRr4EvDhzPxx23nq6pK7b+c8M/8vM98GjADvjIhtLUeqpUbuB4DNmflW4Ku8dsbbqoh4D3AyM7/Zdpb5rLRC/wEw/W/8kc62vpeZP576dTUzHwIGImJ9y7EAiIgBqlL8fGbeN8uQvpz3brn7ec6nZOYp4GvAFTN2/WTOI2IVsAZ4fknDzWOu3Jn5fGb+b+fhnUC/fGDqu4BrIuIYcC+wKyL+ccaY1ud8pRX6/cAHOlddXAZMZOaJtkPVERE/P7UeFxHvpPp/1/oPaCfTZ4GjmfmJOYb13bzXyd3Hc35hRKzt3H8d8JvAf88Ydj/wR5377wUOZufVurbUyT3jtZVrqF7baF1m/lVmjmTmZqoXPA9m5h/OGNb6nK9ayoP1WkTcQ3VlwvqIOA7cRvXCC5m5F3iI6oqLp4GXgBvaSfrTamR/L/AnEXEGeBl4X9s/oB3vAt4PjHfWRgE+CrwR+nre6+Tu1znfAHwuIs6j+kvmnzPzwYj4G+BwZt5P9ZfVP0TE01Qvtr+vvbg/USf3n0XENcAZqtzXt5a2hn6bc9/6L0mFWGlLLpJULAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFeL/Admi1qH7N00UAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for column in contour.collections:\n",
" plt.gca().collections.remove(column)\n",
" \n",
"contour = ax.contourf(\n",
" contour_plot_x_data,\n",
" contour_plot_y_data,\n",
" homomorphic_predictions.round().reshape(contour_plot_x_data.shape),\n",
" cmap=\"gray\",\n",
" alpha=0.50,\n",
")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "53ecca94",
"metadata": {},
"source": [
"### Enjoy!"
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}