Files
concrete/examples/QuantizedLinearRegression.ipynb
2021-08-18 08:08:19 +01:00

759 lines
88 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "0fe629d6",
"metadata": {},
"source": [
"# Quantized Linear Regression\n",
"\n",
"Currently, **hdk** only supports unsigned integers up to 7-bits. Nevertheless, we want to evaluate a linear regression model with it. Luckily, we can make use of **quantization** to overcome this limitation!"
]
},
{
"cell_type": "markdown",
"id": "d0cfb561",
"metadata": {},
"source": [
"### Let's start by importing some libraries to develop our linear regression model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3c1d929c",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "markdown",
"id": "69d25f7c",
"metadata": {},
"source": [
"### And some helpers for visualization"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a89c1a6c",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from IPython.display import display"
]
},
{
"cell_type": "markdown",
"id": "7729c1de",
"metadata": {},
"source": [
"### We need a dataset, a handcrafted one for simplicity"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b77a9e82",
"metadata": {},
"outputs": [],
"source": [
"x = np.array([[130], [110], [100], [145], [160], [185], [200], [80], [50]], dtype=np.float32)\n",
"y = np.array([325, 295, 268, 400, 420, 500, 520, 220, 120], dtype=np.float32)"
]
},
{
"cell_type": "markdown",
"id": "cc8673ff",
"metadata": {},
"source": [
"### Let's visualize our dataset to get a grasp of it"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "35a98d1a",
"metadata": {},
"outputs": [],
"source": [
"plt.ioff()\n",
"fig, ax = plt.subplots(1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "56703410",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax.scatter(x[:, 0], y, marker=\"x\", color=\"red\")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "e31b82e8",
"metadata": {},
"source": [
"### Now, we need a model so let's define it\n",
"\n",
"The main purpose of this tutorial is not to train a linear regression model but to use it homomorphically. So we will not discuss about how the model is trained."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cc5e72a2",
"metadata": {},
"outputs": [],
"source": [
"class Model:\n",
" w = None\n",
" b = None\n",
"\n",
" def fit(self, x, y):\n",
" a = np.ones((x.shape[0], x.shape[1] + 1), dtype=np.float32)\n",
" a[:, 1:] = x\n",
"\n",
" regularization_contribution = np.identity(x.shape[1] + 1, dtype=np.float32)\n",
" regularization_contribution[0][0] = 0\n",
"\n",
" parameters = np.linalg.pinv(a.T @ a + regularization_contribution) @ a.T @ y\n",
"\n",
" self.b = parameters[0]\n",
" self.w = parameters[1:].reshape(-1, 1)\n",
"\n",
" return self\n",
"\n",
" def evaluate(self, x):\n",
" return x @ self.w + self.b"
]
},
{
"cell_type": "markdown",
"id": "cefd8346",
"metadata": {},
"source": [
"### And create one"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b9879f4d",
"metadata": {},
"outputs": [],
"source": [
"model = Model().fit(x, y)"
]
},
{
"cell_type": "markdown",
"id": "01cfc83f",
"metadata": {},
"source": [
"### Time to make some predictions"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "78356d37",
"metadata": {},
"outputs": [],
"source": [
"inputs = np.linspace(40, 210, 100).reshape(-1, 1)\n",
"predictions = model.evaluate(inputs)"
]
},
{
"cell_type": "markdown",
"id": "58160140",
"metadata": {},
"source": [
"### Let's visualize our predictions to see how our model performs"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2a623999",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax.plot(inputs, predictions, color=\"blue\")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "d3f39faa",
"metadata": {},
"source": [
"### As a bonus let's inspect the model parameters"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7fa65211",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2.669915]]\n",
"-3.2335143\n"
]
}
],
"source": [
"print(model.w)\n",
"print(model.b)"
]
},
{
"cell_type": "markdown",
"id": "544d6e34",
"metadata": {},
"source": [
"They are floating point numbers and we can't directly work with them!"
]
},
{
"cell_type": "markdown",
"id": "abf310f2",
"metadata": {},
"source": [
"### So, let's abstract quantization\n",
"\n",
"Here is a quick summary of quantization. We have a range of values and we want to represent them using small number of bits (n). To do this, we split the range into 2^n sections and map each section to a value. Here is a visualization of the process!"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a7b3b993",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<svg content=\"&lt;mxfile host=&quot;app.diagrams.net&quot; modified=&quot;2021-08-13T09:47:25.144Z&quot; agent=&quot;5.0 (X11)&quot; etag=&quot;5QhM0DGu1eUjmjeXuyFL&quot; version=&quot;14.9.6&quot; type=&quot;device&quot;&gt;&lt;diagram id=&quot;6rZNNX4_K12e_kCXuZoG&quot; name=&quot;Page-1&quot;&gt;7Zzdb5s6FMD/mkjdw66MHZL0sUl376SrStM6rc8euAGNYAZOk/avnw2YD5t8QAmhJA+tyLFzbM7v2D7HdjtCi9X2vxAHzgO1iTeCwN6O0P0IQhPd8t9C8JoI0MRMBMvQtRORkQse3TeSCkEqXbs2iUoVGaUec4Oy0KK+TyxWkuEwpJtytWfqlVsN8JJogkcLe7r0ybWZk0hnJsjlX4m7dGTLBkhLVlhWTgWRg226KYjQlxFahJSy5Gm1XRBP2E7aJfnevztKs46FxGfHfOFpYrx9//Hw//Lpq8mc+yd4H5mfUy0v2FunL5x2lr1KCxDfvhOG5J8sD0eRa43QPGI4ZLrYYSuPCwz+mOghtmbevL9GZgXuPYSuCAtfeZVNbmdpZqdgYikLiYeZ+1JWj1Pcy0xd1sI36vKGIUg905ikelLHhGNQVhHRdWiR9FtFuyqKTHhAETfVkjBNEX8ovHYuirHVQAhrIfSpTz4WKAgU+5oNQWmKULeg0NBBjdsCpSrqGNR46KBmbYFSFXUMyhw4KKQuLU1BaYo6BjUZOig1mGgMSlXUMajpwEGN2womNEUdg5oNHVRbwYSmqGNQt0MH1VYwoSnqGJTccaidC38gWFoKBKfNYGmRn6poByxuPfxaqBaICtGeDquBC5ju75e6fpbr84ekB+16DmzoOR9/F0Xzg+ltM4dCqmeqik49+utl50NieND0RzNUl9quGdZL3AfF8JDpj2Z4aECfmmG9nH5QDNuaS7XcpGuGVen+xOPWmtvuC39ciseV699sP8kC3lChTEPOyJaVaUYspL/Jgno0zOPmZ9fzFBH23KUvXIJDJlw+fyEhcy3s3aUFK9e2RTPzjeMy8hhgS7S5CXHAZSFd+zYRLwuybgkFZFvXh3YFKNJ5Cj42rvAxCHa7U4lfbVh6yr/C25jMUBloA+3sDBpl8zaOnNgsRpmKkH/DjFvajyUQoIyVPMOEH2A+hLdlTAiMm82HqiI4UxSdej6s2gMozIcFzJM/a3EePH+mPvscxafhd7yCAYJtTEyW5xNlokfUf5eiBz7U8qk4UVduQp2he/YCjI7i7DZR9ytspVdHmqKXE6XoeHqfwmhr8ZqqI3J23olTKr6OrFOPLOM6sk45stQc/Pwjq97NoAsJSdRz1MYhiXYg23FIImle8e47fW2OV83yusbb6ArL4PG2lVCoijrHW+8+xZA2ybTjXDj+ZzJphlE7cazQdWqSVRcuGsRtWpx173LWPo9hCmFbUjInbEOIrxcsqB8Rax3D0Qp/ig5GO6O3XsZpo/fHZaq7GdPcRc4VmqGqtbtnOcMwkh5Y8x16OQpOka2Me5atoKrT25551DCGBLoOieoh0betMaSnADcbR6z7QPzEEgD7espzis1LJdpD8n5ZgZABO0Wkn3WDy+Fh9O0UDh11bF17OoVV4Xky/EQqLfrM1cR/cdh4eThl1/5cZ/zq1GByVGpgGJ36cAv5ZKVXbOWasdcX3ul5ENwIMV4JgP6vKEgKlc/cduL3W1BR9KnotmBnby/FR9U1L5tziw7a7SRbddXh/TNZ5R7IdebaEavKCwmZV5hnXnkPXJDIpqAL52Ya/eImG7ty289t2rPxNq463NW4vZFQZPxBsvetbggfHZWKdVqmm5fuCerf3xjy4KjgCbN2PIF/zP+3SHJckf+DFvTlLw==&lt;/diagram&gt;&lt;/mxfile&gt;\" height=\"195px\" version=\"1.1\" viewBox=\"-0.5 -0.5 420 195\" width=\"420px\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"><defs/><g><path d=\"M 14.37 84 L 361.63 84\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 9.12 84 L 16.12 80.5 L 14.37 84 L 16.12 87.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 366.88 84 L 359.88 87.5 L 361.63 84 L 359.88 80.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 48 94 L 48 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 88 94 L 88 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 128 94 L 128 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 168 94 L 168 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 208 94 L 208 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 248 94 L 248 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 288 94 L 288 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 328 94 L 328 74\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 48 71 L 60.93 58.07 Q 68 51 78 51 L 98 51 Q 108 51 115.07 58.07 L 123.5 66.5\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 127.21 70.21 L 119.78 67.73 L 123.5 66.5 L 124.73 62.78 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 134.37 123 L 141.63 123\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 129.12 123 L 136.12 119.5 L 134.37 123 L 136.12 126.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 146.88 123 L 139.88 126.5 L 141.63 123 L 139.88 119.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 154.37 123 L 181.63 123\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 149.12 123 L 156.12 119.5 L 154.37 123 L 156.12 126.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 186.88 123 L 179.88 126.5 L 181.63 123 L 179.88 119.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 194.37 123 L 221.63 123\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 189.12 123 L 196.12 119.5 L 194.37 123 L 196.12 126.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 226.88 123 L 219.88 126.5 L 221.63 123 L 219.88 119.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 234.37 123 L 241.63 123\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 229.12 123 L 236.12 119.5 L 234.37 123 L 236.12 126.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 246.88 123 L 239.88 126.5 L 241.63 123 L 239.88 119.5 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"108\" y=\"94\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 109px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>min(x)</div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"128\" y=\"108\">min(x)</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"228\" y=\"94\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 229px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">max(x)</div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"248\" y=\"108\">max(x)</text></switch></g><path d=\"M 138 148 L 138 128\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-dasharray=\"2 6\" stroke-miterlimit=\"10\" stroke-width=\"2\"/><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"118\" y=\"152\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 119px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 0<br style=\"font-size: 10px\"/></font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"138\" y=\"165\">Map...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"148\" y=\"152\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 149px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 1<br style=\"font-size: 10px\"/></font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"168\" y=\"165\">Map...</text></switch></g><path d=\"M 168 148 L 168 128\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-dasharray=\"2 6\" stroke-miterlimit=\"10\" stroke-width=\"2\"/><path d=\"M 208 148 L 208 128\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-dasharray=\"2 6\" stroke-miterlimit=\"10\" stroke-width=\"2\"/><path d=\"M 238 148 L 238 128\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-dasharray=\"2 6\" stroke-miterlimit=\"10\" stroke-width=\"2\"/><path d=\"M 294.37 68.66 L 321.63 68.66\" fill=\"none\" pointer-events=\"stroke\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 289.12 68.66 L 296.12 65.16 L 294.37 68.66 L 296.12 72.16 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><path d=\"M 326.88 68.66 L 319.88 72.16 L 321.63 68.66 L 319.88 65.16 Z\" fill=\"#000000\" pointer-events=\"all\" stroke=\"#000000\" stroke-miterlimit=\"10\"/><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"288\" y=\"18.66\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 29px; margin-left: 289px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><font style=\"font-size: 10px\">Distance<br/>Between<br/>Consecutive<br/>Values</font></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"308\" y=\"32\">Distan...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"188\" y=\"152\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 189px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 2</font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"208\" y=\"165\">Map...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"218\" y=\"152\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 219px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 3</font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"238\" y=\"165\">Map...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"120\" x=\"128\" y=\"174\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 184px; margin-left: 129px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">(when n = 2)</div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"188\" y=\"187\">(when n = 2)</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"28\" y=\"94\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 29px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">0</div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"48\" y=\"107\">0</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"110\" x=\"308\" y=\"18.66\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 108px; height: 1px; padding-top: 29px; margin-left: 309px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div><font style=\"font-size: 12px\">= 1 / scale</font></div><div><font style=\"font-size: 12px\">= 1 / q</font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"363\" y=\"32\">= 1 / scale...</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"140\" x=\"128\" y=\"24\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 138px; height: 1px; padding-top: 34px; margin-left: 129px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><font style=\"font-size: 12px\">x =</font><font style=\"font-size: 12px\"> (x   + zp  ) / q </font></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"198\" y=\"37\">x = (x   + zp  ) / q </text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"167\" y=\"29\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 168px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div><font style=\"font-size: 10px\">q</font></div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"187\" y=\"42\">q</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"199\" y=\"29\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 200px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>x</div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"219\" y=\"42\">x</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"40\" x=\"227\" y=\"29\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 228px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>x</div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"247\" y=\"42\">x</text></switch></g><rect fill=\"none\" height=\"20\" pointer-events=\"all\" stroke=\"none\" width=\"80\" x=\"48\" y=\"28\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject height=\"100%\" pointer-events=\"none\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow: visible; text-align: left;\" width=\"100%\"><div style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 78px; height: 1px; padding-top: 38px; margin-left: 49px;\" xmlns=\"http://www.w3.org/1999/xhtml\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>zero point<br/></div><div>zp = 2</div></div></div></div></foreignObject><text fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\" x=\"88\" y=\"41\">zero point...</text></switch></g></g><switch><g requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"/><a target=\"_blank\" transform=\"translate(0,-5)\" xlink:href=\"https://www.diagrams.net/doc/faq/svg-export-text-problems\"><text font-size=\"10px\" text-anchor=\"middle\" x=\"50%\" y=\"100%\">Viewer does not support full SVG 1.1</text></a></switch></svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import SVG\n",
"SVG(filename=\"figures/QuantizationVisualized.svg\")"
]
},
{
"cell_type": "markdown",
"id": "9cbd7e1d",
"metadata": {},
"source": [
"If you want to learn more, head to https://intellabs.github.io/distiller/algo_quantization.html"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a8bab855",
"metadata": {},
"outputs": [],
"source": [
"class QuantizationParameters:\n",
" def __init__(self, q, zp, n):\n",
" # q = scale factor = 1 / distance between consecutive values\n",
" # zp = zero point which is used to determine the beginning of the quantized range\n",
" # (quantized 0 = the beginning of the quantized range = zp * distance between consecutive values)\n",
" # n = number of bits\n",
" \n",
" # e.g.,\n",
" \n",
" # n = 2\n",
" # zp = 2\n",
" # q = 0.66\n",
" # distance between consecutive values = 1 / q = 1.5151\n",
" \n",
" # quantized 0 = zp / q = zp * distance between consecutive values = 3.0303\n",
" # quantized 1 = quantized 0 + distance between consecutive values = 4.5454\n",
" # quantized 2 = quantized 1 + distance between consecutive values = 6.0606\n",
" # quantized 3 = quantized 2 + distance between consecutive values = 7.5757\n",
" \n",
" self.q = q\n",
" self.zp = zp\n",
" self.n = n\n",
"\n",
"class QuantizedArray:\n",
" def __init__(self, values, parameters):\n",
" # values = quantized values\n",
" # parameters = parameters used during quantization\n",
" \n",
" # e.g.,\n",
" \n",
" # values = [1, 0, 2, 1]\n",
" # parameters = QuantizationParameters(q=0.66, zp=2, n=2)\n",
" \n",
" # original array = [4.5454, 3.0303, 6.0606, 4.5454]\n",
" \n",
" self.values = np.array(values)\n",
" self.parameters = parameters\n",
"\n",
" @staticmethod\n",
" def of(x, n):\n",
" if not isinstance(x, np.ndarray):\n",
" x = np.array(x)\n",
"\n",
" min_x = x.min()\n",
" max_x = x.max()\n",
"\n",
" if min_x == max_x: # encoding single valued arrays\n",
" \n",
" if min_x == 0.0: # encoding 0s\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = 0 --> q_x = 1 && zp_x = 0 && x_q = 0\n",
" q_x = 1\n",
" zp_x = 0\n",
" x_q = np.zeros(x.shape, dtype=np.uint)\n",
" \n",
" elif min_x < 0.0: # encoding negative scalars\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = -x --> q_x = 1 / x & zp_x = -1 & x_q = 0\n",
" q_x = abs(1 / min_x)\n",
" zp_x = -1\n",
" x_q = np.zeros(x.shape, dtype=np.uint)\n",
" \n",
" else: # encoding positive scalars\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = x --> q_x = 1 / x & zp_x = 0 & x_q = 1\n",
" q_x = 1 / min_x\n",
" zp_x = 0\n",
" x_q = np.ones(x.shape, dtype=np.uint)\n",
" \n",
" else: # encoding multi valued arrays\n",
" \n",
" # distance between consecutive values = range of x / number of different quantized values = (max_x - min_x) / (2^n - 1)\n",
" # q = 1 / distance between consecutive values\n",
" q_x = (2**n - 1) / (max_x - min_x)\n",
" \n",
" # zp = what should be added to 0 to get min_x -> min_x = (0 + zp) / q -> zp = min_x * q\n",
" zp_x = int(round(min_x * q_x))\n",
" \n",
" # x = (x_q + zp) / q -> x_q = (x * q) - zp\n",
" x_q = ((q_x * x) - zp_x).round().astype(np.uint)\n",
"\n",
" return QuantizedArray(x_q, QuantizationParameters(q_x, zp_x, n))\n",
"\n",
" def dequantize(self):\n",
" # x = (x_q + zp) / q\n",
" # x = (x_q + zp) / q\n",
" return (self.values.astype(np.float32) + float(self.parameters.zp)) / self.parameters.q\n",
"\n",
" def affine(self, w, b, min_y, max_y, n_y):\n",
" # the formulas used in this method was derived from the following equations\n",
" #\n",
" # x = (x_q + zp_x) / q_x\n",
" # w = (w_q + zp_w) / q_w\n",
" # b = (b_q + zp_b) / q_b\n",
" #\n",
" # (x * w) + b = ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b)\n",
" # = y = (y_q + zp_y) / q_y\n",
" #\n",
" # So, ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b) = (y_q + zp_y) / q_y\n",
" # We can calculate zp_y and q_y from min_y, max_y, n_y. So, the only unknown is y_q and it can be solved.\n",
"\n",
" x_q = self.values\n",
" w_q = w.values\n",
" b_q = b.values\n",
"\n",
" q_x = self.parameters.q\n",
" q_w = w.parameters.q\n",
" q_b = b.parameters.q\n",
"\n",
" zp_x = self.parameters.zp\n",
" zp_w = w.parameters.zp\n",
" zp_b = b.parameters.zp\n",
"\n",
" q_y = (2**n_y - 1) / (max_y - min_y)\n",
" zp_y = int(round(min_y * q_y))\n",
"\n",
" y_q = (q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b))\n",
" y_q -= min_y * q_y\n",
" y_q = y_q.round().clip(0, 2**n_y - 1).astype(np.uint)\n",
"\n",
" return QuantizedArray(y_q, QuantizationParameters(q_y, zp_y, n_y))\n",
"\n",
"class QuantizedFunction:\n",
" def __init__(self, table):\n",
" self.table = table\n",
"\n",
" @staticmethod\n",
" def of(f, input_bits, output_bits):\n",
" domain = np.array(range(2**input_bits), dtype=np.uint)\n",
" table = f(domain).round().clip(0, 2**output_bits - 1).astype(np.uint)\n",
" return QuantizedFunction(table)"
]
},
{
"cell_type": "markdown",
"id": "e5be0800",
"metadata": {},
"source": [
"### Let's quantize our model parameters\n",
"\n",
"Since the parameters only consist of scalars, we can use a single bit quantization."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3ec0ad9b",
"metadata": {},
"outputs": [],
"source": [
"parameter_bits = 1\n",
"\n",
"w_q = QuantizedArray.of(model.w, parameter_bits)\n",
"b_q = QuantizedArray.of(model.b, parameter_bits)"
]
},
{
"cell_type": "markdown",
"id": "b43c0371",
"metadata": {},
"source": [
"### And quantize our inputs"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "20cea447",
"metadata": {},
"outputs": [],
"source": [
"input_bits = 6\n",
"\n",
"x_q = QuantizedArray.of(inputs, input_bits)"
]
},
{
"cell_type": "markdown",
"id": "ca76b68d",
"metadata": {},
"source": [
"### Time to make quantized inference"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8728e939",
"metadata": {},
"outputs": [],
"source": [
"output_bits = 7\n",
"\n",
"min_y = predictions.min()\n",
"max_y = predictions.max()\n",
"y_q = x_q.affine(w_q, b_q, min_y, max_y, output_bits)\n",
"\n",
"quantized_predictions = y_q.dequantize()"
]
},
{
"cell_type": "markdown",
"id": "ab782b4a",
"metadata": {},
"source": [
"### And visualize the results"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "9d2bb5da",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax.plot(inputs, quantized_predictions, color=\"black\")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "4834cdfc",
"metadata": {},
"source": [
"### Now it's time to make the inference homomorphic"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "fcf4ea26",
"metadata": {},
"outputs": [],
"source": [
"q_y = (2**output_bits - 1) / (max_y - min_y)\n",
"zp_y = int(round(min_y * q_y))\n",
"\n",
"q_x = x_q.parameters.q\n",
"q_w = w_q.parameters.q\n",
"q_b = b_q.parameters.q\n",
"\n",
"zp_x = x_q.parameters.zp\n",
"zp_w = w_q.parameters.zp\n",
"zp_b = b_q.parameters.zp\n",
"\n",
"x_q = x_q.values\n",
"w_q = w_q.values\n",
"b_q = b_q.values"
]
},
{
"cell_type": "markdown",
"id": "43e47369",
"metadata": {},
"source": [
"### Simplification to rescue!\n",
"\n",
"The `y_q` formula in `QuantizedArray.affine(...)` can be rewritten to make it easier to implement in homomorphically. Here is the breakdown.\n",
"```\n",
"(q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b)) - (min_y * q_y)\n",
"^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^\n",
"constant (c1) can be done constant (c2) constant (c3) constant (c4)\n",
" on the circuit \n",
" \n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" can be done on the circuit\n",
" \n",
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"cannot be done on the circuit because of floating point operation so will be a single table lookup\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "2de0cf20",
"metadata": {},
"outputs": [],
"source": [
"c1 = q_y / (q_x * q_w)\n",
"c2 = w_q + zp_w\n",
"c3 = (q_x * q_w / q_b) * (b_q + zp_b)\n",
"c4 = min_y * q_y\n",
"\n",
"f = lambda intermediate: (c1 * (intermediate + c3)) - c4\n",
"f_q = QuantizedFunction.of(f, input_bits + parameter_bits, output_bits)\n",
"\n",
"from hdk.common.extensions.table import LookupTable\n",
"table = LookupTable([int(entry) for entry in f_q.table])\n",
"\n",
"w_0 = int(c2.flatten()[0])\n",
"\n",
"def infer(x_0):\n",
" return table[(x_0 + zp_x) * w_0]"
]
},
{
"cell_type": "markdown",
"id": "93eb9499",
"metadata": {},
"source": [
"### Time to compile our quantized inference function"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a80895fd",
"metadata": {},
"outputs": [],
"source": [
"from hdk.common.data_types.integers import Integer\n",
"from hdk.common.data_types.values import EncryptedValue\n",
"from hdk.hnumpy.compile import compile_numpy_function_into_op_graph\n",
"\n",
"dataset = []\n",
"for x_i in x_q:\n",
" dataset.append((int(x_i[0]),))\n",
"\n",
"homomorphic_model = compile_numpy_function_into_op_graph(\n",
" infer,\n",
" {\"x_0\": EncryptedValue(Integer(input_bits, is_signed=False))},\n",
" iter(dataset),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f0b08a0f",
"metadata": {},
"source": [
"### Here is the textual representation of the operation graph"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "2cc4e11d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"%0 = ConstantInput(1) # Integer<unsigned, 1 bits>\n",
"%1 = x_0 # Integer<unsigned, 6 bits>\n",
"%2 = ConstantInput(15) # Integer<unsigned, 4 bits>\n",
"%3 = Add(1, 2) # Integer<unsigned, 7 bits>\n",
"%4 = Mul(3, 0) # Integer<unsigned, 7 bits>\n",
"%5 = ArbitraryFunction(4) # Integer<unsigned, 7 bits>\n",
"return(%5)\n"
]
}
],
"source": [
"from hdk.common.debugging import get_printable_graph\n",
"print(get_printable_graph(homomorphic_model, show_data_types=True))"
]
},
{
"cell_type": "markdown",
"id": "ade14f17",
"metadata": {},
"source": [
"### Finally, it's time to make homomorphic inference\n",
"\n",
"Or, at least, simulate it until the compiler integration is complete."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "dd2d03d7",
"metadata": {},
"outputs": [],
"source": [
"homomorphic_predictions = []\n",
"for x_i in map(lambda x_i: int(x_i[0]), x_q):\n",
" evaluation = homomorphic_model.evaluate({0: x_i})\n",
" inference = QuantizedArray(evaluation[homomorphic_model.output_nodes[0]], y_q.parameters)\n",
" homomorphic_predictions.append(inference.dequantize())\n",
"homomorphic_predictions = np.array(homomorphic_predictions, dtype=np.float32)"
]
},
{
"cell_type": "markdown",
"id": "443fbc03",
"metadata": {},
"source": [
"### And visualize it"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "57050b5d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax.plot(inputs, homomorphic_predictions, color=\"green\")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "53ecca94",
"metadata": {},
"source": [
"### Enjoy!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}