{ "cells": [ { "cell_type": "markdown", "id": "0fe629d6", "metadata": {}, "source": [ "# Quantized Logistic Regression\n", "\n", "Currently, **hdk** only supports unsigned integers up to 7-bits. Nevertheless, we want to evaluate a logistic regression model with it. Luckily, we can make use of **quantization** to overcome this limitation!" ] }, { "cell_type": "markdown", "id": "d0cfb561", "metadata": {}, "source": [ "### Let's start by importing some libraries to develop our logistic regression model" ] }, { "cell_type": "code", "execution_count": 1, "id": "3c1d929c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch" ] }, { "cell_type": "markdown", "id": "69d25f7c", "metadata": {}, "source": [ "### And some helpers for visualization" ] }, { "cell_type": "code", "execution_count": 2, "id": "a89c1a6c", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from IPython.display import display" ] }, { "cell_type": "markdown", "id": "7729c1de", "metadata": {}, "source": [ "### We need a dataset, a handcrafted one for simplicity" ] }, { "cell_type": "code", "execution_count": 3, "id": "b77a9e82", "metadata": {}, "outputs": [], "source": [ "x = torch.tensor([[1, 1], [1, 2], [2, 1], [4, 1], [3, 2], [4, 2]]).float()\n", "y = torch.tensor([[0], [0], [0], [1], [1], [1]]).float()" ] }, { "cell_type": "markdown", "id": "cc8673ff", "metadata": {}, "source": [ "### Let's visualize our dataset to get a grasp of it" ] }, { "cell_type": "code", "execution_count": 4, "id": "35a98d1a", "metadata": {}, "outputs": [], "source": [ "plt.ioff()\n", "fig, ax = plt.subplots(1)" ] }, { "cell_type": "code", "execution_count": 5, "id": "56703410", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPTklEQVR4nO3df4zkd13H8efruPPHpcgZbqO117v1D1ABKbQr1ED0lCgHmBJjTagVbCO5RKsup4mNEOkpaaIhchQbOC6lOdT1wNAGSgNGImAlhJo9LO2VCmmEKweNt7S5omBMznv7x3eW7q27O7N3szuzn30+ksnO9/v93Hxf/XTvtd/5zMxtqgpJ0sa3ZdQBJEnDYaFLUiMsdElqhIUuSY2w0CWpEVtHdeKdO3fW5OTkqE4vSRvS8ePHv1lVE0sdG1mhT05OMjs7O6rTS9KGlOTkcsdccpGkRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSI9ov9MW/M9XfoSqpUX0LPcnlST6V5ItJHk4yvcSYJHlXkkeTPJjkyrWJu0oHD8KBA0+XeFW3ffDgKFNJQzczA5OTsGVL93VmZtSJ2jeOcz7IFfpZ4A+q6nnA1cBNSZ63aMyrgOf0bvuB9ww15YWogjNn4Lbbni71Awe67TNnvFJXM2ZmYP9+OHmy+7Y+ebLbHoeCadW4znlqlcWW5CPA7VX1iQX73gt8uqqO9ba/BOytqseXe5ypqala838PfWGJz5uehkOHIFnbc0vrZHKyK5TF9uyBr351vdNsDqOc8yTHq2pqqWOrWkNPMgm8GLh/0aHLgK8t2D7V27f4z+9PMptkdm5ubjWnvjBJV94LWeZqzGOPrW6/Lt64zvnAhZ7kEuAu4E1V9a0LOVlVHamqqaqamphY8jcoDdf8FfpCC9fUpQbs3r26/bp44zrnAxV6km10ZT5TVXcvMeTrwOULtnf19o3OwuWW6Wk4d677unBNXWrArbfC9u3n79u+vduvtTGucz7Iu1wCvA94pKrescywe4A39N7tcjXw1Err5+sigR07zl8zP3So296xw2UXNeP66+HIkW79Num+HjnS7dfaGNc57/uiaJKXA/8MPASc6+1+M7AboKoO90r/dmAf8B3gxqpa8RXPdXlRtAt4fnkv3pakDWSlF0W39vvDVfUZYMUGrO6nwk0XFm+NLS5vy1xSo9r/pKgkbRIWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUiL6FnuTOJKeTnFjm+LOSfDTJF5I8nOTG4ceUJPUzyBX6UWDfCsdvAr5YVVcAe4G/SPI9Fx9NkrQafQu9qu4DnlxpCPDMJAEu6Y09O5x4kqRBDWMN/XbgJ4BvAA8B01V1bqmBSfYnmU0yOzc3N4RTS5LmDaPQXwk8APwI8CLg9iQ/sNTAqjpSVVNVNTUxMTGEU0uS5g2j0G8E7q7Oo8BXgB8fwuNKklZhGIX+GPAKgCQ/BPwY8O9DeFxJ0ips7TcgyTG6d6/sTHIKuAXYBlBVh4G3AUeTPAQEuLmqvrlmiSVJS+pb6FV1XZ/j3wB+cWiJJEkXxE+KSlIjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmN6FvoSe5McjrJiRXG7E3yQJKHk/zTcCNKkgYxyBX6UWDfcgeT7ADeDVxTVc8HfnUoySRJq9K30KvqPuDJFYb8GnB3VT3WG396SNkkSaswjDX05wI/mOTTSY4necNyA5PsTzKbZHZubm4Ip5YkzRtGoW8FrgJeA7wS+OMkz11qYFUdqaqpqpqamJgYwqklSfO2DuExTgFPVNW3gW8nuQ+4AvjyEB5bkjSgYVyhfwR4eZKtSbYDLwUeGcLjSpJWoe8VepJjwF5gZ5JTwC3ANoCqOlxVjyT5e+BB4BxwR1Ut+xZHSdLa6FvoVXXdAGPeDrx9KIkkSRfET4pKUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY3oW+hJ7kxyOsmJPuN+KsnZJNcOL54kaVCDXKEfBfatNCDJM4A/B/5hCJkkSRegb6FX1X3Ak32G/S5wF3B6GKEkSat30WvoSS4Dfhl4zwBj9yeZTTI7Nzd3saeWJC0wjBdF3wncXFXn+g2sqiNVNVVVUxMTE0M4tSRp3tYhPMYU8IEkADuBVyc5W1UfHsJjS5IGdNGFXlU/On8/yVHgXstcktZf30JPcgzYC+xMcgq4BdgGUFWH1zSdJGlgfQu9qq4b9MGq6oaLSiNJumB+UlSSGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJakTfQk9yZ5LTSU4sc/z6JA8meSjJZ5NcMfyYkqR+BrlCPwrsW+H4V4CfraqfBN4GHBlCLknSKm3tN6Cq7ksyucLxzy7Y/Bywawi5JEmrNOw19N8EPr7cwST7k8wmmZ2bmxvyqSVpcxtaoSf5ObpCv3m5MVV1pKqmqmpqYmJiWKeWJDHAkssgkrwQuAN4VVU9MYzHlCStzkVfoSfZDdwNvL6qvnzxkSRJF6LvFXqSY8BeYGeSU8AtwDaAqjoMvBV4NvDuJABnq2pqrQJLkpY2yLtcrutz/I3AG4eWSJJ0QfykqCQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUiPYLvWrlbQ2fcy6NRN9CT3JnktNJTixzPEneleTRJA8muXL4MS/QwYNw4MDThVLVbR88OMpUbXPOtUnMzMDkJGzZ0n2dmRl1osGu0I8C+1Y4/irgOb3bfuA9Fx9rCKrgzBm47banC+bAgW77zBmvGteCc65NYmYG9u+Hkye7b+uTJ7vtkZd6VfW9AZPAiWWOvRe4bsH2l4BL+z3mVVddVWvu3Lmq6emqbs672/R0t19rwznXJrBnz/nf4vO3PXvW/tzAbC3Tq6kBrpqSTAL3VtULljh2L/BnVfWZ3vY/AjdX1ewSY/fTXcWze/fuq06ePHkBP4JWqap7TjTv3DlI1v68m5lzrsZt2bL0E86k+3ZfS0mOV9XUkrnW9tTnq6ojVTVVVVMTExPrccLuKf9CC9d3NXzOuTaB3btXt3+9DKPQvw5cvmB7V2/faC1cv52e7n5sTk+fv76r4XLOtUnceits337+vu3bu/2jtHUIj3EP8DtJPgC8FHiqqh4fwuNenAR27OgK5dChbvvQoe7Yjh0uAawF51ybxPXXd1/f8hZ47LHuyvzWW5/ePyp919CTHAP2AjuB/wBuAbYBVNXhJAFup3snzHeAG5daP19samqqZmf7Drt4VecXyeJtDZ9zLq2ZldbQ+16hV9V1fY4XcNMFZlt7i4vEYll7zrk0Eu1/UlSSNgkLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJasRAv7FoTU6czAHr8CuLvmsn8M11PN8wbdTsGzU3bNzsGzU3bNzs6517T1Ut+RuCRlbo6y3J7HL/5OS426jZN2pu2LjZN2pu2LjZxym3Sy6S1AgLXZIasZkK/cioA1yEjZp9o+aGjZt9o+aGjZt9bHJvmjV0SWrdZrpCl6SmWeiS1IimCj3JnUlOJzmxzPEkeVeSR5M8mOTK9c64nAGy703yVJIHere3rnfGpSS5PMmnknwxycNJppcYM3bzPmDucZ3z70vyL0m+0Mv+J0uM+d4kH+zN+f1JJkcQdXGmQXLfkGRuwZy/cRRZl5PkGUn+Ncm9Sxwb/ZxXVTM34GeAK4ETyxx/NfBxIMDVwP2jzryK7HuBe0edc4lclwJX9u4/E/gy8Lxxn/cBc4/rnAe4pHd/G3A/cPWiMb8NHO7dfx3wwQ2S+wbg9lFnXeG/4feBv13q+2Ic5rypK/Squg94coUhrwX+qjqfA3YkuXR90q1sgOxjqaoer6rP9+7/J/AIcNmiYWM37wPmHku9efyv3ua23m3xuxteC7y/d/9DwCuSZJ0iLmnA3GMryS7gNcAdywwZ+Zw3VegDuAz42oLtU2yQv8Q9P917uvrxJM8fdZjFek8xX0x35bXQWM/7CrlhTOe899T/AeA08ImqWnbOq+os8BTw7HUNuYQBcgP8Sm9p7kNJLl/fhCt6J/CHwLlljo98zjdboW9kn6f7NxyuAP4S+PBo45wvySXAXcCbqupbo84zqD65x3bOq+p/q+pFwC7gJUleMOJIAxkg90eByap6IfAJnr7iHakkvwScrqrjo86yks1W6F8HFv7E39XbN/aq6lvzT1er6mPAtiQ7RxwLgCTb6EpxpqruXmLIWM57v9zjPOfzquoM8Clg36JD353zJFuBZwFPrGu4FSyXu6qeqKr/6W3eAVy1ztGW8zLgmiRfBT4A/HySv1k0ZuRzvtkK/R7gDb13XVwNPFVVj4861CCS/PD8elySl9D9vxv5X9BepvcBj1TVO5YZNnbzPkjuMZ7ziSQ7eve/H/gF4N8WDbsH+I3e/WuBT1bv1bpRGST3otdWrqF7bWPkquqPqmpXVU3SveD5yar69UXDRj7nW9fzZGstyTG6dybsTHIKuIXuhReq6jDwMbp3XDwKfAe4cTRJ/78Bsl8L/FaSs8B/A68b9V/QnpcBrwce6q2NArwZ2A1jPe+D5B7XOb8UeH+SZ9D9kPm7qro3yZ8Cs1V1D90Pq79O8ijdi+2vG13c7xok9+8luQY4S5f7hpGlHcC4zbkf/ZekRmy2JRdJapaFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhrxf09l6LOTuZAtAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_min, x_max = x[:, 0].min(), x[:, 0].max()\n", "x_deviation = x_max - x_min\n", "\n", "y_min, y_max = x[:, 1].min(), x[:, 1].max()\n", "y_deviation = y_max - y_min\n", "\n", "ax.set_xlim(x_min - (x_deviation / 10), x_max + (x_deviation / 10))\n", "ax.set_ylim(y_min - (y_deviation / 10), y_max + (y_deviation / 10))\n", "\n", "ax.scatter(\n", " np.array([x_i[0] for x_i, y_i in zip(x, y) if y_i == 0], dtype=np.float32),\n", " np.array([x_i[1] for x_i, y_i in zip(x, y) if y_i == 0], dtype=np.float32),\n", " marker=\"x\",\n", " color=\"red\",\n", ")\n", "ax.scatter(\n", " np.array([x_i[0] for x_i, y_i in zip(x, y) if y_i == 1], dtype=np.float32),\n", " np.array([x_i[1] for x_i, y_i in zip(x, y) if y_i == 1], dtype=np.float32),\n", " marker=\"o\",\n", " color=\"blue\",\n", ")\n", "display(fig)" ] }, { "cell_type": "markdown", "id": "e31b82e8", "metadata": {}, "source": [ "### Now, we need a model so let's define it" ] }, { "cell_type": "code", "execution_count": 6, "id": "cc5e72a2", "metadata": {}, "outputs": [], "source": [ "class Model(torch.nn.Module):\n", " def __init__(self, n):\n", " super(Model, self).__init__()\n", " self.fc = torch.nn.Linear(n, 1)\n", "\n", " def forward(self, x):\n", " output = torch.sigmoid(self.fc(x))\n", " return output" ] }, { "cell_type": "markdown", "id": "cefd8346", "metadata": {}, "source": [ "### And create one\n", "\n", "The main purpose of this tutorial is not to train a logistic regression model but to use it homomorphically. So we will not discuss about how the model is trained." ] }, { "cell_type": "code", "execution_count": 7, "id": "b9879f4d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 | Loss: 0.5758869647979736\n", "Epoch: 101 | Loss: 0.13611836731433868\n", "Epoch: 201 | Loss: 0.08021673560142517\n", "Epoch: 301 | Loss: 0.05636058747768402\n", "Epoch: 401 | Loss: 0.043306026607751846\n", "Epoch: 501 | Loss: 0.03511128947138786\n", "Epoch: 601 | Loss: 0.029501130804419518\n", "Epoch: 701 | Loss: 0.025424323976039886\n", "Epoch: 801 | Loss: 0.02233024500310421\n", "Epoch: 901 | Loss: 0.01990305446088314\n", "Epoch: 1001 | Loss: 0.0179488193243742\n", "Epoch: 1101 | Loss: 0.01634199731051922\n", "Epoch: 1201 | Loss: 0.014997857622802258\n", "Epoch: 1301 | Loss: 0.013856985606253147\n", "Epoch: 1401 | Loss: 0.012876608408987522\n", "Epoch: 1501 | Loss: 0.012025204487144947\n" ] } ], "source": [ "model = Model(x.shape[1])\n", "\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1)\n", "criterion = torch.nn.BCELoss()\n", "\n", "epochs = 1501\n", "for e in range(1, epochs + 1):\n", " optimizer.zero_grad()\n", "\n", " out = model(x)\n", " loss = criterion(out, y)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if e % 100 == 1 or e == epochs:\n", " print(\"Epoch:\", e, \"|\", \"Loss:\", loss.item())" ] }, { "cell_type": "markdown", "id": "01cfc83f", "metadata": {}, "source": [ "### Time to make some predictions" ] }, { "cell_type": "code", "execution_count": 8, "id": "78356d37", "metadata": {}, "outputs": [], "source": [ "contour_plot_x_data = np.linspace(x_min - (x_deviation / 10), x_max + 2 * (x_deviation / 10), 250)\n", "contour_plot_y_data = np.linspace(y_min - (y_deviation / 10), y_max + 2 * (y_deviation / 10), 250)\n", "contour_plot_x_data, contour_plot_y_data = np.meshgrid(contour_plot_x_data, contour_plot_y_data)\n", "\n", "inputs = np.stack((contour_plot_x_data.flatten(), contour_plot_y_data.flatten()), axis=1)\n", "predictions = model(torch.tensor(inputs).float()).detach().numpy()" ] }, { "cell_type": "markdown", "id": "58160140", "metadata": {}, "source": [ "### Let's visualize our predictions to see how our model performs" ] }, { "cell_type": "code", "execution_count": 9, "id": "2a623999", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAT40lEQVR4nO3df2xdZ33H8c+XxluYkyURQcyuw7I/6KbDDeVHRjpRbVnQltJBqilMo9tgrYYibV06tElD449WG39NaAjPFURRqbLescIEFasrWIdy6aKV1VNoC27syaprg22sBRzfkPg2P6773R/nGhzX9r22j+9z7nPfL8nKvec88fn0afLJ8XPOvdfcXQCA1ve60AEAANmg0AEgEhQ6AESCQgeASFDoABCJLaEO3NnZ6bt27Qp1eOTMlStXtGPHDm3btk033XRT6DhAbj3//PM/cvc3LrcvWKHv2rVLx48fD3V45MzQ0JCOHDmi22+/Xdu3bw8dB8itzs7O7620jyUX5EKSJJqfn9fExISmp6dDxwFaEoWO3BgbG1OxWNTVq1dDRwFaEoUOAJGg0AEgEhQ6AESCQkfuXLt2LXQEoCVR6MiVcrmsarWq4eHh0FGAlkOhI1eSJFFfX1/oGEBLotABIBIUOgBEgkJHLrm7xsfHQ8cAWgqFjtxJkkS9vb2qVCq8DQCwBvEX+tLPTOUzVFtCoVBQqVQKHQNoKXXfbdHM9kh6VNKbJLmkk+7eu2SMSeqVdKekiqR73P257OOu0dNPS1euSIcPS2ZpmT/1lLR1q3TwYOh0QGYGB6VSSbp4UdqxQzp0SNq3L3SquOVxzhs5Q69K+it3TyTdJuk+M0uWjHmfpLfUvo5J+lymKdfDPS3zgYG0xBfKfGAg3c6ZOiIxOCj190vlcvrHulxOnw8Ohk4Wr7zOed0zdHefljRde3zJzIYl3SxpaNGwuyQ96u4u6Vkz22lmXbXfG4ZZemYupSU+MJA+PnDgp2fsQARKJen69Ru3Xb+ebg99xhirvM75mtbQzWyvpHdIGliy62ZJE4ueT9a2Lf39x8zsrJmdnZubW2PUdVhc6gso85YxNTWl2dlZ7nap4+LFtW3HxuV1zhsudDPbJukrkj7m7j9ez8Hc/aS773f3/Z2dnev5Fms9YLrMstjC8gtyr7u7W729vXrllVdCR8m1HTvWth0bl9c5b6jQzaxDaZl/wd0fX2bIlKQ9i5731LaFs3jN/MAB6YEH0l8Xr6kj9173uvhvxNqoQ4ekjo4bt3V0pNuxOfI6543c5WKSPi9p2N0/vcKwJyT9uZl9UdIBSReDrp9L6bLK1q03rpkvLL9s3cqyC6KxsGabtzsuYpbXOW/kQ6LfI+nDkgbN7IXatk9IerMkufsJSV9TesviS0pvW7w386TrcfBgeia+UN4LpU6ZIzL79oUvk3aTxzlv5C6X/5K0agPW7m65L6tQmVpa3pQ5gEg1coYOBJMkiUZHR+Xu2rVrl7q6ukJHAnKLK07IvbGxMd4GAGgAhQ4AkaDQASASFDoARIJCR8sol8uhIwC5RqGjJbi7RkdHNTIyEjoKkFsUOlpGf39/6AhArlHoABAJCh0AIkGho6VUq1XW0YEVUOhoGYVCQX19fRoZGdH0dNg38wTyiEJHS0mShE8wAlZAoQNAJCh0AIgEhQ4AkaDQ0XKmpqY0OzvL3S7AEhQ6Wk53d7d6e3tDxwByh0IHgEhQ6AAQCQodACJBoaNlVatVXbp0KXQMIDcodLSkQqGgUqmkiYkJSh2oodDRstxdL7/8cugYQG5Q6AAQCQodACJBoQNAJCh0tDwujAKpuoVuZo+Y2Xkze3GF/TvMrN/MvmNm58zs3uxjAssbGxtTqVTSzMxM6ChAcI2coZ+SdMcq+++TNOTut0o6KOkfzOxnNh4NaMzU1FToCEAu1C10dz8j6cJqQyRtNzOTtK02tppNPABAo7Zk8D0ekvSEpB9I2i7p99391eUGmtkxScckaefOnRkcGgCwIIuLooclvSCpW9LbJT1kZj+/3EB3P+nu+919f2dnZwaHBqQLFy6oUqloeHiYi6Noa1kU+r2SHvfUS5LGJP1KBt8XaEihUFBfXx8fHo22l0Whf1/SeyXJzN4k6Zcl8XpsAGiyumvoZvaY0rtXdpvZpKQHJXVIkrufkPRJSafMbFCSSfq4u/9o0xIDAJZVt9Dd/e46+38g6bczSwQAWBdeKYooJEmi+fl5Xb58OXQUIBgKHdF45plnNDs7y8VRtC0KHdHo7u5WsVgMHQMIhkIHgEhQ6AAQCQodACJBoSM6lUpF09PToWMATUehIyoLF0bL5XLoKEDTUeiITrlc5tZFtCUKHQAiQaEDQCQodACIBIWOKM3Pz2toaIi7XdBWKHREJ0kSjY2NqVgs6urVq6HjAE1DoQNAJCh0AIgEhY6oXbt2LXQEoGkodESrXC6rWq1qeHg4dBSgKSh0RCtJEvX19YWOATQNhQ4AkaDQASASFDoARIJCR/TcnXdfRFug0BG1JEnU29vLh16gLVDoiF6hUFCpVAodA9h0FDoARIJCB4BI1C10M3vEzM6b2YurjDloZi+Y2Tkz+89sIwIAGtHIGfopSXestNPMdkr6rKQj7v5WSb+XSTIgY7Ozs9ztgqjVLXR3PyPpwipD/kDS4+7+/dr48xllAzLj7urt7eXNuhC1LNbQb5G0y8yeNrNvm9lHVhpoZsfM7KyZnZ2bm8vg0ACABVsy+h7vkvReSa+X9N9m9qy7jywd6O4nJZ2UpJ6eHs/g2ACAmiwKfVLSjLvPSZozszOSbpX0mkIHAGyeLJZc/k3S7Wa2xcx+TtIBSbwBNXJpfn5ely5dCh0D2BR1z9DN7DFJByXtNrNJSQ9K6pAkdz/h7sNm9u+SvivpVUkPu/uKtzgCoRQKBY2OjsrdtWvXLnV1dYWOBGSqbqG7+90NjPmUpE9lkgjYRGNjYxofH9fRo0dDRwEyxytFASASFDoARIJCR1viwihilMVti0BLOXfunPbu3StJuuWWW8KGATLEGTraTpIk6u/vDx0DyByFDgCRoNABIBIUOgBEgkJH26pWqxoZ4S2HEA8KHW2pUCior69PIyMj3MKIaFDoaFtJkoSOAGSKQgeASFDoABAJCh1t7/Lly6EjAJmg0NHWnnnmGc3Ozmp8fDx0FGDDKHS0te7ubhWLxdAxgExQ6AAQCQodACJBoaPtXbhwQZVKhRcYoeVR6Gh7hUJBpVJJExMTlDpaGoUOSHJ3vfzyy6FjABtCoQNAJCh0AIgEhQ4AkaDQgUUmJye5MIqWRaEDNWNjYzp9+rRmZmZCRwHWhUIHFpmamgodAVi3uoVuZo+Y2Xkze7HOuF81s6qZfTC7eACARjVyhn5K0h2rDTCzmyT9vaT/yCATAGAd6ha6u5+RdKHOsOOSviLpfBahAABrt+E1dDO7WdLvSvpcA2OPmdlZMzs7Nze30UMDmVt4X5fh4eHQUYA1y+Ki6GckfdzdX6030N1Puvt+d9/f2dmZwaGBbBUKBfX19Wl8fJzbF9FytmTwPfZL+qKZSdJuSXeaWdXdv5rB9wYANGjDhe7uv7Tw2MxOSXqSMgeA5qtb6Gb2mKSDknab2aSkByV1SJK7n9jUdACAhtUtdHe/u9Fv5u73bCgNkBPz8/OamZnR9u3bQ0cBGsYrRYElkiRRf3+/KpWKxsfHQ8cBGkahA8soFAoqFouhYwBrQqEDQCQodACIBIUOrKJSqWh6ejp0DKAhFDqwgu7ubhWLRZXL5dBRgIZQ6MAqKHO0EgodACJBoQNAJCh0AIgEhQ7U4e4aGhribhfkHoUOrCJJEp0+fVqlUil0FKAuCh0AIkGhA0AkKHQAiASFDjSIzxhF3lHoQAPOnTunarWqkZGR0FGAFVHoQAOSJFFvb2/oGMCqKHQAiASFDgCRoNABIBIUOrAG1WqVD45GblHoQIMKhYJ6e3v5FCPkFoUOrEGhUOB9XZBbFDoARIJCB4BIUOjAOszOznJxFLlTt9DN7BEzO29mL66w/w/N7LtmNmhm3zKzW7OPCeSHu6u3t1fXrl0LHQW4QSNn6Kck3bHK/jFJv+Hu+yR9UtLJDHIBANZoS70B7n7GzPausv9bi54+K6kng1wAgDXKeg39TyR9faWdZnbMzM6a2dm5ubmMDw0A7a3uGXqjzOw3lRb67SuNcfeTqi3J9PT0eFbHBkKYn58PHQG4QSaFbmZvk/SwpPe5+0wW3xPIs0KhoNHRUbm79uzZo+3bt4eOBGx8ycXM3izpcUkfdnfe/R9tY2xsjFeNIlfqnqGb2WOSDkrabWaTkh6U1CFJ7n5C0gOS3iDps2YmSVV3379ZgQEAy2vkLpe76+z/qKSPZpYIALAuvFIUACJBoQMbNDk5GToCIIlCBzbE3TU6OqqREe4HQHgUOrBB/f39oSMAkih0AIgGhQ4AkaDQASASFDqQgWq1yoVRBEehAxtUKBTU19enkZERXbp0KXQctDEKHchAkiShIwAUOgDEgkIHgEhQ6EBGxsfHNTExofHx8dBR0KYodCAj7q5isRg6BtoYhQ4AkaDQASASFDoARIJCBzJWqVR4gRGCoNCBDHV3d6tUKmlycpJSR9NR6EDGzp07x62LCIJCB4BIUOgAEAkKHQAiQaEDm2B+fp4Lo2g6Ch3IWJIkGhsb0+nTpzUzMxM6DtoIhQ5skqmpqdAR0GYodACIRPyF7r76c2SPOQeC2FJvgJk9Iun9ks67e2GZ/SapV9KdkiqS7nH357IOui5PPy1duSIdPiyZpcXy1FPS1q3SwYOh08WJOUebGByUSiXp4kVpxw7p0CFp376wmRo5Qz8l6Y5V9r9P0ltqX8ckfW7jsTLgnhbLwEBaKAvFMjCQbuesMXvM+Q0uXLigSqWi4eHh0FGQscFBqb9fKpfTP9blcvp8cDBsrrpn6O5+xsz2rjLkLkmPurtLetbMdppZl7tPZxVyXczSs0QpLZSBgfTxgQM/PXtEtpjzGxQKBfX19en+++8PHQUZK5Wk69dv3Hb9ero95Fl6FmvoN0uaWPR8srbtNczsmJmdNbOzc3NzGRy6jsUFs6ANi6WpmHO0gYsX17a9WZp6UdTdT7r7fnff39nZ2YwDpj/yL7awFIDNwZyjDezYsbbtzZJFoU9J2rPoeU9tW1iL128PHJAeeCD9dfH6LrLFnC/L3TU9HXYFEtk6dEjq6LhxW0dHuj2kLAr9CUkfsdRtki4GXz+X0h/xt269cf328OH0+datLAFsBub8NZIkUbFY1OzsLKUekX37pA98QNq5M/1jvXNn+jz0XS7mdc6azOwxSQcl7Zb0f5IelNQhSe5+onbb4kNK74SpSLrX3c/WO3BPT48fP358Q+Eb4n5jkSx9juwx569hZjp69Ki6urpCR0GL6+zs/La7719uXyN3udxdZ79Lum+d2Tbf0iJp82JpCuYcCCL+V4oCQJug0AEgEhQ60CRcGMVmo9CBJnB3FYtFPvACm4pCB5qkXC6HjoDIUegAEAkKHQAiQaEDQCQodKCJqtWqhoaGuNsFm4JCB5okSRKdPn1apVIpdBREikIHgEhQ6AAQibrvtrhpBzb7oaTvNfGQuyX9qInHy1KrZm/V3FLrZm/V3FLrZm927l909zcutyNYoTebmZ1d6S0n865Vs7dqbql1s7dqbql1s+cpN0suABAJCh0AItFOhX4ydIANaNXsrZpbat3srZpbat3sucndNmvoABC7djpDB4CoUegAEImoCt3MHjGz82b24gr7zcz+0cxeMrPvmtk7m51xJQ1kP2hmF83shdrXA83OuBwz22Nm3zSzITM7Z2Z/scyY3M17g7nzOudbzex/zOw7tex/u8yYnzWzL9XmfMDM9gaIujRTI7nvMbMfLprzj4bIuhIzu8nMnjezJ5fZF37O3T2aL0m/Lumdkl5cYf+dkr4uySTdJmkgdOY1ZD8o6cnQOZfJ1SXpnbXH2yWNSEryPu8N5s7rnJukbbXHHZIGJN22ZMyfSTpRe/whSV9qkdz3SHoodNZV/hv+UtK/LPfnIg9zHtUZurufkXRhlSF3SXrUU89K2mlmXc1Jt7oGsueSu0+7+3O1x5ckDUu6ecmw3M17g7lzqTaPl2tPO2pfS+9uuEvSP9Uef1nSe83MmhRxWQ3mzi0z65H0O5IeXmFI8DmPqtAbcLOkiUXPJ9Uif4lrfq324+rXzeytocMsVfsR8x1Kz7wWy/W8r5Jbyumc1370f0HSeUnfcPcV59zdq5IuSnpDU0Muo4HcknS0tjT3ZTPb09yEq/qMpL+W9OoK+4PPebsVeit7Tul7ONwqqU/SV8PGuZGZbZP0FUkfc/cfh87TqDq5czvn7j7v7m+X1CPp3WZWCBypIQ3k7pe0193fJukb+ukZb1Bm9n5J593926GzrKbdCn1K0uJ/8Xtq23LP3X+88OOqu39NUoeZ7Q4cS5JkZh1KS/EL7v74MkNyOe/1cud5zhe4e1nSNyXdsWTXT+bczLZI2iFppqnhVrFSbnefcfertacPS3pXk6Ot5D2SjpjZuKQvSjpkZv+8ZEzwOW+3Qn9C0kdqd13cJumiu7fER8eY2S8srMeZ2buV/r8L/he0lunzkobd/dMrDMvdvDeSO8dz/kYz21l7/HpJvyXpf5cMe0LSH9cef1BSyWtX60JpJPeSaytHlF7bCM7d/8bde9x9r9ILniV3/6Mlw4LP+ZZmHmyzmdljSu9M2G1mk5IeVHrhRe5+QtLXlN5x8ZKkiqR7wyR9rQayf1DSn5pZVdIrkj4U+i9ozXskfVjSYG1tVJI+IenNUq7nvZHceZ3zLkn/ZGY3Kf1H5l/d/Ukz+ztJZ939CaX/WBXN7CWlF9s/FC7uTzSS+34zOyKpqjT3PcHSNiBvc85L/wEgEu225AIA0aLQASASFDoARIJCB4BIUOgAEAkKHQAiQaEDQCT+H+Ww0z5fuBGwAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "contour = ax.contourf(\n", " contour_plot_x_data,\n", " contour_plot_y_data,\n", " predictions.round().reshape(contour_plot_x_data.shape),\n", " cmap=\"gray\",\n", " alpha=0.50,\n", ")\n", "display(fig)" ] }, { "cell_type": "markdown", "id": "d3f39faa", "metadata": {}, "source": [ "### As a bonus let's inspect the model parameters" ] }, { "cell_type": "code", "execution_count": 10, "id": "7fa65211", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[4.53586054]\n", " [2.37015319]]\n", "-14.660321235656738\n" ] } ], "source": [ "w = np.array(model.fc.weight.flatten().tolist()).reshape((-1, 1))\n", "b = model.fc.bias.flatten().tolist()[0]\n", "\n", "print(w)\n", "print(b)" ] }, { "cell_type": "markdown", "id": "544d6e34", "metadata": {}, "source": [ "They are floating point numbers and we can't directly work with them!" ] }, { "cell_type": "markdown", "id": "abf310f2", "metadata": {}, "source": [ "### So, let's abstract quantization\n", "\n", "Here is a quick summary of quantization. We have a range of values and we want to represent them using small number of bits (n). To do this, we split the range into 2^n sections and map each section to a value. Here is a visualization of the process!" ] }, { "cell_type": "code", "execution_count": 11, "id": "d3ab2aa2", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "
min(x)
min(x)
max(x)
max(x)
Map
to 0
Map...
Map
to 1
Map...
Distance
Between
Consecutive
Values
Distan...
Map
to 2
Map...
Map
to 3
Map...
(when n = 2)
(when n = 2)
0
0
= 1 / scale
= 1 / q
= 1 / scale...
x = (x   + zp  ) / q
x = (x   + zp  ) / q
q
q
x
x
x
x
zero point
zp = 2
zero point...
Viewer does not support full SVG 1.1
" ], "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import SVG\n", "SVG(filename=\"figures/QuantizationVisualized.svg\")" ] }, { "cell_type": "markdown", "id": "e4314038", "metadata": {}, "source": [ "If you want to learn more, head to https://intellabs.github.io/distiller/algo_quantization.html" ] }, { "cell_type": "code", "execution_count": 12, "id": "a8bab855", "metadata": {}, "outputs": [], "source": [ "class QuantizationParameters:\n", " def __init__(self, q, zp, n):\n", " # q = scale factor = 1 / distance between consecutive values\n", " # zp = zero point which is used to determine the beginning of the quantized range\n", " # (quantized 0 = the beginning of the quantized range = zp * distance between consecutive values)\n", " # n = number of bits\n", " \n", " # e.g.,\n", " \n", " # n = 2\n", " # zp = 2\n", " # q = 0.66\n", " # distance between consecutive values = 1 / q = 1.5151\n", " \n", " # quantized 0 = zp / q = zp * distance between consecutive values = 3.0303\n", " # quantized 1 = quantized 0 + distance between consecutive values = 4.5454\n", " # quantized 2 = quantized 1 + distance between consecutive values = 6.0606\n", " # quantized 3 = quantized 2 + distance between consecutive values = 7.5757\n", " \n", " self.q = q\n", " self.zp = zp\n", " self.n = n\n", "\n", "class QuantizedArray:\n", " def __init__(self, values, parameters):\n", " # values = quantized values\n", " # parameters = parameters used during quantization\n", " \n", " # e.g.,\n", " \n", " # values = [1, 0, 2, 1]\n", " # parameters = QuantizationParameters(q=0.66, zp=2, n=2)\n", " \n", " # original array = [4.5454, 3.0303, 6.0606, 4.5454]\n", " \n", " self.values = np.array(values)\n", " self.parameters = parameters\n", "\n", " @staticmethod\n", " def of(x, n):\n", " if not isinstance(x, np.ndarray):\n", " x = np.array(x)\n", "\n", " min_x = x.min()\n", " max_x = x.max()\n", "\n", " if min_x == max_x: # encoding single valued arrays\n", " \n", " if min_x == 0.0: # encoding 0s\n", " \n", " # dequantization = (x_q + zp_x) / q_x = 0 --> q_x = 1 && zp_x = 0 && x_q = 0\n", " q_x = 1\n", " zp_x = 0\n", " x_q = np.zeros(x.shape, dtype=np.uint)\n", " \n", " elif min_x < 0.0: # encoding negative scalars\n", " \n", " # dequantization = (x_q + zp_x) / q_x = -x --> q_x = 1 / x & zp_x = -1 & x_q = 0\n", " q_x = abs(1 / min_x)\n", " zp_x = -1\n", " x_q = np.zeros(x.shape, dtype=np.uint)\n", " \n", " else: # encoding positive scalars\n", " \n", " # dequantization = (x_q + zp_x) / q_x = x --> q_x = 1 / x & zp_x = 0 & x_q = 1\n", " q_x = 1 / min_x\n", " zp_x = 0\n", " x_q = np.ones(x.shape, dtype=np.uint)\n", " \n", " else: # encoding multi valued arrays\n", " \n", " # distance between consecutive values = range of x / number of different quantized values = (max_x - min_x) / (2^n - 1)\n", " # q = 1 / distance between consecutive values\n", " q_x = (2**n - 1) / (max_x - min_x)\n", " \n", " # zp = what should be added to 0 to get min_x -> min_x = (0 + zp) / q -> zp = min_x * q\n", " zp_x = int(round(min_x * q_x))\n", " \n", " # x = (x_q + zp) / q -> x_q = (x * q) - zp\n", " x_q = ((q_x * x) - zp_x).round().astype(np.uint)\n", "\n", " return QuantizedArray(x_q, QuantizationParameters(q_x, zp_x, n))\n", "\n", " def dequantize(self):\n", " # x = (x_q + zp) / q\n", " # x = (x_q + zp) / q\n", " return (self.values.astype(np.float32) + float(self.parameters.zp)) / self.parameters.q\n", "\n", " def affine(self, w, b, min_y, max_y, n_y):\n", " # the formulas used in this method was derived from the following equations\n", " #\n", " # x = (x_q + zp_x) / q_x\n", " # w = (w_q + zp_w) / q_w\n", " # b = (b_q + zp_b) / q_b\n", " #\n", " # (x * w) + b = ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b)\n", " # = y = (y_q + zp_y) / q_y\n", " #\n", " # So, ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b) = (y_q + zp_y) / q_y\n", " # We can calculate zp_y and q_y from min_y, max_y, n_y. So, the only unknown is y_q and it can be solved.\n", "\n", " x_q = self.values\n", " w_q = w.values\n", " b_q = b.values\n", "\n", " q_x = self.parameters.q\n", " q_w = w.parameters.q\n", " q_b = b.parameters.q\n", "\n", " zp_x = self.parameters.zp\n", " zp_w = w.parameters.zp\n", " zp_b = b.parameters.zp\n", "\n", " q_y = (2**n_y - 1) / (max_y - min_y)\n", " zp_y = int(round(min_y * q_y))\n", "\n", " y_q = (q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b))\n", " y_q -= min_y * q_y\n", " y_q = y_q.round().clip(0, 2**n_y - 1).astype(np.uint)\n", "\n", " return QuantizedArray(y_q, QuantizationParameters(q_y, zp_y, n_y))\n", "\n", "class QuantizedFunction:\n", " def __init__(self, table, input_parameters=None, output_parameters=None):\n", " self.table = table\n", " self.input_parameters = input_parameters\n", " self.output_parameters = output_parameters\n", "\n", " @staticmethod\n", " def of(f, input_bits, output_bits):\n", " domain = np.array(range(2**input_bits), dtype=np.uint)\n", " table = f(domain).round().clip(0, 2**output_bits - 1).astype(np.uint)\n", " return QuantizedFunction(table)\n", "\n", " @staticmethod\n", " def plain(f, input_parameters, output_bits):\n", " n = input_parameters.n\n", "\n", " domain = np.array(range(2**n), dtype=np.uint)\n", " inputs = QuantizedArray(domain, input_parameters).dequantize()\n", "\n", " outputs = f(inputs)\n", " quantized_outputs = QuantizedArray.of(outputs, output_bits)\n", "\n", " table = quantized_outputs.values\n", " output_parameters = quantized_outputs.parameters\n", "\n", " return QuantizedFunction(table, input_parameters, output_parameters)\n", "\n", " def apply(self, x):\n", " assert x.parameters == self.input_parameters\n", " return QuantizedArray(self.table[x.values], self.output_parameters)" ] }, { "cell_type": "markdown", "id": "e5be0800", "metadata": {}, "source": [ "### Let's quantize our model parameters\n", "\n", "Since the parameters only consist of scalars, we can use a single bit quantization." ] }, { "cell_type": "code", "execution_count": 13, "id": "3ec0ad9b", "metadata": {}, "outputs": [], "source": [ "parameter_bits = 1\n", "\n", "w_q = QuantizedArray.of(w, parameter_bits)\n", "b_q = QuantizedArray.of(b, parameter_bits)" ] }, { "cell_type": "markdown", "id": "b43c0371", "metadata": {}, "source": [ "### And quantize our inputs" ] }, { "cell_type": "code", "execution_count": 14, "id": "20cea447", "metadata": {}, "outputs": [], "source": [ "input_bits = 5\n", "\n", "x = inputs\n", "x_q = QuantizedArray.of(inputs, input_bits)" ] }, { "cell_type": "markdown", "id": "ca76b68d", "metadata": {}, "source": [ "### Time to make quantized inference" ] }, { "cell_type": "code", "execution_count": 15, "id": "8728e939", "metadata": {}, "outputs": [], "source": [ "output_bits = 7\n", "\n", "intermediate = x @ w + b\n", "intermediate_q = x_q.affine(w_q, b_q, intermediate.min(), intermediate.max(), output_bits)\n", "\n", "sigmoid = QuantizedFunction.plain(lambda x: 1 / (1 + np.exp(-x)), intermediate_q.parameters, output_bits)\n", "y_q = sigmoid.apply(intermediate_q)\n", "\n", "quantized_predictions = y_q.dequantize()" ] }, { "cell_type": "markdown", "id": "ab782b4a", "metadata": {}, "source": [ "### And visualize the results" ] }, { "cell_type": "code", "execution_count": 16, "id": "9d2bb5da", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQaElEQVR4nO3df2xdZ33H8fd3jTtTx0siUrGkDssqkY2QUKAZ6QTavERbf6FW05hGt8FaDaXaSje0SkPjj1Ybf01oiK4IoqhUoRtrmaBibVXWoYQuGqyeQltwSqaqagOERgq0iuncdkrW7/4419Qxtu+xc67P9eP3S7rqvec8vueTp/Ynx889NzcyE0nS8vczbQeQJDXDQpekQljoklQIC12SCmGhS1IhVrV14KGhoVy3bl1bh1ehTp8+zYUXXsj555/fdhSpJx5//PEfZeaFs+1rrdDXrVvHzTff3NbhVajnnnuOG2+8kc2bN7cdReqJoaGh7861zyUXSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBWi/EKf+ZmpfoaqpEJ1/dcWI2ITcDfwBiCBfZl5+4wxAdwOXAW8BFyfmY81H3eBHnkEXnkFLr8cIqoyf/hhGByE0dG200mNGR+HgwdhYgLWrIFdu2D79rZTla0f57zOGfoZ4JbM3ApcBtwUEVtnjLkSeFPntgf4TKMpFyOzKvOxsarEp8p8bKza7pm6CjE+Dg88AKdOVd/Wp05Vj8fH205Wrn6d865n6Jl5AjjRuf9iRBwFLgK+M23YtcDdmZnAoxGxNiI2dL62HRHVmTlUJT42Vt3fufO1M3apAAcPwunTZ287fbra3vYZY6n6dc4XtIYeEZuBtwNjM3ZdBHx/2uPjnW0zv35PRByOiMOTk5MLjLoI00t9imWuwkxMLGy7zl2/znntQo+I1cCXgA9n5o8Xc7DM3JeZOzJzx9DQ0GKeYqEHrJZZpptafpEKsWbNwrbr3PXrnNcq9IgYoCrzz2fmfbMM+QGwadrjkc629kxfM9+5E269tfrv9DV1qQC7dsHAwNnbBgaq7eqNfp3zOle5BPBZ4GhmfmKOYfcDH4qIe4GdwESr6+dQLasMDp69Zj61/DI46LKLijG1ZttvV1yUrF/nvM6HRL8LeD8wHhFPdLZ9FHgjQGbuBR6iumTxaarLFm9oPOlijI5WZ+JT5T1V6pa5CrN9e/tlstL045zXucrlP4B5G7BzdctNTYVq1MzytswlFar8d4pK0gphoUtSIeqsoUvLyssvv8yJE/Vek1+9ejXDw8M9TiQtDQtdRdm4cSMHDhyoPX7Xrl1s2rTJUlcRXHJRcTKz9u2ZZ55pO67UGAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJ0LfSIuCsiTkbEkTn2r4mIByLiWxHxZETc0HxMSVI3dc7Q9wNXzLP/JuA7mXkJMAr8XUScf+7RJEkL0bXQM/MQ8MJ8Q4DhiAhgdWfsmWbiSZLqWtXAc3wKuB94DhgGfi8zX51tYETsAfYArF27toFDS5KmNPGi6OXAE8BG4G3ApyLi52YbmJn7MnNHZu4YGhpq4NCSpClNFPoNwH1ZeRp4FvjlBp5XkrQATRT694DdABHxBuCXgGcaeF5J0gJ0XUOPiHuorl5ZHxHHgduAAYDM3At8DNgfEeNAAB/JzB/1LLEkaVZdCz0zr+uy/zngtxpLJElaFN8pKkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCNPGJRdKydezYMS6++OLa41etWsWWLVt6mEhaPAtdK1pmcuDAgVpjjxw5wi233NLjRNLiueQiSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUiK6FHhF3RcTJiDgyz5jRiHgiIp6MiH9vNqIkqY46Z+j7gSvm2hkRa4FPA9dk5luA320kmSRpQboWemYeAl6YZ8jvA/dl5vc64082lE2StABNrKFvAdZFxCMR8c2I+MBcAyNiT0QcjojDk5OTDRxakjSliU8sWgVcCuwGXgf8Z0Q8mplPzRyYmfuAfQAjIyPZwLElSR1NFPpx4PnMnAQmI+IQcAnwU4UuSeqdJpZc/gV4d0SsiogLgJ3A0QaeV5K0AF3P0CPiHmAUWB8Rx4HbgAGAzNybmUcj4l+BbwOvAndm5pyXOEqSeqNroWfmdTXGfBz4eCOJJEmL4jtFJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgqxqu0A0nJy5swZnnrqqQV9zYYNGxgeHu5RIuk1FrpU07Zt27j99tsX9DXbt29n9+7dvPnNb+5RKuk1Frq0ANu2bVvQ+CeffJLdu3f3KI10NtfQJakQFrokFcJCl6RCdC30iLgrIk5GxJEu434lIs5ExHubiydJqqvOGfp+4Ir5BkTEecDfAv/WQCZJ0iJ0LfTMPAS80GXYzcCXgJNNhJIkLdw5r6FHxEXAbwOfqTF2T0QcjojDk5OT53poSdI0Tbwo+kngI5n5areBmbkvM3dk5o6hoaEGDi1JmtLEG4t2APdGBMB64KqIOJOZX27guSVJNZ1zoWfmL07dj4j9wIOWuSQtva6FHhH3AKPA+og4DtwGDABk5t6eppMk1da10DPzurpPlpnXn1MaSdKi+U5RSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQXQs9Iu6KiJMRcWSO/X8QEd+OiPGI+EZEXNJ8TElSN3XO0PcDV8yz/1ng1zNzO/AxYF8DuSRJC7Sq24DMPBQRm+fZ/41pDx8FRhrIJUlaoK6FvkB/DHxlrp0RsQfYA7B27dqGDy31rxdffLH22OHh4R4mUckaK/SI+A2qQn/3XGMycx+dJZmRkZFs6thSv9q6dSsHDhxg8+bNtcZffPHFXHDBBbXHS9M1UugR8VbgTuDKzHy+ieeUSpGZPPvss7XGfv3rX+fGG2/scSKV6pwvW4yINwL3Ae/PzKfOPZIkaTG6nqFHxD3AKLA+Io4DtwEDAJm5F7gVeD3w6YgAOJOZO3oVWJI0uzpXuVzXZf8HgQ82lkiStCi+U1SSCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKkT5hZ45/2M1zzmXWrGq24CIuAt4D3AyM7fNsj+A24GrgJeA6zPzsaaDLsojj8Arr8Dll0NEVSwPPwyDgzA62na6MjnnWiHGx+HgQZiYgDVrYNcu2L693Ux1ztD3A1fMs/9K4E2d2x7gM+ceqwGZVbGMjVWFMlUsY2PVds8am+eca4UYH4cHHoBTp6pv61Onqsfj4+3m6nqGnpmHImLzPEOuBe7OzAQejYi1EbEhM080FXJRIqqzRKgKZWysur9z52tnj2qWc64V4uBBOH367G2nT1fb2zxLb2IN/SLg+9MeH+9s+ykRsSciDkfE4cnJyQYO3cX0gplisfSWc64VYGJiYduXypK+KJqZ+zJzR2buGBoaWooDVr/yTze1FKDecM61AqxZs7DtS6XrkksNPwA2TXs80tnWrunrt1O/8k89Bs8ae8E51wqxa1e1Zj592WVgoNrepiYK/X7gQxFxL7ATmGh9/Ryq4hgcPHv9dmopYHDQYukF51wrxNQ6eb9d5VLnssV7gFFgfUQcB24DBgAycy/wENUli09TXbZ4Q6/CLtjoaHXWOFUkUwVjsfSOc64VYvv29gt8pjpXuVzXZX8CNzWWqGkzi8Ri6T3nXGpF+e8UlaQVwkKXpEJY6JJUiCaucpHUoJdeeomjR4/WGnveeeexZcuWHifScmGhS31k48aN3HHHHbXHX3311QwPD7Nhw4YeptJyYaFLfWbr1q21xx47doxLL720h2m0nLiGLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgoR2dIHD0TED4HvLuEh1wM/WsLjNWm5Zl+uuWH5Zl+uuWH5Zl/q3L+QmRfOtqO1Ql9qEXE4M3e0nWMxlmv25Zoblm/25Zoblm/2fsrtkoskFcJCl6RCrKRC39d2gHOwXLMv19ywfLMv19ywfLP3Te4Vs4YuSaVbSWfoklQ0C12SClFUoUfEXRFxMiKOzLE/IuLvI+LpiPh2RLxjqTPOpUb20YiYiIgnOrdblzrjbCJiU0R8LSK+ExFPRsSfzzKm7+a9Zu5+nfPBiPiviPhWJ/tfzzLmZyPiC505H4uIzS1EnZmpTu7rI+KH0+b8g21knUtEnBcRj0fEg7Psa3/OM7OYG/BrwDuAI3Psvwr4ChDAZcBY25kXkH0UeLDtnLPk2gC8o3N/GHgK2Nrv814zd7/OeQCrO/cHgDHgshlj/hTY27n/PuALyyT39cCn2s46z5/hL4B/mu37oh/mvKgz9Mw8BLwwz5Brgbuz8iiwNiL64qNeamTvS5l5IjMf69x/ETgKXDRjWN/Ne83cfakzj//TeTjQuc28uuFa4HOd+18EdkdELFHEWdXM3bciYgS4GrhzjiGtz3lRhV7DRcD3pz0+zjL5Ie741c6vq1+JiLe0HWamzq+Yb6c685qur+d9ntzQp3Pe+dX/CeAk8NXMnHPOM/MMMAG8fklDzqJGboDf6SzNfTEiNi1twnl9EvhL4NU59rc+5yut0Jezx6j+DYdLgDuAL7cb52wRsRr4EvDhzPxx23nq6pK7b+c8M/8vM98GjADvjIhtLUeqpUbuB4DNmflW4Ku8dsbbqoh4D3AyM7/Zdpb5rLRC/wEw/W/8kc62vpeZP576dTUzHwIGImJ9y7EAiIgBqlL8fGbeN8uQvpz3brn7ec6nZOYp4GvAFTN2/WTOI2IVsAZ4fknDzWOu3Jn5fGb+b+fhnUC/fGDqu4BrIuIYcC+wKyL+ccaY1ud8pRX6/cAHOlddXAZMZOaJtkPVERE/P7UeFxHvpPp/1/oPaCfTZ4GjmfmJOYb13bzXyd3Hc35hRKzt3H8d8JvAf88Ydj/wR5377wUOZufVurbUyT3jtZVrqF7baF1m/lVmjmTmZqoXPA9m5h/OGNb6nK9ayoP1WkTcQ3VlwvqIOA7cRvXCC5m5F3iI6oqLp4GXgBvaSfrTamR/L/AnEXEGeBl4X9s/oB3vAt4PjHfWRgE+CrwR+nre6+Tu1znfAHwuIs6j+kvmnzPzwYj4G+BwZt5P9ZfVP0TE01Qvtr+vvbg/USf3n0XENcAZqtzXt5a2hn6bc9/6L0mFWGlLLpJULAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFeL/Admi1qH7N00UAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for column in contour.collections:\n", " plt.gca().collections.remove(column)\n", " \n", "contour = ax.contourf(\n", " contour_plot_x_data,\n", " contour_plot_y_data,\n", " quantized_predictions.round().reshape(contour_plot_x_data.shape),\n", " cmap=\"gray\",\n", " alpha=0.50,\n", ")\n", "display(fig)" ] }, { "cell_type": "markdown", "id": "4834cdfc", "metadata": {}, "source": [ "### Now it's time to make the inference homomorphic" ] }, { "cell_type": "code", "execution_count": 17, "id": "fcf4ea26", "metadata": {}, "outputs": [], "source": [ "q_y = (2**output_bits - 1) / (intermediate.max() - intermediate.min())\n", "zp_y = int(round(intermediate.min() * q_y))\n", "\n", "q_x = x_q.parameters.q\n", "q_w = w_q.parameters.q\n", "q_b = b_q.parameters.q\n", "\n", "zp_x = x_q.parameters.zp\n", "zp_w = w_q.parameters.zp\n", "zp_b = b_q.parameters.zp\n", "\n", "x_q = x_q.values\n", "w_q = w_q.values\n", "b_q = b_q.values" ] }, { "cell_type": "markdown", "id": "43e47369", "metadata": {}, "source": [ "### Simplification to rescue!\n", "\n", "The `y_q` formula in `QuantizedArray.affine(...)` can be rewritten to make it easier to implement in homomorphically. Here is the breakdown.\n", "```\n", "(q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b)) - (min_y * q_y)\n", "^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^\n", "constant (c1) can be done constant (c2) constant (c3) constant (c4)\n", " on the circuit \n", " \n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " can be done on the circuit\n", " \n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "cannot be done on the circuit because of floating point operation so will be a single table lookup\n", "```" ] }, { "cell_type": "code", "execution_count": 18, "id": "2de0cf20", "metadata": {}, "outputs": [], "source": [ "c1 = q_y / (q_x * q_w)\n", "c2 = w_q + zp_w\n", "c3 = (q_x * q_w / q_b) * (b_q + zp_b)\n", "c4 = intermediate.min() * q_y\n", "\n", "def f(x):\n", " values = ((c1 * (x + c3)) - c4).round().clip(0, 2**output_bits - 1).astype(np.uint)\n", " after_affine_q = QuantizedArray(values, intermediate_q.parameters)\n", " \n", " sigmoid = QuantizedFunction.plain(lambda x: 1 / (1 + np.exp(-x)), after_affine_q.parameters, output_bits)\n", " y_q = sigmoid.apply(after_affine_q)\n", " \n", " return y_q.values\n", "\n", "f_q = QuantizedFunction.of(f, output_bits, output_bits)\n", "\n", "from hdk.common.extensions.table import LookupTable\n", "table = LookupTable([int(entry) for entry in f_q.table])\n", "\n", "w_0 = int(c2.flatten()[0])\n", "w_1 = int(c2.flatten()[1])\n", "\n", "def infer(x_0, x_1):\n", " return table[((x_0 + zp_x) * w_0) + ((x_1 + zp_x) * w_1)]" ] }, { "cell_type": "markdown", "id": "93eb9499", "metadata": {}, "source": [ "### Time to compile our quantized inference function" ] }, { "cell_type": "code", "execution_count": 19, "id": "a80895fd", "metadata": {}, "outputs": [], "source": [ "from hdk.common.data_types.integers import Integer\n", "from hdk.common.data_types.values import EncryptedValue\n", "from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n", "\n", "dataset = []\n", "for x_i in x_q:\n", " dataset.append((int(x_i[0]), int(x_i[1])))\n", " \n", "homomorphic_model = compile_numpy_function_into_op_graph(\n", " infer,\n", " {\n", " \"x_0\": EncryptedValue(Integer(input_bits, is_signed=False)),\n", " \"x_1\": EncryptedValue(Integer(input_bits, is_signed=False)),\n", " },\n", " iter(dataset),\n", ")" ] }, { "cell_type": "markdown", "id": "f0b08a0f", "metadata": {}, "source": [ "### Here is the textual representation of the operation graph" ] }, { "cell_type": "code", "execution_count": 20, "id": "2cc4e11d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "%0 = ConstantInput(2) # Integer\n", "%1 = ConstantInput(1) # Integer\n", "%2 = x_0 # Integer\n", "%3 = ConstantInput(6) # Integer\n", "%4 = x_1 # Integer\n", "%5 = ConstantInput(6) # Integer\n", "%6 = Add(2, 3) # Integer\n", "%7 = Add(4, 5) # Integer\n", "%8 = Mul(6, 0) # Integer\n", "%9 = Mul(7, 1) # Integer\n", "%10 = Add(8, 9) # Integer\n", "%11 = TLU(10) # Integer\n", "return(%11)\n" ] } ], "source": [ "from hdk.common.debugging import get_printable_graph\n", "print(get_printable_graph(homomorphic_model, show_data_types=True))" ] }, { "cell_type": "markdown", "id": "ade14f17", "metadata": {}, "source": [ "### Finally, it's time to make homomorphic inference\n", "\n", "Or, at least, simulate it until the compiler integration is complete." ] }, { "cell_type": "code", "execution_count": 21, "id": "dd2d03d7", "metadata": {}, "outputs": [], "source": [ "homomorphic_predictions = []\n", "for x_0, x_1 in map(lambda x_i: (int(x_i[0]), int(x_i[1])), x_q):\n", " evaluation = homomorphic_model.evaluate({0: x_0, 1: x_1})\n", " inference = QuantizedArray(evaluation[homomorphic_model.output_nodes[0]], y_q.parameters)\n", " homomorphic_predictions.append(inference.dequantize())\n", "homomorphic_predictions = np.array(homomorphic_predictions, dtype=np.float32)" ] }, { "cell_type": "markdown", "id": "443fbc03", "metadata": {}, "source": [ "### And visualize it" ] }, { "cell_type": "code", "execution_count": 22, "id": "57050b5d", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQaElEQVR4nO3df2xdZ33H8fd3jTtTx0siUrGkDssqkY2QUKAZ6QTavERbf6FW05hGt8FaDaXaSje0SkPjj1Ybf01oiK4IoqhUoRtrmaBibVXWoYQuGqyeQltwSqaqagOERgq0iuncdkrW7/4419Qxtu+xc67P9eP3S7rqvec8vueTp/Ynx889NzcyE0nS8vczbQeQJDXDQpekQljoklQIC12SCmGhS1IhVrV14KGhoVy3bl1bh1ehTp8+zYUXXsj555/fdhSpJx5//PEfZeaFs+1rrdDXrVvHzTff3NbhVajnnnuOG2+8kc2bN7cdReqJoaGh7861zyUXSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBWi/EKf+ZmpfoaqpEJ1/dcWI2ITcDfwBiCBfZl5+4wxAdwOXAW8BFyfmY81H3eBHnkEXnkFLr8cIqoyf/hhGByE0dG200mNGR+HgwdhYgLWrIFdu2D79rZTla0f57zOGfoZ4JbM3ApcBtwUEVtnjLkSeFPntgf4TKMpFyOzKvOxsarEp8p8bKza7pm6CjE+Dg88AKdOVd/Wp05Vj8fH205Wrn6d865n6Jl5AjjRuf9iRBwFLgK+M23YtcDdmZnAoxGxNiI2dL62HRHVmTlUJT42Vt3fufO1M3apAAcPwunTZ287fbra3vYZY6n6dc4XtIYeEZuBtwNjM3ZdBHx/2uPjnW0zv35PRByOiMOTk5MLjLoI00t9imWuwkxMLGy7zl2/znntQo+I1cCXgA9n5o8Xc7DM3JeZOzJzx9DQ0GKeYqEHrJZZpptafpEKsWbNwrbr3PXrnNcq9IgYoCrzz2fmfbMM+QGwadrjkc629kxfM9+5E269tfrv9DV1qQC7dsHAwNnbBgaq7eqNfp3zOle5BPBZ4GhmfmKOYfcDH4qIe4GdwESr6+dQLasMDp69Zj61/DI46LKLijG1ZttvV1yUrF/nvM6HRL8LeD8wHhFPdLZ9FHgjQGbuBR6iumTxaarLFm9oPOlijI5WZ+JT5T1V6pa5CrN9e/tlstL045zXucrlP4B5G7BzdctNTYVq1MzytswlFar8d4pK0gphoUtSIeqsoUvLyssvv8yJE/Vek1+9ejXDw8M9TiQtDQtdRdm4cSMHDhyoPX7Xrl1s2rTJUlcRXHJRcTKz9u2ZZ55pO67UGAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJ0LfSIuCsiTkbEkTn2r4mIByLiWxHxZETc0HxMSVI3dc7Q9wNXzLP/JuA7mXkJMAr8XUScf+7RJEkL0bXQM/MQ8MJ8Q4DhiAhgdWfsmWbiSZLqWtXAc3wKuB94DhgGfi8zX51tYETsAfYArF27toFDS5KmNPGi6OXAE8BG4G3ApyLi52YbmJn7MnNHZu4YGhpq4NCSpClNFPoNwH1ZeRp4FvjlBp5XkrQATRT694DdABHxBuCXgGcaeF5J0gJ0XUOPiHuorl5ZHxHHgduAAYDM3At8DNgfEeNAAB/JzB/1LLEkaVZdCz0zr+uy/zngtxpLJElaFN8pKkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCNPGJRdKydezYMS6++OLa41etWsWWLVt6mEhaPAtdK1pmcuDAgVpjjxw5wi233NLjRNLiueQiSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUiK6FHhF3RcTJiDgyz5jRiHgiIp6MiH9vNqIkqY46Z+j7gSvm2hkRa4FPA9dk5luA320kmSRpQboWemYeAl6YZ8jvA/dl5vc64082lE2StABNrKFvAdZFxCMR8c2I+MBcAyNiT0QcjojDk5OTDRxakjSliU8sWgVcCuwGXgf8Z0Q8mplPzRyYmfuAfQAjIyPZwLElSR1NFPpx4PnMnAQmI+IQcAnwU4UuSeqdJpZc/gV4d0SsiogLgJ3A0QaeV5K0AF3P0CPiHmAUWB8Rx4HbgAGAzNybmUcj4l+BbwOvAndm5pyXOEqSeqNroWfmdTXGfBz4eCOJJEmL4jtFJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgqxqu0A0nJy5swZnnrqqQV9zYYNGxgeHu5RIuk1FrpU07Zt27j99tsX9DXbt29n9+7dvPnNb+5RKuk1Frq0ANu2bVvQ+CeffJLdu3f3KI10NtfQJakQFrokFcJCl6RCdC30iLgrIk5GxJEu434lIs5ExHubiydJqqvOGfp+4Ir5BkTEecDfAv/WQCZJ0iJ0LfTMPAS80GXYzcCXgJNNhJIkLdw5r6FHxEXAbwOfqTF2T0QcjojDk5OT53poSdI0Tbwo+kngI5n5areBmbkvM3dk5o6hoaEGDi1JmtLEG4t2APdGBMB64KqIOJOZX27guSVJNZ1zoWfmL07dj4j9wIOWuSQtva6FHhH3AKPA+og4DtwGDABk5t6eppMk1da10DPzurpPlpnXn1MaSdKi+U5RSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQXQs9Iu6KiJMRcWSO/X8QEd+OiPGI+EZEXNJ8TElSN3XO0PcDV8yz/1ng1zNzO/AxYF8DuSRJC7Sq24DMPBQRm+fZ/41pDx8FRhrIJUlaoK6FvkB/DHxlrp0RsQfYA7B27dqGDy31rxdffLH22OHh4R4mUckaK/SI+A2qQn/3XGMycx+dJZmRkZFs6thSv9q6dSsHDhxg8+bNtcZffPHFXHDBBbXHS9M1UugR8VbgTuDKzHy+ieeUSpGZPPvss7XGfv3rX+fGG2/scSKV6pwvW4yINwL3Ae/PzKfOPZIkaTG6nqFHxD3AKLA+Io4DtwEDAJm5F7gVeD3w6YgAOJOZO3oVWJI0uzpXuVzXZf8HgQ82lkiStCi+U1SSCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKkT5hZ45/2M1zzmXWrGq24CIuAt4D3AyM7fNsj+A24GrgJeA6zPzsaaDLsojj8Arr8Dll0NEVSwPPwyDgzA62na6MjnnWiHGx+HgQZiYgDVrYNcu2L693Ux1ztD3A1fMs/9K4E2d2x7gM+ceqwGZVbGMjVWFMlUsY2PVds8am+eca4UYH4cHHoBTp6pv61Onqsfj4+3m6nqGnpmHImLzPEOuBe7OzAQejYi1EbEhM080FXJRIqqzRKgKZWysur9z52tnj2qWc64V4uBBOH367G2nT1fb2zxLb2IN/SLg+9MeH+9s+ykRsSciDkfE4cnJyQYO3cX0gplisfSWc64VYGJiYduXypK+KJqZ+zJzR2buGBoaWooDVr/yTze1FKDecM61AqxZs7DtS6XrkksNPwA2TXs80tnWrunrt1O/8k89Bs8ae8E51wqxa1e1Zj592WVgoNrepiYK/X7gQxFxL7ATmGh9/Ryq4hgcPHv9dmopYHDQYukF51wrxNQ6eb9d5VLnssV7gFFgfUQcB24DBgAycy/wENUli09TXbZ4Q6/CLtjoaHXWOFUkUwVjsfSOc64VYvv29gt8pjpXuVzXZX8CNzWWqGkzi8Ri6T3nXGpF+e8UlaQVwkKXpEJY6JJUiCaucpHUoJdeeomjR4/WGnveeeexZcuWHifScmGhS31k48aN3HHHHbXHX3311QwPD7Nhw4YeptJyYaFLfWbr1q21xx47doxLL720h2m0nLiGLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgoR2dIHD0TED4HvLuEh1wM/WsLjNWm5Zl+uuWH5Zl+uuWH5Zl/q3L+QmRfOtqO1Ql9qEXE4M3e0nWMxlmv25Zoblm/25Zoblm/2fsrtkoskFcJCl6RCrKRC39d2gHOwXLMv19ywfLMv19ywfLP3Te4Vs4YuSaVbSWfoklQ0C12SClFUoUfEXRFxMiKOzLE/IuLvI+LpiPh2RLxjqTPOpUb20YiYiIgnOrdblzrjbCJiU0R8LSK+ExFPRsSfzzKm7+a9Zu5+nfPBiPiviPhWJ/tfzzLmZyPiC505H4uIzS1EnZmpTu7rI+KH0+b8g21knUtEnBcRj0fEg7Psa3/OM7OYG/BrwDuAI3Psvwr4ChDAZcBY25kXkH0UeLDtnLPk2gC8o3N/GHgK2Nrv814zd7/OeQCrO/cHgDHgshlj/hTY27n/PuALyyT39cCn2s46z5/hL4B/mu37oh/mvKgz9Mw8BLwwz5Brgbuz8iiwNiL64qNeamTvS5l5IjMf69x/ETgKXDRjWN/Ne83cfakzj//TeTjQuc28uuFa4HOd+18EdkdELFHEWdXM3bciYgS4GrhzjiGtz3lRhV7DRcD3pz0+zjL5Ie741c6vq1+JiLe0HWamzq+Yb6c685qur+d9ntzQp3Pe+dX/CeAk8NXMnHPOM/MMMAG8fklDzqJGboDf6SzNfTEiNi1twnl9EvhL4NU59rc+5yut0Jezx6j+DYdLgDuAL7cb52wRsRr4EvDhzPxx23nq6pK7b+c8M/8vM98GjADvjIhtLUeqpUbuB4DNmflW4Ku8dsbbqoh4D3AyM7/Zdpb5rLRC/wEw/W/8kc62vpeZP576dTUzHwIGImJ9y7EAiIgBqlL8fGbeN8uQvpz3brn7ec6nZOYp4GvAFTN2/WTOI2IVsAZ4fknDzWOu3Jn5fGb+b+fhnUC/fGDqu4BrIuIYcC+wKyL+ccaY1ud8pRX6/cAHOlddXAZMZOaJtkPVERE/P7UeFxHvpPp/1/oPaCfTZ4GjmfmJOYb13bzXyd3Hc35hRKzt3H8d8JvAf88Ydj/wR5377wUOZufVurbUyT3jtZVrqF7baF1m/lVmjmTmZqoXPA9m5h/OGNb6nK9ayoP1WkTcQ3VlwvqIOA7cRvXCC5m5F3iI6oqLp4GXgBvaSfrTamR/L/AnEXEGeBl4X9s/oB3vAt4PjHfWRgE+CrwR+nre6+Tu1znfAHwuIs6j+kvmnzPzwYj4G+BwZt5P9ZfVP0TE01Qvtr+vvbg/USf3n0XENcAZqtzXt5a2hn6bc9/6L0mFWGlLLpJULAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFeL/Admi1qH7N00UAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for column in contour.collections:\n", " plt.gca().collections.remove(column)\n", " \n", "contour = ax.contourf(\n", " contour_plot_x_data,\n", " contour_plot_y_data,\n", " homomorphic_predictions.round().reshape(contour_plot_x_data.shape),\n", " cmap=\"gray\",\n", " alpha=0.50,\n", ")\n", "display(fig)" ] }, { "cell_type": "markdown", "id": "53ecca94", "metadata": {}, "source": [ "### Enjoy!" ] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 }