mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
881 lines
71 KiB
Plaintext
881 lines
71 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "0fe629d6",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Quantized Logistic Regression\n",
|
||
"\n",
|
||
"Currently, **hdk** only supports unsigned integers up to 7-bits. Nevertheless, we want to evaluate a logistic regression model with it. Luckily, we can make use of **quantization** to overcome this limitation!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d0cfb561",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Let's start by importing some libraries to develop our logistic regression model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "3c1d929c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import torch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "69d25f7c",
|
||
"metadata": {},
|
||
"source": [
|
||
"### And some helpers for visualization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "a89c1a6c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from IPython.display import display"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7729c1de",
|
||
"metadata": {},
|
||
"source": [
|
||
"### We need a dataset, a handcrafted one for simplicity"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "b77a9e82",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"x = torch.tensor([[1, 1], [1, 2], [2, 1], [4, 1], [3, 2], [4, 2]]).float()\n",
|
||
"y = torch.tensor([[0], [0], [0], [1], [1], [1]]).float()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "cc8673ff",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Let's visualize our dataset to get a grasp of it"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "35a98d1a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"plt.ioff()\n",
|
||
"fig, ax = plt.subplots(1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "56703410",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPTklEQVR4nO3df4zkd13H8efruPPHpcgZbqO117v1D1ABKbQr1ED0lCgHmBJjTagVbCO5RKsup4mNEOkpaaIhchQbOC6lOdT1wNAGSgNGImAlhJo9LO2VCmmEKweNt7S5omBMznv7x3eW7q27O7N3szuzn30+ksnO9/v93Hxf/XTvtd/5zMxtqgpJ0sa3ZdQBJEnDYaFLUiMsdElqhIUuSY2w0CWpEVtHdeKdO3fW5OTkqE4vSRvS8ePHv1lVE0sdG1mhT05OMjs7O6rTS9KGlOTkcsdccpGkRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSI9ov9MW/M9XfoSqpUX0LPcnlST6V5ItJHk4yvcSYJHlXkkeTPJjkyrWJu0oHD8KBA0+XeFW3ffDgKFNJQzczA5OTsGVL93VmZtSJ2jeOcz7IFfpZ4A+q6nnA1cBNSZ63aMyrgOf0bvuB9ww15YWogjNn4Lbbni71Awe67TNnvFJXM2ZmYP9+OHmy+7Y+ebLbHoeCadW4znlqlcWW5CPA7VX1iQX73gt8uqqO9ba/BOytqseXe5ypqala838PfWGJz5uehkOHIFnbc0vrZHKyK5TF9uyBr351vdNsDqOc8yTHq2pqqWOrWkNPMgm8GLh/0aHLgK8t2D7V27f4z+9PMptkdm5ubjWnvjBJV94LWeZqzGOPrW6/Lt64zvnAhZ7kEuAu4E1V9a0LOVlVHamqqaqamphY8jcoDdf8FfpCC9fUpQbs3r26/bp44zrnAxV6km10ZT5TVXcvMeTrwOULtnf19o3OwuWW6Wk4d677unBNXWrArbfC9u3n79u+vduvtTGucz7Iu1wCvA94pKrescywe4A39N7tcjXw1Err5+sigR07zl8zP3So296xw2UXNeP66+HIkW79Num+HjnS7dfaGNc57/uiaJKXA/8MPASc6+1+M7AboKoO90r/dmAf8B3gxqpa8RXPdXlRtAt4fnkv3pakDWSlF0W39vvDVfUZYMUGrO6nwk0XFm+NLS5vy1xSo9r/pKgkbRIWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUiL6FnuTOJKeTnFjm+LOSfDTJF5I8nOTG4ceUJPUzyBX6UWDfCsdvAr5YVVcAe4G/SPI9Fx9NkrQafQu9qu4DnlxpCPDMJAEu6Y09O5x4kqRBDWMN/XbgJ4BvAA8B01V1bqmBSfYnmU0yOzc3N4RTS5LmDaPQXwk8APwI8CLg9iQ/sNTAqjpSVVNVNTUxMTGEU0uS5g2j0G8E7q7Oo8BXgB8fwuNKklZhGIX+GPAKgCQ/BPwY8O9DeFxJ0ips7TcgyTG6d6/sTHIKuAXYBlBVh4G3AUeTPAQEuLmqvrlmiSVJS+pb6FV1XZ/j3wB+cWiJJEkXxE+KSlIjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmN6FvoSe5McjrJiRXG7E3yQJKHk/zTcCNKkgYxyBX6UWDfcgeT7ADeDVxTVc8HfnUoySRJq9K30KvqPuDJFYb8GnB3VT3WG396SNkkSaswjDX05wI/mOTTSY4necNyA5PsTzKbZHZubm4Ip5YkzRtGoW8FrgJeA7wS+OMkz11qYFUdqaqpqpqamJgYwqklSfO2DuExTgFPVNW3gW8nuQ+4AvjyEB5bkjSgYVyhfwR4eZKtSbYDLwUeGcLjSpJWoe8VepJjwF5gZ5JTwC3ANoCqOlxVjyT5e+BB4BxwR1Ut+xZHSdLa6FvoVXXdAGPeDrx9KIkkSRfET4pKUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY3oW+hJ7kxyOsmJPuN+KsnZJNcOL54kaVCDXKEfBfatNCDJM4A/B/5hCJkkSRegb6FX1X3Ak32G/S5wF3B6GKEkSat30WvoSS4Dfhl4zwBj9yeZTTI7Nzd3saeWJC0wjBdF3wncXFXn+g2sqiNVNVVVUxMTE0M4tSRp3tYhPMYU8IEkADuBVyc5W1UfHsJjS5IGdNGFXlU/On8/yVHgXstcktZf30JPcgzYC+xMcgq4BdgGUFWH1zSdJGlgfQu9qq4b9MGq6oaLSiNJumB+UlSSGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJakTfQk9yZ5LTSU4sc/z6JA8meSjJZ5NcMfyYkqR+BrlCPwrsW+H4V4CfraqfBN4GHBlCLknSKm3tN6Cq7ksyucLxzy7Y/Bywawi5JEmrNOw19N8EPr7cwST7k8wmmZ2bmxvyqSVpcxtaoSf5ObpCv3m5MVV1pKqmqmpqYmJiWKeWJDHAkssgkrwQuAN4VVU9MYzHlCStzkVfoSfZDdwNvL6qvnzxkSRJF6LvFXqSY8BeYGeSU8AtwDaAqjoMvBV4NvDuJABnq2pqrQJLkpY2yLtcrutz/I3AG4eWSJJ0QfykqCQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUiPYLvWrlbQ2fcy6NRN9CT3JnktNJTixzPEneleTRJA8muXL4MS/QwYNw4MDThVLVbR88OMpUbXPOtUnMzMDkJGzZ0n2dmRl1osGu0I8C+1Y4/irgOb3bfuA9Fx9rCKrgzBm47banC+bAgW77zBmvGteCc65NYmYG9u+Hkye7b+uTJ7vtkZd6VfW9AZPAiWWOvRe4bsH2l4BL+z3mVVddVWvu3Lmq6emqbs672/R0t19rwznXJrBnz/nf4vO3PXvW/tzAbC3Tq6kBrpqSTAL3VtULljh2L/BnVfWZ3vY/AjdX1ewSY/fTXcWze/fuq06ePHkBP4JWqap7TjTv3DlI1v68m5lzrsZt2bL0E86k+3ZfS0mOV9XUkrnW9tTnq6ojVTVVVVMTExPrccLuKf9CC9d3NXzOuTaB3btXt3+9DKPQvw5cvmB7V2/faC1cv52e7n5sTk+fv76r4XLOtUnceits337+vu3bu/2jtHUIj3EP8DtJPgC8FHiqqh4fwuNenAR27OgK5dChbvvQoe7Yjh0uAawF51ybxPXXd1/f8hZ47LHuyvzWW5/ePyp919CTHAP2AjuB/wBuAbYBVNXhJAFup3snzHeAG5daP19samqqZmf7Drt4VecXyeJtDZ9zLq2ZldbQ+16hV9V1fY4XcNMFZlt7i4vEYll7zrk0Eu1/UlSSNgkLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJasRAv7FoTU6czAHr8CuLvmsn8M11PN8wbdTsGzU3bNzsGzU3bNzs6517T1Ut+RuCRlbo6y3J7HL/5OS426jZN2pu2LjZN2pu2LjZxym3Sy6S1AgLXZIasZkK/cioA1yEjZp9o+aGjZt9o+aGjZt9bHJvmjV0SWrdZrpCl6SmWeiS1IimCj3JnUlOJzmxzPEkeVeSR5M8mOTK9c64nAGy703yVJIHere3rnfGpSS5PMmnknwxycNJppcYM3bzPmDucZ3z70vyL0m+0Mv+J0uM+d4kH+zN+f1JJkcQdXGmQXLfkGRuwZy/cRRZl5PkGUn+Ncm9Sxwb/ZxXVTM34GeAK4ETyxx/NfBxIMDVwP2jzryK7HuBe0edc4lclwJX9u4/E/gy8Lxxn/cBc4/rnAe4pHd/G3A/cPWiMb8NHO7dfx3wwQ2S+wbg9lFnXeG/4feBv13q+2Ic5rypK/Squg94coUhrwX+qjqfA3YkuXR90q1sgOxjqaoer6rP9+7/J/AIcNmiYWM37wPmHku9efyv3ua23m3xuxteC7y/d/9DwCuSZJ0iLmnA3GMryS7gNcAdywwZ+Zw3VegDuAz42oLtU2yQv8Q9P917uvrxJM8fdZjFek8xX0x35bXQWM/7CrlhTOe899T/AeA08ImqWnbOq+os8BTw7HUNuYQBcgP8Sm9p7kNJLl/fhCt6J/CHwLlljo98zjdboW9kn6f7NxyuAP4S+PBo45wvySXAXcCbqupbo84zqD65x3bOq+p/q+pFwC7gJUleMOJIAxkg90eByap6IfAJnr7iHakkvwScrqrjo86yks1W6F8HFv7E39XbN/aq6lvzT1er6mPAtiQ7RxwLgCTb6EpxpqruXmLIWM57v9zjPOfzquoM8Clg36JD353zJFuBZwFPrGu4FSyXu6qeqKr/6W3eAVy1ztGW8zLgmiRfBT4A/HySv1k0ZuRzvtkK/R7gDb13XVwNPFVVj4861CCS/PD8elySl9D9vxv5X9BepvcBj1TVO5YZNnbzPkjuMZ7ziSQ7eve/H/gF4N8WDbsH+I3e/WuBT1bv1bpRGST3otdWrqF7bWPkquqPqmpXVU3SveD5yar69UXDRj7nW9fzZGstyTG6dybsTHIKuIXuhReq6jDwMbp3XDwKfAe4cTRJ/78Bsl8L/FaSs8B/A68b9V/QnpcBrwce6q2NArwZ2A1jPe+D5B7XOb8UeH+SZ9D9kPm7qro3yZ8Cs1V1D90Pq79O8ijdi+2vG13c7xok9+8luQY4S5f7hpGlHcC4zbkf/ZekRmy2JRdJapaFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhrxf09l6LOTuZAtAAAAAElFTkSuQmCC\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"x_min, x_max = x[:, 0].min(), x[:, 0].max()\n",
|
||
"x_deviation = x_max - x_min\n",
|
||
"\n",
|
||
"y_min, y_max = x[:, 1].min(), x[:, 1].max()\n",
|
||
"y_deviation = y_max - y_min\n",
|
||
"\n",
|
||
"ax.set_xlim(x_min - (x_deviation / 10), x_max + (x_deviation / 10))\n",
|
||
"ax.set_ylim(y_min - (y_deviation / 10), y_max + (y_deviation / 10))\n",
|
||
"\n",
|
||
"ax.scatter(\n",
|
||
" np.array([x_i[0] for x_i, y_i in zip(x, y) if y_i == 0], dtype=np.float32),\n",
|
||
" np.array([x_i[1] for x_i, y_i in zip(x, y) if y_i == 0], dtype=np.float32),\n",
|
||
" marker=\"x\",\n",
|
||
" color=\"red\",\n",
|
||
")\n",
|
||
"ax.scatter(\n",
|
||
" np.array([x_i[0] for x_i, y_i in zip(x, y) if y_i == 1], dtype=np.float32),\n",
|
||
" np.array([x_i[1] for x_i, y_i in zip(x, y) if y_i == 1], dtype=np.float32),\n",
|
||
" marker=\"o\",\n",
|
||
" color=\"blue\",\n",
|
||
")\n",
|
||
"display(fig)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e31b82e8",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Now, we need a model so let's define it"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "cc5e72a2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class Model(torch.nn.Module):\n",
|
||
" def __init__(self, n):\n",
|
||
" super(Model, self).__init__()\n",
|
||
" self.fc = torch.nn.Linear(n, 1)\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" output = torch.sigmoid(self.fc(x))\n",
|
||
" return output"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "cefd8346",
|
||
"metadata": {},
|
||
"source": [
|
||
"### And create one\n",
|
||
"\n",
|
||
"The main purpose of this tutorial is not to train a logistic regression model but to use it homomorphically. So we will not discuss about how the model is trained."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "b9879f4d",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch: 1 | Loss: 0.530019998550415\n",
|
||
"Epoch: 101 | Loss: 0.1248268187046051\n",
|
||
"Epoch: 201 | Loss: 0.07593712955713272\n",
|
||
"Epoch: 301 | Loss: 0.05418260768055916\n",
|
||
"Epoch: 401 | Loss: 0.04199932515621185\n",
|
||
"Epoch: 501 | Loss: 0.03424343094229698\n",
|
||
"Epoch: 601 | Loss: 0.028883913531899452\n",
|
||
"Epoch: 701 | Loss: 0.024963364005088806\n",
|
||
"Epoch: 801 | Loss: 0.021973103284835815\n",
|
||
"Epoch: 901 | Loss: 0.019618362188339233\n",
|
||
"Epoch: 1001 | Loss: 0.017716625705361366\n",
|
||
"Epoch: 1101 | Loss: 0.01614907570183277\n",
|
||
"Epoch: 1201 | Loss: 0.014835075475275517\n",
|
||
"Epoch: 1301 | Loss: 0.013717765919864178\n",
|
||
"Epoch: 1401 | Loss: 0.01275621633976698\n",
|
||
"Epoch: 1501 | Loss: 0.011920095421373844\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"model = Model(x.shape[1])\n",
|
||
"\n",
|
||
"optimizer = torch.optim.SGD(model.parameters(), lr=1)\n",
|
||
"criterion = torch.nn.BCELoss()\n",
|
||
"\n",
|
||
"epochs = 1501\n",
|
||
"for e in range(1, epochs + 1):\n",
|
||
" optimizer.zero_grad()\n",
|
||
"\n",
|
||
" out = model(x)\n",
|
||
" loss = criterion(out, y)\n",
|
||
"\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" if e % 100 == 1 or e == epochs:\n",
|
||
" print(\"Epoch:\", e, \"|\", \"Loss:\", loss.item())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "01cfc83f",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Time to make some predictions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "78356d37",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"contour_plot_x_data = np.linspace(x_min - (x_deviation / 10), x_max + 2 * (x_deviation / 10), 250)\n",
|
||
"contour_plot_y_data = np.linspace(y_min - (y_deviation / 10), y_max + 2 * (y_deviation / 10), 250)\n",
|
||
"contour_plot_x_data, contour_plot_y_data = np.meshgrid(contour_plot_x_data, contour_plot_y_data)\n",
|
||
"\n",
|
||
"inputs = np.stack((contour_plot_x_data.flatten(), contour_plot_y_data.flatten()), axis=1)\n",
|
||
"predictions = model(torch.tensor(inputs).float()).detach().numpy()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "58160140",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Let's visualize our predictions to see how our model performs"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "2a623999",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAT40lEQVR4nO3df2xdZ33H8c+XxluYkyURQcyuw7I/6MbhhvIjI52oNi9oS+kg1RSm0W2wVkORti4b2qSh8Uerjb8mNITnCqKoVKF3rDBBxeqqXYdy6aKV1VNoC27syapjg22sBRzfkNjNj+t+98e5Bse1fa/t4/uc+9z3S7Jy7zlPfD59mnxy/Jxz7zV3FwCg+b0udAAAQDYodACIBIUOAJGg0AEgEhQ6AERiS6gDt7e3+65du0IdHjlz5coV7dixQ9u2bdNNN90UOg6QWy+88MKP3P2Ny+0LVui7du3SsWPHQh0eOTM4OKjDhw/r9ttv1/bt20PHAXKrvb39eyvtY8kFuZAkiebn5zU+Pq6pqanQcYCmRKEjN0ZHR1UsFnX16tXQUYCmRKEDQCQodACIBIUOAJGg0JE7165dCx0BaEoUOnKlXC6rUqloaGgodBSg6VDoyJUkSdTb2xs6BtCUKHQAiASFDgCRoNABIBIUOnLJ3TU8PBw6BtBU4i/0pZ+Zymeo5t7ChdFKpcL7ugBrUPPdFs1sj6RHJL1Jkks64e49S8aYpB5Jd0qak3SPuz+ffdw1euYZ6coV6dAhySwt86eflrZulbq7Q6fDKpIkUalU0pEjR0JHaQoDA1KpJF28KO3YIR08KO3bFzpV3PI45/WcoVck/bW7J5Juk3SfmSVLxrxf0luqX0clfT7TlOvhnpZ5f39a4gtl3t+fbudMHZEYGJD6+qRyOf1jXS6nzwcGQieLV17nvOYZurtPSZqqPr5kZkOSbpY0uGjYXZIecXeX9JyZ7TSzjurvDcMsPTOX0hLv708fHzjw0zN2IAKlknT9+o3brl9Pt4c+Y4xVXud8TWvoZrZX0jsl9S/ZdbOk8UXPJ6rblv7+o2Z2xszOzM7OrjHqOiwu9QWUedOYnJzUzMyMxsbGQkfJtYsX17YdG5fXOa+70M1sm6SvSfq4u/94PQdz9xPuvt/d97e3t6/nW6z1gOkyy2ILyy/Ivc7OTvX09OiVV14JHSXXduxY23ZsXF7nvK5CN7M2pWX+JXd/bJkhk5L2LHreVd0WzuI18wMHpPvvT39dvKaO3Hvd6+K/EWujDh6U2tpu3NbWlm7H5sjrnNdzl4tJ+oKkIXf/zArDHpf052b2ZUkHJF0Mun4upcsqW7feuGa+sPyydSvLLojGwppt3u64iFle57yeD4l+r6SPSBowsxer2z4p6c2S5O7HJT2p9JbFl5Xetnhv5knXo7s7PRNfKO+FUqfMEZl9+8KXSavJ45zXc5fLf0latQGrd7fcl1WoTC0tb8ocQKTqOUMHgkmSRCMjI3J37dq1Sx0dHaEjAbnFFSfk3ujoqEqlUugYQO5R6AAQCQodACJBoQNAJCh0NI1yuRw6ApBrFDqagrtrZGSED70AVkGho2n09fWFjgDkGoUOAJGg0AEgEhQ6mkqlUmEdHVgBhY6mUSgU1Nvbq+HhYT48GlgGhY6mkiQJn2AErIBCB4BIUOgAEAkKHQAiQaGj6UxOTmpmZoa7XYAlKHQ0nc7OTvX09ISOAeQOhQ4AkaDQASASFDoARIJCR9OqVCq6dOlS6BhAblDoaEqFQkGlUknj4+OUOlBFoaNpubvOnTsXOgaQGxQ6AESCQgeASFDoABAJCh1NjwujQKpmoZvZw2Z23sxeWmH/DjPrM7PvmNlZM7s3+5jA8kZHR1UqlTQ9PR06ChBcPWfoJyXdscr++yQNuvutkrol/aOZ/czGowH1mZycDB0ByIWahe7upyVdWG2IpO1mZpK2VcdWsokHAKjXlgy+x4OSHpf0A0nbJf2+u7+63EAzOyrpqCTt3Lkzg0MDABZkcVH0kKQXJXVKeoekB83s55cb6O4n3H2/u+9vb2/P4NCAdOHCBc3NzWloaIiLo2hpWRT6vZIe89TLkkYl/UoG3xeoS6FQUG9vLx8ejZaXRaF/X9L7JMnM3iTplyXxemwAaLCaa+hm9qjSu1d2m9mEpAcktUmSux+X9ClJJ81sQJJJ+oS7/2jTEgMAllWz0N397hr7fyDptzNLBABYF14piigkSaL5+Xldvnw5dBQgGAod0Xj22Wc1MzPDxVG0LAod0ejs7FSxWAwdAwiGQgeASFDoABAJCh0AIkGhIzpzc3OampoKHQNoOAodUVm4MFoul0NHARqOQkd0yuUyty6iJVHoABAJCh0AIkGhA0AkKHREaX5+XoODg9ztgpZCoSM6SZJodHRUxWJRV69eDR0HaBgKHQAiQaEDQCQodETt2rVroSMADUOhI1rlclmVSkVDQ0OhowANQaEjWkmSqLe3N3QMoGEodACIBIUOAJGg0AEgEhQ6oufuvPsiWgKFjqglSaKenh4+9AItgUJH9AqFgkqlUugYwKaj0AEgEhQ6AESiZqGb2cNmdt7MXlplTLeZvWhmZ83sP7ONCACoRz1n6Ccl3bHSTjPbKelzkg67+9sk/V4myYCMzczMcLcLolaz0N39tKQLqwz5A0mPufv3q+PPZ5QNyIy7q6enhzfrQtSyWEO/RdIuM3vGzL5tZh9daaCZHTWzM2Z2ZnZ2NoNDAwAWbMnoe7xb0vskvV7Sf5vZc+4+vHSgu5+QdEKSurq6PINjAwCqsij0CUnT7j4radbMTku6VdJrCh0AsHmyWHL5N0m3m9kWM/s5SQck8QbUyKX5+XldunQpdAxgU9Q8QzezRyV1S9ptZhOSHpDUJknuftzdh8zs3yV9V9Krkh5y9xVvcQRCKRQKGhkZkbtr165d6ujoCB0JyFTNQnf3u+sY82lJn84kEbCJRkdHNTY2piNHjoSOAmSOV4oCQCQodACIBIWOlsSFUcQoi9sWgaZy9uxZ7d27V5J0yy23hA0DZIgzdLScJEnU19cXOgaQOQodACJBoQNAJCh0AIgEhY6WValUNDzMWw4hHhQ6WlKhUFBvb6+Gh4e5hRHRoNDRspIkCR0ByBSFDgCRoNABIBIUOlre5cuXQ0cAMkGho6U9++yzmpmZ0djYWOgowIZR6GhpnZ2dKhaLoWMAmaDQASASFDoARIJCR8u7cOGC5ubmeIERmh6FjpZXKBRUKpU0Pj5OqaOpUeiAJHfXuXPnQscANoRCB4BIUOgAEAkKHQAiQaEDi0xMTHBhFE2LQgeqRkdHderUKU1PT4eOAqwLhQ4sMjk5GToCsG41C93MHjaz82b2Uo1xv2pmFTP7UHbxAAD1qucM/aSkO1YbYGY3SfoHSf+RQSYAwDrULHR3Py3pQo1hxyR9TdL5LEIBANZuw2voZnazpN+V9Pk6xh41szNmdmZ2dnajhwYyt/C+LkNDQ6GjAGuWxUXRz0r6hLu/Wmugu59w9/3uvr+9vT2DQwPZKhQK6u3t1djYGLcvoulsyeB77Jf0ZTOTpN2S7jSzirt/PYPvDQCo04YL3d1/aeGxmZ2U9ARlDgCNV7PQzexRSd2SdpvZhKQHJLVJkrsf39R0AIC61Sx0d7+73m/m7vdsKA2QE/Pz85qentb27dtDRwHqxitFgSWSJFFfX5/m5uY0NjYWOg5QNwodWEahUFCxWAwdA1gTCh0AIkGhA0AkKHRgFXNzc5qamgodA6gLhQ6soLOzU8ViUeVyOXQUoC4UOrAKyhzNhEIHgEhQ6AAQCQodACJBoQM1uLsGBwe52wW5R6EDq0iSRKdOnVKpVAodBaiJQgeASFDoABAJCh0AIkGhA3XiM0aRdxQ6UIezZ8+qUqloeHg4dBRgRRQ6UIckSdTT0xM6BrAqCh0AIkGhA0AkKHQAiASFDqxBpVLhg6ORWxQ6UKdCoaCenh4+xQi5RaEDa1AoFHhfF+QWhQ4AkaDQASASFDqwDjMzM1wcRe7ULHQze9jMzpvZSyvs/0Mz+66ZDZjZt8zs1uxjAvnh7urp6dG1a9dCRwFuUM8Z+klJd6yyf1TSb7j7PkmfknQig1wAgDXaUmuAu582s72r7P/WoqfPSerKIBcAYI2yXkP/E0lPrbTTzI6a2RkzOzM7O5vxoQGgtdU8Q6+Xmf2m0kK/faUx7n5C1SWZrq4uz+rYQAjz8/OhIwA3yKTQzeztkh6S9H53n87iewJ5VigUNDIyInfXnj17tH379tCRgI0vuZjZmyU9Jukj7s67/6NljI6O8qpR5ErNM3Qze1RSt6TdZjYh6QFJbZLk7scl3S/pDZI+Z2aSVHH3/ZsVGACwvHrucrm7xv6PSfpYZokAAOvCK0UBIBIUOrBBExMToSMAkih0YEPcXSMjIxoe5n4AhEehAxvU19cXOgIgiUIHgGhQ6AAQCQodyEClUmEdHcFR6MAGFQoF9fb2qlKp6NKlS6HjoIVR6EAGkiTRuXPnQsdAi6PQASASFDoARIJCB4BIUOhARsbGxjQ+Pq6xsbHQUdCiKHQgI+6uYrEYOgZaGIUOAJGg0AEgEhQ6AESCQgcyNjc3xytGEQSFDmSos7NTpVJJExMTlDoajkIHMnb27FluXUQQFDoARIJCB4BIUOgAEAkKHdgE8/PzXBhFw1HoQMaSJNHo6KhOnTql6enp0HHQQih0YJNMTk6GjoAWQ6EDQCTiL3T31Z8je8w5EMSWWgPM7GFJH5B03t0Ly+w3ST2S7pQ0J+ked38+66Dr8swz0pUr0qFDkllaLE8/LW3dKnV3h04XJ+b8BnNzcxoaGtJb3/rW0FGQsYEBqVSSLl6UduyQDh6U9u0Lm6meM/STku5YZf/7Jb2l+nVU0uc3HisD7mmx9PenhbJQLP396XbOGrPHnN+gs7NTvb29oWNgEwwMSH19Urmc/rEul9PnAwNhc9U8Q3f302a2d5Uhd0l6xN1d0nNmttPMOtx9KquQ62KWniVKaaH096ePDxz46dkjssWco0WUStL16zduu3493R7yLD2LNfSbJY0vej5R3fYaZnbUzM6Y2ZnZ2dkMDl3D4oJZQLFsLuYcLeDixbVtb5SGXhR19xPuvt/d97e3tzfigOmP/IstLAVgczDnaAE7dqxte6NkUeiTkvYset5V3RbW4vXbAwek++9Pf128votsMefLcndNTYVdgUS2Dh6U2tpu3NbWlm4PKYtCf1zSRy11m6SLwdfPpfRH/K1bb1y/PXQofb51K0sAm4E5f40kSVQsFjUzM0OpR2TfPumDH5R27kz/WO/cmT4PfZeLeY2zJjN7VFK3pN2S/k/SA5LaJMndj1dvW3xQ6Z0wc5LudfcztQ7c1dXlx44d21D4urjfWCRLnyN7zPlrmJmOHDmijo6O0FHQ5Nrb27/t7vuX21fPXS5319jvku5bZ7bNt7RIWrxYGoI5B4KI/5WiANAiKHQAiASFDjQIF0ax2Sh0oAHcXcVikQ+8wKai0IEGKZfLoSMgchQ6AESCQgeASFDoABAJCh1ooEqlosHBQe52waag0IEGSZJEp06dUqlUCh0FkaLQASASFDoARKLmuy1u2oHNfijpew085G5JP2rg8bLUrNmbNbfUvNmbNbfUvNkbnfsX3f2Ny+0IVuiNZmZnVnrLybxr1uzNmltq3uzNmltq3ux5ys2SCwBEgkIHgEi0UqGfCB1gA5o1e7Pmlpo3e7Pmlpo3e25yt8waOgDErpXO0AEgahQ6AEQiqkI3s4fN7LyZvbTCfjOzfzKzl83su2b2rkZnXEkd2bvN7KKZvVj9ur/RGZdjZnvM7JtmNmhmZ83sL5cZk7t5rzN3Xud8q5n9j5l9p5r975YZ87Nm9pXqnPeb2d4AUZdmqif3PWb2w0Vz/rEQWVdiZjeZ2Qtm9sQy+8LPubtH8yXp1yW9S9JLK+y/U9JTkkzSbZL6Q2deQ/ZuSU+EzrlMrg5J76o+3i5pWFKS93mvM3de59wkbas+bpPUL+m2JWP+TNLx6uMPS/pKk+S+R9KDobOu8t/wV5L+Zbk/F3mY86jO0N39tKQLqwy5S9IjnnpO0k4z62hMutXVkT2X3H3K3Z+vPr4kaUjSzUuG5W7e68ydS9V5vFx92lb9Wnp3w12Svlh9/FVJ7zMza1DEZdWZO7fMrEvS70h6aIUhwec8qkKvw82Sxhc9n1CT/CWu+rXqj6tPmdnbQodZqvoj5juVnnktlut5XyW3lNM5r/7o/6Kk85K+4e4rzrm7VyRdlPSGhoZcRh25JelIdWnuq2a2p7EJV/VZSX8j6dUV9gef81Yr9Gb2vNL3cLhVUq+kr4eNcyMz2ybpa5I+7u4/Dp2nXjVy53bO3X3e3d8hqUvSe8ysEDhSXerI3Sdpr7u/XdI39NMz3qDM7AOSzrv7t0NnWU2rFfqkpMX/4ndVt+Weu/944cdVd39SUpuZ7Q4cS5JkZm1KS/FL7v7YMkNyOe+1cud5zhe4e1nSNyXdsWTXT+bczLZI2iFpuqHhVrFSbnefdver1acPSXp3g6Ot5L2SDpvZmKQvSzpoZv+8ZEzwOW+1Qn9c0kerd13cJumiuzfFR8eY2S8srMeZ2XuU/r8L/he0mukLkobc/TMrDMvdvNeTO8dz/kYz21l9/HpJvyXpf5cMe1zSH1cff0hSyatX60KpJ/eSayuHlV7bCM7d/9bdu9x9r9ILniV3/6Mlw4LP+ZZGHmyzmdmjSu9M2G1mE5IeUHrhRe5+XNKTSu+4eFnSnKR7wyR9rTqyf0jSn5pZRdIrkj4c+i9o1XslfUTSQHVtVJI+KenNUq7nvZ7ceZ3zDklfNLOblP4j86/u/oSZ/b2kM+7+uNJ/rIpm9rLSi+0fDhf3J+rJ/RdmdlhSRWnue4KlrUPe5pyX/gNAJFptyQUAokWhA0AkKHQAiASFDgCRoNABIBIUOgBEgkIHgEj8P5U402gz8GSLAAAAAElFTkSuQmCC\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"contour = ax.contourf(\n",
|
||
" contour_plot_x_data,\n",
|
||
" contour_plot_y_data,\n",
|
||
" predictions.round().reshape(contour_plot_x_data.shape),\n",
|
||
" cmap=\"gray\",\n",
|
||
" alpha=0.50,\n",
|
||
")\n",
|
||
"display(fig)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d3f39faa",
|
||
"metadata": {},
|
||
"source": [
|
||
"### As a bonus let's inspect the model parameters"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "7fa65211",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[[4.54424667]\n",
|
||
" [2.37960148]]\n",
|
||
"-14.69552993774414\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"w = np.array(model.fc.weight.flatten().tolist()).reshape((-1, 1))\n",
|
||
"b = model.fc.bias.flatten().tolist()[0]\n",
|
||
"\n",
|
||
"print(w)\n",
|
||
"print(b)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "544d6e34",
|
||
"metadata": {},
|
||
"source": [
|
||
"They are floating point numbers and we can't directly work with them!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "abf310f2",
|
||
"metadata": {},
|
||
"source": [
|
||
"### So, let's abstract quantization\n",
|
||
"\n",
|
||
"Here is a quick summary of quantization. We have a range of values and we want to represent them using small number of bits (n). To do this, we split the range into 2^n sections and map each section to a value. Here is a visualization of the process!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"id": "d3ab2aa2",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/svg+xml": [
|
||
"<svg content=\"<mxfile host="app.diagrams.net" modified="2021-08-13T09:47:25.144Z" agent="5.0 (X11)" etag="5QhM0DGu1eUjmjeXuyFL" version="14.9.6" type="device"><diagram id="6rZNNX4_K12e_kCXuZoG" name="Page-1">7Zzdb5s6FMD/mkjdw66MHZL0sUl376SrStM6rc8euAGNYAZOk/avnw2YD5t8QAmhJA+tyLFzbM7v2D7HdjtCi9X2vxAHzgO1iTeCwN6O0P0IQhPd8t9C8JoI0MRMBMvQtRORkQse3TeSCkEqXbs2iUoVGaUec4Oy0KK+TyxWkuEwpJtytWfqlVsN8JJogkcLe7r0ybWZk0hnJsjlX4m7dGTLBkhLVlhWTgWRg226KYjQlxFahJSy5Gm1XRBP2E7aJfnevztKs46FxGfHfOFpYrx9//Hw//Lpq8mc+yd4H5mfUy0v2FunL5x2lr1KCxDfvhOG5J8sD0eRa43QPGI4ZLrYYSuPCwz+mOghtmbevL9GZgXuPYSuCAtfeZVNbmdpZqdgYikLiYeZ+1JWj1Pcy0xd1sI36vKGIUg905ikelLHhGNQVhHRdWiR9FtFuyqKTHhAETfVkjBNEX8ovHYuirHVQAhrIfSpTz4WKAgU+5oNQWmKULeg0NBBjdsCpSrqGNR46KBmbYFSFXUMyhw4KKQuLU1BaYo6BjUZOig1mGgMSlXUMajpwEGN2womNEUdg5oNHVRbwYSmqGNQt0MH1VYwoSnqGJTccaidC38gWFoKBKfNYGmRn6poByxuPfxaqBaICtGeDquBC5ju75e6fpbr84ekB+16DmzoOR9/F0Xzg+ltM4dCqmeqik49+utl50NieND0RzNUl9quGdZL3AfF8JDpj2Z4aECfmmG9nH5QDNuaS7XcpGuGVen+xOPWmtvuC39ciseV699sP8kC3lChTEPOyJaVaUYspL/Jgno0zOPmZ9fzFBH23KUvXIJDJlw+fyEhcy3s3aUFK9e2RTPzjeMy8hhgS7S5CXHAZSFd+zYRLwuybgkFZFvXh3YFKNJ5Cj42rvAxCHa7U4lfbVh6yr/C25jMUBloA+3sDBpl8zaOnNgsRpmKkH/DjFvajyUQoIyVPMOEH2A+hLdlTAiMm82HqiI4UxSdej6s2gMozIcFzJM/a3EePH+mPvscxafhd7yCAYJtTEyW5xNlokfUf5eiBz7U8qk4UVduQp2he/YCjI7i7DZR9ytspVdHmqKXE6XoeHqfwmhr8ZqqI3J23olTKr6OrFOPLOM6sk45stQc/Pwjq97NoAsJSdRz1MYhiXYg23FIImle8e47fW2OV83yusbb6ArL4PG2lVCoijrHW+8+xZA2ybTjXDj+ZzJphlE7cazQdWqSVRcuGsRtWpx173LWPo9hCmFbUjInbEOIrxcsqB8Rax3D0Qp/ig5GO6O3XsZpo/fHZaq7GdPcRc4VmqGqtbtnOcMwkh5Y8x16OQpOka2Me5atoKrT25551DCGBLoOieoh0betMaSnADcbR6z7QPzEEgD7espzis1LJdpD8n5ZgZABO0Wkn3WDy+Fh9O0UDh11bF17OoVV4Xky/EQqLfrM1cR/cdh4eThl1/5cZ/zq1GByVGpgGJ36cAv5ZKVXbOWasdcX3ul5ENwIMV4JgP6vKEgKlc/cduL3W1BR9KnotmBnby/FR9U1L5tziw7a7SRbddXh/TNZ5R7IdebaEavKCwmZV5hnXnkPXJDIpqAL52Ya/eImG7ty289t2rPxNq463NW4vZFQZPxBsvetbggfHZWKdVqmm5fuCerf3xjy4KjgCbN2PIF/zP+3SHJckf+DFvTlLw==</diagram></mxfile>\" height=\"195px\" version=\"1.1\" viewBox=\"-0.5 -0.5 420 195\" width=\"420px\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"><defs/><g><path d=\"M 14.37 84 L 361.63 84\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 9.12 84 L 16.12 80.5 L 14.37 84 L 16.12 87.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 366.88 84 L 359.88 87.5 L 361.63 84 L 359.88 80.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 48 94 L 48 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 88 94 L 88 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 128 94 L 128 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 168 94 L 168 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 208 94 L 208 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 248 94 L 248 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 288 94 L 288 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 328 94 L 328 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 48 71 L 60.93 58.07 Q 68 51 78 51 L 98 51 Q 108 51 115.07 58.07 L 123.5 66.5\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 127.21 70.21 L 119.78 67.73 L 123.5 66.5 L 124.73 62.78 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 134.37 123 L 141.63 123\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 129.12 123 L 136.12 119.5 L 134.37 123 L 136.12 126.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 146.88 123 L 139.88 126.5 L 141.63 123 L 139.88 119.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 154.37 123 L 181.63 123\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 149.12 123 L 156.12 119.5 L 154.37 123 L 156.12 126.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 186.88 123 L 179.88 126.5 L 181.63 123 L 179.88 119.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 194.37 123 L 221.63 123\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 189.12 123 L 196.12 119.5 L 194.37 123 L 196.12 126.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 226.88 123 L 219.88 126.5 L 221.63 123 L 219.88 119.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 234.37 123 L 241.63 123\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 229.12 123 L 236.12 119.5 L 234.37 123 L 236.12 126.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 246.88 123 L 239.88 126.5 L 241.63 123 L 239.88 119.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"108\" y=\"94\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 109px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>min(x)</div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"128\" y=\"108\">min(x)</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"228\" y=\"94\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 229px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">max(x)</div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"248\" y=\"108\">max(x)</text></switch></g><path d=\"M 138 148 L 138 128\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-dasharray=\"2 6\" stroke-miterlimit=\"10\" stroke-width=\"2\"/><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"118\" y=\"152\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 119px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 0<br style=\"font-size: 10px\"/></font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"138\" y=\"165\">Map...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"148\" y=\"152\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 149px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 1<br style=\"font-size: 10px\"/></font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"168\" y=\"165\">Map...</text></switch></g><path d=\"M 168 148 L 168 128\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-dasharray=\"2 6\" stroke-miterlimit=\"10\" stroke-width=\"2\"/><path d=\"M 208 148 L 208 128\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-dasharray=\"2 6\" stroke-miterlimit=\"10\" stroke-width=\"2\"/><path d=\"M 238 148 L 238 128\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-dasharray=\"2 6\" stroke-miterlimit=\"10\" stroke-width=\"2\"/><path d=\"M 294.37 68.66 L 321.63 68.66\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 289.12 68.66 L 296.12 65.16 L 294.37 68.66 L 296.12 72.16 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 326.88 68.66 L 319.88 72.16 L 321.63 68.66 L 319.88 65.16 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"288\" y=\"18.66\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 29px; margin-left: 289px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><font style=\"font-size: 10px\">Distance<br/>Between<br/>Consecutive<br/>Values</font></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"308\" y=\"32\">Distan...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"188\" y=\"152\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 189px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 2</font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"208\" y=\"165\">Map...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"218\" y=\"152\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 219px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 3</font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"238\" y=\"165\">Map...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"120\" x=\"128\" y=\"174\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 184px; margin-left: 129px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">(when n = 2)</div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"188\" y=\"187\">(when n = 2)</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"28\" y=\"94\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 29px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">0</div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"48\" y=\"107\">0</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"110\" x=\"308\" y=\"18.66\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 108px; height: 1px; padding-top: 29px; margin-left: 309px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div><font style=\"font-size: 12px\">= 1 / scale</font></div><div><font style=\"font-size: 12px\">= 1 / q</font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"363\" y=\"32\">= 1 / scale...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"140\" x=\"128\" y=\"24\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 138px; height: 1px; padding-top: 34px; margin-left: 129px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><font style=\"font-size: 12px\">x =</font><font style=\"font-size: 12px\"> (x + zp ) / q </font></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"198\" y=\"37\">x = (x + zp ) / q </text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"167\" y=\"29\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 168px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div><font style=\"font-size: 10px\">q</font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"187\" y=\"42\">q</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"199\" y=\"29\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 200px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>x</div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"219\" y=\"42\">x</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"227\" y=\"29\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 228px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>x</div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"247\" y=\"42\">x</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"80\" x=\"48\" y=\"28\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 78px; height: 1px; padding-top: 38px; margin-left: 49px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>zero point<br/></div><div>zp = 2</div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"88\" y=\"41\">zero point...</text></switch></g></g><switch><g requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"/><a target=\"_blank\" transform=\"translate(0,-5)\" xlink:href=\"https://www.diagrams.net/doc/faq/svg-export-text-problems\"><text font-size=\"10px\" text-anchor=\"middle\" x=\"50%\" y=\"100%\">Viewer does not support full SVG 1.1</text></a></switch></svg>"
|
||
],
|
||
"text/plain": [
|
||
"<IPython.core.display.SVG object>"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from IPython.display import SVG\n",
|
||
"SVG(filename=\"figures/QuantizationVisualized.svg\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e4314038",
|
||
"metadata": {},
|
||
"source": [
|
||
"If you want to learn more, head to https://intellabs.github.io/distiller/algo_quantization.html"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "a8bab855",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class QuantizationParameters:\n",
|
||
" def __init__(self, q, zp, n):\n",
|
||
" # q = scale factor = 1 / distance between consecutive values\n",
|
||
" # zp = zero point which is used to determine the beginning of the quantized range\n",
|
||
" # (quantized 0 = the beginning of the quantized range = zp * distance between consecutive values)\n",
|
||
" # n = number of bits\n",
|
||
" \n",
|
||
" # e.g.,\n",
|
||
" \n",
|
||
" # n = 2\n",
|
||
" # zp = 2\n",
|
||
" # q = 0.66\n",
|
||
" # distance between consecutive values = 1 / q = 1.5151\n",
|
||
" \n",
|
||
" # quantized 0 = zp / q = zp * distance between consecutive values = 3.0303\n",
|
||
" # quantized 1 = quantized 0 + distance between consecutive values = 4.5454\n",
|
||
" # quantized 2 = quantized 1 + distance between consecutive values = 6.0606\n",
|
||
" # quantized 3 = quantized 2 + distance between consecutive values = 7.5757\n",
|
||
" \n",
|
||
" self.q = q\n",
|
||
" self.zp = zp\n",
|
||
" self.n = n\n",
|
||
"\n",
|
||
"class QuantizedArray:\n",
|
||
" def __init__(self, values, parameters):\n",
|
||
" # values = quantized values\n",
|
||
" # parameters = parameters used during quantization\n",
|
||
" \n",
|
||
" # e.g.,\n",
|
||
" \n",
|
||
" # values = [1, 0, 2, 1]\n",
|
||
" # parameters = QuantizationParameters(q=0.66, zp=2, n=2)\n",
|
||
" \n",
|
||
" # original array = [4.5454, 3.0303, 6.0606, 4.5454]\n",
|
||
" \n",
|
||
" self.values = np.array(values)\n",
|
||
" self.parameters = parameters\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def of(x, n):\n",
|
||
" if not isinstance(x, np.ndarray):\n",
|
||
" x = np.array(x)\n",
|
||
"\n",
|
||
" min_x = x.min()\n",
|
||
" max_x = x.max()\n",
|
||
"\n",
|
||
" if min_x == max_x: # encoding single valued arrays\n",
|
||
" \n",
|
||
" if min_x == 0.0: # encoding 0s\n",
|
||
" \n",
|
||
" # dequantization = (x_q + zp_x) / q_x = 0 --> q_x = 1 && zp_x = 0 && x_q = 0\n",
|
||
" q_x = 1\n",
|
||
" zp_x = 0\n",
|
||
" x_q = np.zeros(x.shape, dtype=np.uint)\n",
|
||
" \n",
|
||
" elif min_x < 0.0: # encoding negative scalars\n",
|
||
" \n",
|
||
" # dequantization = (x_q + zp_x) / q_x = -x --> q_x = 1 / x & zp_x = -1 & x_q = 0\n",
|
||
" q_x = abs(1 / min_x)\n",
|
||
" zp_x = -1\n",
|
||
" x_q = np.zeros(x.shape, dtype=np.uint)\n",
|
||
" \n",
|
||
" else: # encoding positive scalars\n",
|
||
" \n",
|
||
" # dequantization = (x_q + zp_x) / q_x = x --> q_x = 1 / x & zp_x = 0 & x_q = 1\n",
|
||
" q_x = 1 / min_x\n",
|
||
" zp_x = 0\n",
|
||
" x_q = np.ones(x.shape, dtype=np.uint)\n",
|
||
" \n",
|
||
" else: # encoding multi valued arrays\n",
|
||
" \n",
|
||
" # distance between consecutive values = range of x / number of different quantized values = (max_x - min_x) / (2^n - 1)\n",
|
||
" # q = 1 / distance between consecutive values\n",
|
||
" q_x = (2**n - 1) / (max_x - min_x)\n",
|
||
" \n",
|
||
" # zp = what should be added to 0 to get min_x -> min_x = (0 + zp) / q -> zp = min_x * q\n",
|
||
" zp_x = int(round(min_x * q_x))\n",
|
||
" \n",
|
||
" # x = (x_q + zp) / q -> x_q = (x * q) - zp\n",
|
||
" x_q = ((q_x * x) - zp_x).round().astype(np.uint)\n",
|
||
"\n",
|
||
" return QuantizedArray(x_q, QuantizationParameters(q_x, zp_x, n))\n",
|
||
"\n",
|
||
" def dequantize(self):\n",
|
||
" # x = (x_q + zp) / q\n",
|
||
" # x = (x_q + zp) / q\n",
|
||
" return (self.values.astype(np.float32) + float(self.parameters.zp)) / self.parameters.q\n",
|
||
"\n",
|
||
" def affine(self, w, b, min_y, max_y, n_y):\n",
|
||
" # the formulas used in this method was derived from the following equations\n",
|
||
" #\n",
|
||
" # x = (x_q + zp_x) / q_x\n",
|
||
" # w = (w_q + zp_w) / q_w\n",
|
||
" # b = (b_q + zp_b) / q_b\n",
|
||
" #\n",
|
||
" # (x * w) + b = ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b)\n",
|
||
" # = y = (y_q + zp_y) / q_y\n",
|
||
" #\n",
|
||
" # So, ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b) = (y_q + zp_y) / q_y\n",
|
||
" # We can calculate zp_y and q_y from min_y, max_y, n_y. So, the only unknown is y_q and it can be solved.\n",
|
||
"\n",
|
||
" x_q = self.values\n",
|
||
" w_q = w.values\n",
|
||
" b_q = b.values\n",
|
||
"\n",
|
||
" q_x = self.parameters.q\n",
|
||
" q_w = w.parameters.q\n",
|
||
" q_b = b.parameters.q\n",
|
||
"\n",
|
||
" zp_x = self.parameters.zp\n",
|
||
" zp_w = w.parameters.zp\n",
|
||
" zp_b = b.parameters.zp\n",
|
||
"\n",
|
||
" q_y = (2**n_y - 1) / (max_y - min_y)\n",
|
||
" zp_y = int(round(min_y * q_y))\n",
|
||
"\n",
|
||
" y_q = (q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b))\n",
|
||
" y_q -= min_y * q_y\n",
|
||
" y_q = y_q.round().clip(0, 2**n_y - 1).astype(np.uint)\n",
|
||
"\n",
|
||
" return QuantizedArray(y_q, QuantizationParameters(q_y, zp_y, n_y))\n",
|
||
"\n",
|
||
"class QuantizedFunction:\n",
|
||
" def __init__(self, table, input_parameters=None, output_parameters=None):\n",
|
||
" self.table = table\n",
|
||
" self.input_parameters = input_parameters\n",
|
||
" self.output_parameters = output_parameters\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def of(f, input_bits, output_bits):\n",
|
||
" domain = np.array(range(2**input_bits), dtype=np.uint)\n",
|
||
" table = f(domain).round().clip(0, 2**output_bits - 1).astype(np.uint)\n",
|
||
" return QuantizedFunction(table)\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def plain(f, input_parameters, output_bits):\n",
|
||
" n = input_parameters.n\n",
|
||
"\n",
|
||
" domain = np.array(range(2**n), dtype=np.uint)\n",
|
||
" inputs = QuantizedArray(domain, input_parameters).dequantize()\n",
|
||
"\n",
|
||
" outputs = f(inputs)\n",
|
||
" quantized_outputs = QuantizedArray.of(outputs, output_bits)\n",
|
||
"\n",
|
||
" table = quantized_outputs.values\n",
|
||
" output_parameters = quantized_outputs.parameters\n",
|
||
"\n",
|
||
" return QuantizedFunction(table, input_parameters, output_parameters)\n",
|
||
"\n",
|
||
" def apply(self, x):\n",
|
||
" assert x.parameters == self.input_parameters\n",
|
||
" return QuantizedArray(self.table[x.values], self.output_parameters)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e5be0800",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Let's quantize our model parameters\n",
|
||
"\n",
|
||
"Since the parameters only consist of scalars, we can use a single bit quantization."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"id": "3ec0ad9b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"parameter_bits = 1\n",
|
||
"\n",
|
||
"w_q = QuantizedArray.of(w, parameter_bits)\n",
|
||
"b_q = QuantizedArray.of(b, parameter_bits)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b43c0371",
|
||
"metadata": {},
|
||
"source": [
|
||
"### And quantize our inputs"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "20cea447",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"input_bits = 5\n",
|
||
"\n",
|
||
"x = inputs\n",
|
||
"x_q = QuantizedArray.of(inputs, input_bits)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ca76b68d",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Time to make quantized inference"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "8728e939",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"output_bits = 7\n",
|
||
"\n",
|
||
"intermediate = x @ w + b\n",
|
||
"intermediate_q = x_q.affine(w_q, b_q, intermediate.min(), intermediate.max(), output_bits)\n",
|
||
"\n",
|
||
"sigmoid = QuantizedFunction.plain(lambda x: 1 / (1 + np.exp(-x)), intermediate_q.parameters, output_bits)\n",
|
||
"y_q = sigmoid.apply(intermediate_q)\n",
|
||
"\n",
|
||
"quantized_predictions = y_q.dequantize()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ab782b4a",
|
||
"metadata": {},
|
||
"source": [
|
||
"### And visualize the results"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "9d2bb5da",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQaElEQVR4nO3df2xdZ33H8fd3jTtTx0siUrGkDssqkY2QUKAZ6QTavERbf6FW05hGt8FaDaXaSje0SkPjj1Ybf01oiK4IoqhUoRtrmaBibVXWoYQuGqyeQltwSqaqagOERgq0iuncdkrW7/4419Qxtu+xc67P9eP3S7rqvec8vueTp/Ynx889NzcyE0nS8vczbQeQJDXDQpekQljoklQIC12SCmGhS1IhVrV14KGhoVy3bl1bh1ehTp8+zYUXXsj555/fdhSpJx5//PEfZeaFs+1rrdDXrVvHzTff3NbhVajnnnuOG2+8kc2bN7cdReqJoaGh7861zyUXSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBWi/EKf+ZmpfoaqpEJ1/dcWI2ITcDfwBiCBfZl5+4wxAdwOXAW8BFyfmY81H3eBHnkEXnkFLr8cIqoyf/hhGByE0dG200mNGR+HgwdhYgLWrIFdu2D79rZTla0f57zOGfoZ4JbM3ApcBtwUEVtnjLkSeFPntgf4TKMpFyOzKvOxsarEp8p8bKza7pm6CjE+Dg88AKdOVd/Wp05Vj8fH205Wrn6d865n6Jl5AjjRuf9iRBwFLgK+M23YtcDdmZnAoxGxNiI2dL62HRHVmTlUJT42Vt3fufO1M3apAAcPwunTZ287fbra3vYZY6n6dc4XtIYeEZuBtwNjM3ZdBHx/2uPjnW0zv35PRByOiMOTk5MLjLoI00t9imWuwkxMLGy7zl2/znntQo+I1cCXgA9n5o8Xc7DM3JeZOzJzx9DQ0GKeYqEHrJZZpptafpEKsWbNwrbr3PXrnNcq9IgYoCrzz2fmfbMM+QGwadrjkc629kxfM9+5E269tfrv9DV1qQC7dsHAwNnbBgaq7eqNfp3zOle5BPBZ4GhmfmKOYfcDH4qIe4GdwESr6+dQLasMDp69Zj61/DI46LKLijG1ZttvV1yUrF/nvM6HRL8LeD8wHhFPdLZ9FHgjQGbuBR6iumTxaarLFm9oPOlijI5WZ+JT5T1V6pa5CrN9e/tlstL045zXucrlP4B5G7BzdctNTYVq1MzytswlFar8d4pK0gphoUtSIeqsoUvLyssvv8yJE/Vek1+9ejXDw8M9TiQtDQtdRdm4cSMHDhyoPX7Xrl1s2rTJUlcRXHJRcTKz9u2ZZ55pO67UGAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJ0LfSIuCsiTkbEkTn2r4mIByLiWxHxZETc0HxMSVI3dc7Q9wNXzLP/JuA7mXkJMAr8XUScf+7RJEkL0bXQM/MQ8MJ8Q4DhiAhgdWfsmWbiSZLqWtXAc3wKuB94DhgGfi8zX51tYETsAfYArF27toFDS5KmNPGi6OXAE8BG4G3ApyLi52YbmJn7MnNHZu4YGhpq4NCSpClNFPoNwH1ZeRp4FvjlBp5XkrQATRT694DdABHxBuCXgGcaeF5J0gJ0XUOPiHuorl5ZHxHHgduAAYDM3At8DNgfEeNAAB/JzB/1LLEkaVZdCz0zr+uy/zngtxpLJElaFN8pKkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCNPGJRdKydezYMS6++OLa41etWsWWLVt6mEhaPAtdK1pmcuDAgVpjjxw5wi233NLjRNLiueQiSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUiK6FHhF3RcTJiDgyz5jRiHgiIp6MiH9vNqIkqY46Z+j7gSvm2hkRa4FPA9dk5luA320kmSRpQboWemYeAl6YZ8jvA/dl5vc64082lE2StABNrKFvAdZFxCMR8c2I+MBcAyNiT0QcjojDk5OTDRxakjSliU8sWgVcCuwGXgf8Z0Q8mplPzRyYmfuAfQAjIyPZwLElSR1NFPpx4PnMnAQmI+IQcAnwU4UuSeqdJpZc/gV4d0SsiogLgJ3A0QaeV5K0AF3P0CPiHmAUWB8Rx4HbgAGAzNybmUcj4l+BbwOvAndm5pyXOEqSeqNroWfmdTXGfBz4eCOJJEmL4jtFJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgqxqu0A0nJy5swZnnrqqQV9zYYNGxgeHu5RIuk1FrpU07Zt27j99tsX9DXbt29n9+7dvPnNb+5RKuk1Frq0ANu2bVvQ+CeffJLdu3f3KI10NtfQJakQFrokFcJCl6RCdC30iLgrIk5GxJEu434lIs5ExHubiydJqqvOGfp+4Ir5BkTEecDfAv/WQCZJ0iJ0LfTMPAS80GXYzcCXgJNNhJIkLdw5r6FHxEXAbwOfqTF2T0QcjojDk5OT53poSdI0Tbwo+kngI5n5areBmbkvM3dk5o6hoaEGDi1JmtLEG4t2APdGBMB64KqIOJOZX27guSVJNZ1zoWfmL07dj4j9wIOWuSQtva6FHhH3AKPA+og4DtwGDABk5t6eppMk1da10DPzurpPlpnXn1MaSdKi+U5RSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQXQs9Iu6KiJMRcWSO/X8QEd+OiPGI+EZEXNJ8TElSN3XO0PcDV8yz/1ng1zNzO/AxYF8DuSRJC7Sq24DMPBQRm+fZ/41pDx8FRhrIJUlaoK6FvkB/DHxlrp0RsQfYA7B27dqGDy31rxdffLH22OHh4R4mUckaK/SI+A2qQn/3XGMycx+dJZmRkZFs6thSv9q6dSsHDhxg8+bNtcZffPHFXHDBBbXHS9M1UugR8VbgTuDKzHy+ieeUSpGZPPvss7XGfv3rX+fGG2/scSKV6pwvW4yINwL3Ae/PzKfOPZIkaTG6nqFHxD3AKLA+Io4DtwEDAJm5F7gVeD3w6YgAOJOZO3oVWJI0uzpXuVzXZf8HgQ82lkiStCi+U1SSCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKkT5hZ45/2M1zzmXWrGq24CIuAt4D3AyM7fNsj+A24GrgJeA6zPzsaaDLsojj8Arr8Dll0NEVSwPPwyDgzA62na6MjnnWiHGx+HgQZiYgDVrYNcu2L693Ux1ztD3A1fMs/9K4E2d2x7gM+ceqwGZVbGMjVWFMlUsY2PVds8am+eca4UYH4cHHoBTp6pv61Onqsfj4+3m6nqGnpmHImLzPEOuBe7OzAQejYi1EbEhM080FXJRIqqzRKgKZWysur9z52tnj2qWc64V4uBBOH367G2nT1fb2zxLb2IN/SLg+9MeH+9s+ykRsSciDkfE4cnJyQYO3cX0gplisfSWc64VYGJiYduXypK+KJqZ+zJzR2buGBoaWooDVr/yTze1FKDecM61AqxZs7DtS6XrkksNPwA2TXs80tnWrunrt1O/8k89Bs8ae8E51wqxa1e1Zj592WVgoNrepiYK/X7gQxFxL7ATmGh9/Ryq4hgcPHv9dmopYHDQYukF51wrxNQ6eb9d5VLnssV7gFFgfUQcB24DBgAycy/wENUli09TXbZ4Q6/CLtjoaHXWOFUkUwVjsfSOc64VYvv29gt8pjpXuVzXZX8CNzWWqGkzi8Ri6T3nXGpF+e8UlaQVwkKXpEJY6JJUiCaucpHUoJdeeomjR4/WGnveeeexZcuWHifScmGhS31k48aN3HHHHbXHX3311QwPD7Nhw4YeptJyYaFLfWbr1q21xx47doxLL720h2m0nLiGLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgoR2dIHD0TED4HvLuEh1wM/WsLjNWm5Zl+uuWH5Zl+uuWH5Zl/q3L+QmRfOtqO1Ql9qEXE4M3e0nWMxlmv25Zoblm/25Zoblm/2fsrtkoskFcJCl6RCrKRC39d2gHOwXLMv19ywfLMv19ywfLP3Te4Vs4YuSaVbSWfoklQ0C12SClFUoUfEXRFxMiKOzLE/IuLvI+LpiPh2RLxjqTPOpUb20YiYiIgnOrdblzrjbCJiU0R8LSK+ExFPRsSfzzKm7+a9Zu5+nfPBiPiviPhWJ/tfzzLmZyPiC505H4uIzS1EnZmpTu7rI+KH0+b8g21knUtEnBcRj0fEg7Psa3/OM7OYG/BrwDuAI3Psvwr4ChDAZcBY25kXkH0UeLDtnLPk2gC8o3N/GHgK2Nrv814zd7/OeQCrO/cHgDHgshlj/hTY27n/PuALyyT39cCn2s46z5/hL4B/mu37oh/mvKgz9Mw8BLwwz5Brgbuz8iiwNiL64qNeamTvS5l5IjMf69x/ETgKXDRjWN/Ne83cfakzj//TeTjQuc28uuFa4HOd+18EdkdELFHEWdXM3bciYgS4GrhzjiGtz3lRhV7DRcD3pz0+zjL5Ie741c6vq1+JiLe0HWamzq+Yb6c685qur+d9ntzQp3Pe+dX/CeAk8NXMnHPOM/MMMAG8fklDzqJGboDf6SzNfTEiNi1twnl9EvhL4NU59rc+5yut0Jezx6j+DYdLgDuAL7cb52wRsRr4EvDhzPxx23nq6pK7b+c8M/8vM98GjADvjIhtLUeqpUbuB4DNmflW4Ku8dsbbqoh4D3AyM7/Zdpb5rLRC/wEw/W/8kc62vpeZP576dTUzHwIGImJ9y7EAiIgBqlL8fGbeN8uQvpz3brn7ec6nZOYp4GvAFTN2/WTOI2IVsAZ4fknDzWOu3Jn5fGb+b+fhnUC/fGDqu4BrIuIYcC+wKyL+ccaY1ud8pRX6/cAHOlddXAZMZOaJtkPVERE/P7UeFxHvpPp/1/oPaCfTZ4GjmfmJOYb13bzXyd3Hc35hRKzt3H8d8JvAf88Ydj/wR5377wUOZufVurbUyT3jtZVrqF7baF1m/lVmjmTmZqoXPA9m5h/OGNb6nK9ayoP1WkTcQ3VlwvqIOA7cRvXCC5m5F3iI6oqLp4GXgBvaSfrTamR/L/AnEXEGeBl4X9s/oB3vAt4PjHfWRgE+CrwR+nre6+Tu1znfAHwuIs6j+kvmnzPzwYj4G+BwZt5P9ZfVP0TE01Qvtr+vvbg/USf3n0XENcAZqtzXt5a2hn6bc9/6L0mFWGlLLpJULAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFeL/Admi1qH7N00UAAAAAElFTkSuQmCC\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"for column in contour.collections:\n",
|
||
" plt.gca().collections.remove(column)\n",
|
||
" \n",
|
||
"contour = ax.contourf(\n",
|
||
" contour_plot_x_data,\n",
|
||
" contour_plot_y_data,\n",
|
||
" quantized_predictions.round().reshape(contour_plot_x_data.shape),\n",
|
||
" cmap=\"gray\",\n",
|
||
" alpha=0.50,\n",
|
||
")\n",
|
||
"display(fig)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4834cdfc",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Now it's time to make the inference homomorphic"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "fcf4ea26",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"q_y = (2**output_bits - 1) / (intermediate.max() - intermediate.min())\n",
|
||
"zp_y = int(round(intermediate.min() * q_y))\n",
|
||
"\n",
|
||
"q_x = x_q.parameters.q\n",
|
||
"q_w = w_q.parameters.q\n",
|
||
"q_b = b_q.parameters.q\n",
|
||
"\n",
|
||
"zp_x = x_q.parameters.zp\n",
|
||
"zp_w = w_q.parameters.zp\n",
|
||
"zp_b = b_q.parameters.zp\n",
|
||
"\n",
|
||
"x_q = x_q.values\n",
|
||
"w_q = w_q.values\n",
|
||
"b_q = b_q.values"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "43e47369",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Simplification to rescue!\n",
|
||
"\n",
|
||
"The `y_q` formula in `QuantizedArray.affine(...)` can be rewritten to make it easier to implement in homomorphically. Here is the breakdown.\n",
|
||
"```\n",
|
||
"(q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b)) - (min_y * q_y)\n",
|
||
"^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^\n",
|
||
"constant (c1) can be done constant (c2) constant (c3) constant (c4)\n",
|
||
" on the circuit \n",
|
||
" \n",
|
||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||
" can be done on the circuit\n",
|
||
" \n",
|
||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||
"cannot be done on the circuit because of floating point operation so will be a single table lookup\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "2de0cf20",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"c1 = q_y / (q_x * q_w)\n",
|
||
"c2 = w_q + zp_w\n",
|
||
"c3 = (q_x * q_w / q_b) * (b_q + zp_b)\n",
|
||
"c4 = intermediate.min() * q_y\n",
|
||
"\n",
|
||
"def f(x):\n",
|
||
" values = ((c1 * (x + c3)) - c4).round().clip(0, 2**output_bits - 1).astype(np.uint)\n",
|
||
" after_affine_q = QuantizedArray(values, intermediate_q.parameters)\n",
|
||
" \n",
|
||
" sigmoid = QuantizedFunction.plain(lambda x: 1 / (1 + np.exp(-x)), after_affine_q.parameters, output_bits)\n",
|
||
" y_q = sigmoid.apply(after_affine_q)\n",
|
||
" \n",
|
||
" return y_q.values\n",
|
||
"\n",
|
||
"f_q = QuantizedFunction.of(f, output_bits, output_bits)\n",
|
||
"\n",
|
||
"from hdk.common.extensions.table import LookupTable\n",
|
||
"table = LookupTable([int(entry) for entry in f_q.table])\n",
|
||
"\n",
|
||
"w_0 = int(c2.flatten()[0])\n",
|
||
"w_1 = int(c2.flatten()[1])\n",
|
||
"\n",
|
||
"def infer(x_0, x_1):\n",
|
||
" return table[((x_0 + zp_x) * w_0) + ((x_1 + zp_x) * w_1)]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "93eb9499",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Time to compile our quantized inference function"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "a80895fd",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from hdk.common.data_types.integers import Integer\n",
|
||
"from hdk.common.data_types.values import EncryptedValue\n",
|
||
"from hdk.hnumpy.compile import compile_numpy_function\n",
|
||
"\n",
|
||
"dataset = []\n",
|
||
"for x_i in x_q:\n",
|
||
" dataset.append((int(x_i[0]), int(x_i[1])))\n",
|
||
" \n",
|
||
"homomorphic_model = compile_numpy_function(\n",
|
||
" infer,\n",
|
||
" {\n",
|
||
" \"x_0\": EncryptedValue(Integer(input_bits, is_signed=False)),\n",
|
||
" \"x_1\": EncryptedValue(Integer(input_bits, is_signed=False)),\n",
|
||
" },\n",
|
||
" iter(dataset),\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "f0b08a0f",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Here is the textual representation of the operation graph"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"id": "2cc4e11d",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"%0 = ConstantInput(2) # Integer<unsigned, 2 bits>\n",
|
||
"%1 = ConstantInput(1) # Integer<unsigned, 1 bits>\n",
|
||
"%2 = x_0 # Integer<unsigned, 5 bits>\n",
|
||
"%3 = ConstantInput(6) # Integer<unsigned, 3 bits>\n",
|
||
"%4 = x_1 # Integer<unsigned, 4 bits>\n",
|
||
"%5 = ConstantInput(6) # Integer<unsigned, 3 bits>\n",
|
||
"%6 = Add(2, 3) # Integer<unsigned, 6 bits>\n",
|
||
"%7 = Add(4, 5) # Integer<unsigned, 5 bits>\n",
|
||
"%8 = Mul(6, 0) # Integer<unsigned, 7 bits>\n",
|
||
"%9 = Mul(7, 1) # Integer<unsigned, 5 bits>\n",
|
||
"%10 = Add(8, 9) # Integer<unsigned, 7 bits>\n",
|
||
"%11 = ArbitraryFunction(10) # Integer<unsigned, 7 bits>\n",
|
||
"return(%11)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from hdk.common.debugging import get_printable_graph\n",
|
||
"print(get_printable_graph(homomorphic_model, show_data_types=True))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "ade14f17",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Finally, it's time to make homomorphic inference\n",
|
||
"\n",
|
||
"Or, at least, simulate it until the compiler integration is complete."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"id": "dd2d03d7",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"homomorphic_predictions = []\n",
|
||
"for x_0, x_1 in map(lambda x_i: (int(x_i[0]), int(x_i[1])), x_q):\n",
|
||
" evaluation = homomorphic_model.evaluate({0: x_0, 1: x_1})\n",
|
||
" inference = QuantizedArray(evaluation[homomorphic_model.output_nodes[0]], y_q.parameters)\n",
|
||
" homomorphic_predictions.append(inference.dequantize())\n",
|
||
"homomorphic_predictions = np.array(homomorphic_predictions, dtype=np.float32)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "443fbc03",
|
||
"metadata": {},
|
||
"source": [
|
||
"### And visualize it"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"id": "57050b5d",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQaElEQVR4nO3df2xdZ33H8fd3jTtTx0siUrGkDssqkY2QUKAZ6QTavERbf6FW05hGt8FaDaXaSje0SkPjj1Ybf01oiK4IoqhUoRtrmaBibVXWoYQuGqyeQltwSqaqagOERgq0iuncdkrW7/4419Qxtu+xc67P9eP3S7rqvec8vueTp/Ynx889NzcyE0nS8vczbQeQJDXDQpekQljoklQIC12SCmGhS1IhVrV14KGhoVy3bl1bh1ehTp8+zYUXXsj555/fdhSpJx5//PEfZeaFs+1rrdDXrVvHzTff3NbhVajnnnuOG2+8kc2bN7cdReqJoaGh7861zyUXSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBWi/EKf+ZmpfoaqpEJ1/dcWI2ITcDfwBiCBfZl5+4wxAdwOXAW8BFyfmY81H3eBHnkEXnkFLr8cIqoyf/hhGByE0dG200mNGR+HgwdhYgLWrIFdu2D79rZTla0f57zOGfoZ4JbM3ApcBtwUEVtnjLkSeFPntgf4TKMpFyOzKvOxsarEp8p8bKza7pm6CjE+Dg88AKdOVd/Wp05Vj8fH205Wrn6d865n6Jl5AjjRuf9iRBwFLgK+M23YtcDdmZnAoxGxNiI2dL62HRHVmTlUJT42Vt3fufO1M3apAAcPwunTZ287fbra3vYZY6n6dc4XtIYeEZuBtwNjM3ZdBHx/2uPjnW0zv35PRByOiMOTk5MLjLoI00t9imWuwkxMLGy7zl2/znntQo+I1cCXgA9n5o8Xc7DM3JeZOzJzx9DQ0GKeYqEHrJZZpptafpEKsWbNwrbr3PXrnNcq9IgYoCrzz2fmfbMM+QGwadrjkc629kxfM9+5E269tfrv9DV1qQC7dsHAwNnbBgaq7eqNfp3zOle5BPBZ4GhmfmKOYfcDH4qIe4GdwESr6+dQLasMDp69Zj61/DI46LKLijG1ZttvV1yUrF/nvM6HRL8LeD8wHhFPdLZ9FHgjQGbuBR6iumTxaarLFm9oPOlijI5WZ+JT5T1V6pa5CrN9e/tlstL045zXucrlP4B5G7BzdctNTYVq1MzytswlFar8d4pK0gphoUtSIeqsoUvLyssvv8yJE/Vek1+9ejXDw8M9TiQtDQtdRdm4cSMHDhyoPX7Xrl1s2rTJUlcRXHJRcTKz9u2ZZ55pO67UGAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJ0LfSIuCsiTkbEkTn2r4mIByLiWxHxZETc0HxMSVI3dc7Q9wNXzLP/JuA7mXkJMAr8XUScf+7RJEkL0bXQM/MQ8MJ8Q4DhiAhgdWfsmWbiSZLqWtXAc3wKuB94DhgGfi8zX51tYETsAfYArF27toFDS5KmNPGi6OXAE8BG4G3ApyLi52YbmJn7MnNHZu4YGhpq4NCSpClNFPoNwH1ZeRp4FvjlBp5XkrQATRT694DdABHxBuCXgGcaeF5J0gJ0XUOPiHuorl5ZHxHHgduAAYDM3At8DNgfEeNAAB/JzB/1LLEkaVZdCz0zr+uy/zngtxpLJElaFN8pKkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCNPGJRdKydezYMS6++OLa41etWsWWLVt6mEhaPAtdK1pmcuDAgVpjjxw5wi233NLjRNLiueQiSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUiK6FHhF3RcTJiDgyz5jRiHgiIp6MiH9vNqIkqY46Z+j7gSvm2hkRa4FPA9dk5luA320kmSRpQboWemYeAl6YZ8jvA/dl5vc64082lE2StABNrKFvAdZFxCMR8c2I+MBcAyNiT0QcjojDk5OTDRxakjSliU8sWgVcCuwGXgf8Z0Q8mplPzRyYmfuAfQAjIyPZwLElSR1NFPpx4PnMnAQmI+IQcAnwU4UuSeqdJpZc/gV4d0SsiogLgJ3A0QaeV5K0AF3P0CPiHmAUWB8Rx4HbgAGAzNybmUcj4l+BbwOvAndm5pyXOEqSeqNroWfmdTXGfBz4eCOJJEmL4jtFJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgqxqu0A0nJy5swZnnrqqQV9zYYNGxgeHu5RIuk1FrpU07Zt27j99tsX9DXbt29n9+7dvPnNb+5RKuk1Frq0ANu2bVvQ+CeffJLdu3f3KI10NtfQJakQFrokFcJCl6RCdC30iLgrIk5GxJEu434lIs5ExHubiydJqqvOGfp+4Ir5BkTEecDfAv/WQCZJ0iJ0LfTMPAS80GXYzcCXgJNNhJIkLdw5r6FHxEXAbwOfqTF2T0QcjojDk5OT53poSdI0Tbwo+kngI5n5areBmbkvM3dk5o6hoaEGDi1JmtLEG4t2APdGBMB64KqIOJOZX27guSVJNZ1zoWfmL07dj4j9wIOWuSQtva6FHhH3AKPA+og4DtwGDABk5t6eppMk1da10DPzurpPlpnXn1MaSdKi+U5RSSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQXQs9Iu6KiJMRcWSO/X8QEd+OiPGI+EZEXNJ8TElSN3XO0PcDV8yz/1ng1zNzO/AxYF8DuSRJC7Sq24DMPBQRm+fZ/41pDx8FRhrIJUlaoK6FvkB/DHxlrp0RsQfYA7B27dqGDy31rxdffLH22OHh4R4mUckaK/SI+A2qQn/3XGMycx+dJZmRkZFs6thSv9q6dSsHDhxg8+bNtcZffPHFXHDBBbXHS9M1UugR8VbgTuDKzHy+ieeUSpGZPPvss7XGfv3rX+fGG2/scSKV6pwvW4yINwL3Ae/PzKfOPZIkaTG6nqFHxD3AKLA+Io4DtwEDAJm5F7gVeD3w6YgAOJOZO3oVWJI0uzpXuVzXZf8HgQ82lkiStCi+U1SSCmGhS1IhLHRJKoSFLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFcJCl6RCWOiSVAgLXZIKYaFLUiEsdEkqhIUuSYWw0CWpEBa6JBXCQpekQljoklQIC12SCmGhS1IhLHRJKkT5hZ45/2M1zzmXWrGq24CIuAt4D3AyM7fNsj+A24GrgJeA6zPzsaaDLsojj8Arr8Dll0NEVSwPPwyDgzA62na6MjnnWiHGx+HgQZiYgDVrYNcu2L693Ux1ztD3A1fMs/9K4E2d2x7gM+ceqwGZVbGMjVWFMlUsY2PVds8am+eca4UYH4cHHoBTp6pv61Onqsfj4+3m6nqGnpmHImLzPEOuBe7OzAQejYi1EbEhM080FXJRIqqzRKgKZWysur9z52tnj2qWc64V4uBBOH367G2nT1fb2zxLb2IN/SLg+9MeH+9s+ykRsSciDkfE4cnJyQYO3cX0gplisfSWc64VYGJiYduXypK+KJqZ+zJzR2buGBoaWooDVr/yTze1FKDecM61AqxZs7DtS6XrkksNPwA2TXs80tnWrunrt1O/8k89Bs8ae8E51wqxa1e1Zj592WVgoNrepiYK/X7gQxFxL7ATmGh9/Ryq4hgcPHv9dmopYHDQYukF51wrxNQ6eb9d5VLnssV7gFFgfUQcB24DBgAycy/wENUli09TXbZ4Q6/CLtjoaHXWOFUkUwVjsfSOc64VYvv29gt8pjpXuVzXZX8CNzWWqGkzi8Ri6T3nXGpF+e8UlaQVwkKXpEJY6JJUiCaucpHUoJdeeomjR4/WGnveeeexZcuWHifScmGhS31k48aN3HHHHbXHX3311QwPD7Nhw4YeptJyYaFLfWbr1q21xx47doxLL720h2m0nLiGLkmFsNAlqRAWuiQVwkKXpEJY6JJUCAtdkgoR2dIHD0TED4HvLuEh1wM/WsLjNWm5Zl+uuWH5Zl+uuWH5Zl/q3L+QmRfOtqO1Ql9qEXE4M3e0nWMxlmv25Zoblm/25Zoblm/2fsrtkoskFcJCl6RCrKRC39d2gHOwXLMv19ywfLMv19ywfLP3Te4Vs4YuSaVbSWfoklQ0C12SClFUoUfEXRFxMiKOzLE/IuLvI+LpiPh2RLxjqTPOpUb20YiYiIgnOrdblzrjbCJiU0R8LSK+ExFPRsSfzzKm7+a9Zu5+nfPBiPiviPhWJ/tfzzLmZyPiC505H4uIzS1EnZmpTu7rI+KH0+b8g21knUtEnBcRj0fEg7Psa3/OM7OYG/BrwDuAI3Psvwr4ChDAZcBY25kXkH0UeLDtnLPk2gC8o3N/GHgK2Nrv814zd7/OeQCrO/cHgDHgshlj/hTY27n/PuALyyT39cCn2s46z5/hL4B/mu37oh/mvKgz9Mw8BLwwz5Brgbuz8iiwNiL64qNeamTvS5l5IjMf69x/ETgKXDRjWN/Ne83cfakzj//TeTjQuc28uuFa4HOd+18EdkdELFHEWdXM3bciYgS4GrhzjiGtz3lRhV7DRcD3pz0+zjL5Ie741c6vq1+JiLe0HWamzq+Yb6c685qur+d9ntzQp3Pe+dX/CeAk8NXMnHPOM/MMMAG8fklDzqJGboDf6SzNfTEiNi1twnl9EvhL4NU59rc+5yut0Jezx6j+DYdLgDuAL7cb52wRsRr4EvDhzPxx23nq6pK7b+c8M/8vM98GjADvjIhtLUeqpUbuB4DNmflW4Ku8dsbbqoh4D3AyM7/Zdpb5rLRC/wEw/W/8kc62vpeZP576dTUzHwIGImJ9y7EAiIgBqlL8fGbeN8uQvpz3brn7ec6nZOYp4GvAFTN2/WTOI2IVsAZ4fknDzWOu3Jn5fGb+b+fhnUC/fGDqu4BrIuIYcC+wKyL+ccaY1ud8pRX6/cAHOlddXAZMZOaJtkPVERE/P7UeFxHvpPp/1/oPaCfTZ4GjmfmJOYb13bzXyd3Hc35hRKzt3H8d8JvAf88Ydj/wR5377wUOZufVurbUyT3jtZVrqF7baF1m/lVmjmTmZqoXPA9m5h/OGNb6nK9ayoP1WkTcQ3VlwvqIOA7cRvXCC5m5F3iI6oqLp4GXgBvaSfrTamR/L/AnEXEGeBl4X9s/oB3vAt4PjHfWRgE+CrwR+nre6+Tu1znfAHwuIs6j+kvmnzPzwYj4G+BwZt5P9ZfVP0TE01Qvtr+vvbg/USf3n0XENcAZqtzXt5a2hn6bc9/6L0mFWGlLLpJULAtdkgphoUtSISx0SSqEhS5JhbDQJakQFrokFeL/Admi1qH7N00UAAAAAElFTkSuQmCC\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"for column in contour.collections:\n",
|
||
" plt.gca().collections.remove(column)\n",
|
||
" \n",
|
||
"contour = ax.contourf(\n",
|
||
" contour_plot_x_data,\n",
|
||
" contour_plot_y_data,\n",
|
||
" homomorphic_predictions.round().reshape(contour_plot_x_data.shape),\n",
|
||
" cmap=\"gray\",\n",
|
||
" alpha=0.50,\n",
|
||
")\n",
|
||
"display(fig)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "53ecca94",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Enjoy!"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.7"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|