Files
concrete/docs/user/advanced_examples/QuantizedGeneralizedLinearModel.ipynb
Benoit Chevallier-Mames 2f1e41e4fb docs: updating the doc a bit
refs #1050
2021-12-03 15:25:14 +01:00

776 lines
115 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "b760a0f6",
"metadata": {},
"source": [
"# FIXME(Andrei): To be done with 979\n",
"\n",
"FIXME(Andrei): to be done with 979!"
]
},
{
"cell_type": "markdown",
"id": "253288cf",
"metadata": {},
"source": [
"### Let's start by importing some libraries to develop our linear regression model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6200ab62",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "markdown",
"id": "f43e2387",
"metadata": {},
"source": [
"### And some helpers for visualization"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d104c8df",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import display"
]
},
{
"cell_type": "markdown",
"id": "53e676b8",
"metadata": {},
"source": [
"### We need an inputset, a handcrafted one for simplicity"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d451e829",
"metadata": {},
"outputs": [],
"source": [
"x = np.array([[69], [130], [110], [100], [145], [160], [185], [200], [80], [50]], dtype=np.float32)\n",
"y = np.array([181, 325, 295, 268, 400, 420, 500, 520, 220, 120], dtype=np.float32)"
]
},
{
"cell_type": "markdown",
"id": "75f4fdb7",
"metadata": {},
"source": [
"### Let's visualize our inputset to get a grasp of it"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2a124a62",
"metadata": {},
"outputs": [],
"source": [
"plt.ioff()\n",
"fig, ax = plt.subplots(1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "edcd361b",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax.scatter(x[:, 0], y, marker=\"x\", color=\"red\")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "5c8310ab",
"metadata": {},
"source": [
"### Now, we need a model so let's define it\n",
"\n",
"The main purpose of this tutorial is not to train a linear regression model but to use it homomorphically. So we will not discuss about how the model is trained."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "91d4a1da",
"metadata": {},
"outputs": [],
"source": [
"class Model:\n",
" w = None\n",
" b = None\n",
"\n",
" def fit(self, x, y):\n",
" a = np.ones((x.shape[0], x.shape[1] + 1), dtype=np.float32)\n",
" a[:, 1:] = x\n",
"\n",
" regularization_contribution = np.identity(x.shape[1] + 1, dtype=np.float32)\n",
" regularization_contribution[0][0] = 0\n",
"\n",
" parameters = np.linalg.pinv(a.T @ a + regularization_contribution) @ a.T @ y\n",
"\n",
" self.b = parameters[0]\n",
" self.w = parameters[1:].reshape(-1, 1)\n",
"\n",
" return self\n",
"\n",
" def evaluate(self, x):\n",
" return x @ self.w + self.b"
]
},
{
"cell_type": "markdown",
"id": "faa5247c",
"metadata": {},
"source": [
"### And create one"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "682fb2d8",
"metadata": {},
"outputs": [],
"source": [
"model = Model().fit(x, y)"
]
},
{
"cell_type": "markdown",
"id": "084fb296",
"metadata": {},
"source": [
"### Time to make some predictions"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4953b03e",
"metadata": {},
"outputs": [],
"source": [
"inputs = np.linspace(40, 210, 100).reshape(-1, 1)\n",
"predictions = model.evaluate(inputs)"
]
},
{
"cell_type": "markdown",
"id": "f28155cf",
"metadata": {},
"source": [
"### Let's visualize our predictions to see how our model performs"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "111574ed",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax.plot(inputs, predictions, color=\"blue\")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "23852861",
"metadata": {},
"source": [
"### As a bonus let's inspect the model parameters"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7877cb2e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[2.6698928]]\n",
"-3.2299957\n"
]
}
],
"source": [
"print(model.w)\n",
"print(model.b)"
]
},
{
"cell_type": "markdown",
"id": "de63118c",
"metadata": {},
"source": [
"They are floating point numbers and we can't directly work with them!"
]
},
{
"cell_type": "markdown",
"id": "2d959640",
"metadata": {},
"source": [
"### So, let's abstract quantization\n",
"\n",
"Here is a quick summary of quantization. We have a range of values and we want to represent them using small number of bits (n). To do this, we split the range into 2^n sections and map each section to a value. Here is a visualization of the process!"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9da2e1a4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" version=\"1.1\" width=\"420px\" height=\"195px\" viewBox=\"-0.5 -0.5 420 195\" content=\"&lt;mxfile host=&quot;app.diagrams.net&quot; modified=&quot;2021-08-13T09:47:25.144Z&quot; agent=&quot;5.0 (X11)&quot; etag=&quot;5QhM0DGu1eUjmjeXuyFL&quot; version=&quot;14.9.6&quot; type=&quot;device&quot;&gt;&lt;diagram id=&quot;6rZNNX4_K12e_kCXuZoG&quot; name=&quot;Page-1&quot;&gt;7Zzdb5s6FMD/mkjdw66MHZL0sUl376SrStM6rc8euAGNYAZOk/avnw2YD5t8QAmhJA+tyLFzbM7v2D7HdjtCi9X2vxAHzgO1iTeCwN6O0P0IQhPd8t9C8JoI0MRMBMvQtRORkQse3TeSCkEqXbs2iUoVGaUec4Oy0KK+TyxWkuEwpJtytWfqlVsN8JJogkcLe7r0ybWZk0hnJsjlX4m7dGTLBkhLVlhWTgWRg226KYjQlxFahJSy5Gm1XRBP2E7aJfnevztKs46FxGfHfOFpYrx9//Hw//Lpq8mc+yd4H5mfUy0v2FunL5x2lr1KCxDfvhOG5J8sD0eRa43QPGI4ZLrYYSuPCwz+mOghtmbevL9GZgXuPYSuCAtfeZVNbmdpZqdgYikLiYeZ+1JWj1Pcy0xd1sI36vKGIUg905ikelLHhGNQVhHRdWiR9FtFuyqKTHhAETfVkjBNEX8ovHYuirHVQAhrIfSpTz4WKAgU+5oNQWmKULeg0NBBjdsCpSrqGNR46KBmbYFSFXUMyhw4KKQuLU1BaYo6BjUZOig1mGgMSlXUMajpwEGN2womNEUdg5oNHVRbwYSmqGNQt0MH1VYwoSnqGJTccaidC38gWFoKBKfNYGmRn6poByxuPfxaqBaICtGeDquBC5ju75e6fpbr84ekB+16DmzoOR9/F0Xzg+ltM4dCqmeqik49+utl50NieND0RzNUl9quGdZL3AfF8JDpj2Z4aECfmmG9nH5QDNuaS7XcpGuGVen+xOPWmtvuC39ciseV699sP8kC3lChTEPOyJaVaUYspL/Jgno0zOPmZ9fzFBH23KUvXIJDJlw+fyEhcy3s3aUFK9e2RTPzjeMy8hhgS7S5CXHAZSFd+zYRLwuybgkFZFvXh3YFKNJ5Cj42rvAxCHa7U4lfbVh6yr/C25jMUBloA+3sDBpl8zaOnNgsRpmKkH/DjFvajyUQoIyVPMOEH2A+hLdlTAiMm82HqiI4UxSdej6s2gMozIcFzJM/a3EePH+mPvscxafhd7yCAYJtTEyW5xNlokfUf5eiBz7U8qk4UVduQp2he/YCjI7i7DZR9ytspVdHmqKXE6XoeHqfwmhr8ZqqI3J23olTKr6OrFOPLOM6sk45stQc/Pwjq97NoAsJSdRz1MYhiXYg23FIImle8e47fW2OV83yusbb6ArL4PG2lVCoijrHW+8+xZA2ybTjXDj+ZzJphlE7cazQdWqSVRcuGsRtWpx173LWPo9hCmFbUjInbEOIrxcsqB8Rax3D0Qp/ig5GO6O3XsZpo/fHZaq7GdPcRc4VmqGqtbtnOcMwkh5Y8x16OQpOka2Me5atoKrT25551DCGBLoOieoh0betMaSnADcbR6z7QPzEEgD7espzis1LJdpD8n5ZgZABO0Wkn3WDy+Fh9O0UDh11bF17OoVV4Xky/EQqLfrM1cR/cdh4eThl1/5cZ/zq1GByVGpgGJ36cAv5ZKVXbOWasdcX3ul5ENwIMV4JgP6vKEgKlc/cduL3W1BR9KnotmBnby/FR9U1L5tziw7a7SRbddXh/TNZ5R7IdebaEavKCwmZV5hnXnkPXJDIpqAL52Ya/eImG7ty289t2rPxNq463NW4vZFQZPxBsvetbggfHZWKdVqmm5fuCerf3xjy4KjgCbN2PIF/zP+3SHJckf+DFvTlLw==&lt;/diagram&gt;&lt;/mxfile&gt;\"><defs/><g><path d=\"M 14.37 84 L 361.63 84\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 9.12 84 L 16.12 80.5 L 14.37 84 L 16.12 87.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 366.88 84 L 359.88 87.5 L 361.63 84 L 359.88 80.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 48 94 L 48 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 88 94 L 88 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 128 94 L 128 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 168 94 L 168 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 208 94 L 208 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 248 94 L 248 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 288 94 L 288 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 328 94 L 328 74\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 48 71 L 60.93 58.07 Q 68 51 78 51 L 98 51 Q 108 51 115.07 58.07 L 123.5 66.5\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 127.21 70.21 L 119.78 67.73 L 123.5 66.5 L 124.73 62.78 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 134.37 123 L 141.63 123\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 129.12 123 L 136.12 119.5 L 134.37 123 L 136.12 126.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 146.88 123 L 139.88 126.5 L 141.63 123 L 139.88 119.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 154.37 123 L 181.63 123\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 149.12 123 L 156.12 119.5 L 154.37 123 L 156.12 126.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 186.88 123 L 179.88 126.5 L 181.63 123 L 179.88 119.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 194.37 123 L 221.63 123\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 189.12 123 L 196.12 119.5 L 194.37 123 L 196.12 126.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 226.88 123 L 219.88 126.5 L 221.63 123 L 219.88 119.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 234.37 123 L 241.63 123\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 229.12 123 L 236.12 119.5 L 234.37 123 L 236.12 126.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 246.88 123 L 239.88 126.5 L 241.63 123 L 239.88 119.5 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><rect x=\"108\" y=\"94\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 109px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>min(x)</div></div></div></div></foreignObject><text x=\"128\" y=\"108\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\">min(x)</text></switch></g><rect x=\"228\" y=\"94\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 229px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">max(x)</div></div></div></foreignObject><text x=\"248\" y=\"108\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\">max(x)</text></switch></g><path d=\"M 138 148 L 138 128\" fill=\"none\" stroke=\"#000000\" stroke-width=\"2\" stroke-miterlimit=\"10\" stroke-dasharray=\"2 6\" pointer-events=\"stroke\"/><rect x=\"118\" y=\"152\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 119px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 0<br style=\"font-size: 10px\"/></font></div></div></div></div></foreignObject><text x=\"138\" y=\"165\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">Map...</text></switch></g><rect x=\"148\" y=\"152\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 149px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 1<br style=\"font-size: 10px\"/></font></div></div></div></div></foreignObject><text x=\"168\" y=\"165\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">Map...</text></switch></g><path d=\"M 168 148 L 168 128\" fill=\"none\" stroke=\"#000000\" stroke-width=\"2\" stroke-miterlimit=\"10\" stroke-dasharray=\"2 6\" pointer-events=\"stroke\"/><path d=\"M 208 148 L 208 128\" fill=\"none\" stroke=\"#000000\" stroke-width=\"2\" stroke-miterlimit=\"10\" stroke-dasharray=\"2 6\" pointer-events=\"stroke\"/><path d=\"M 238 148 L 238 128\" fill=\"none\" stroke=\"#000000\" stroke-width=\"2\" stroke-miterlimit=\"10\" stroke-dasharray=\"2 6\" pointer-events=\"stroke\"/><path d=\"M 294.37 68.66 L 321.63 68.66\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"stroke\"/><path d=\"M 289.12 68.66 L 296.12 65.16 L 294.37 68.66 L 296.12 72.16 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><path d=\"M 326.88 68.66 L 319.88 72.16 L 321.63 68.66 L 319.88 65.16 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"all\"/><rect x=\"288\" y=\"18.66\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 29px; margin-left: 289px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><font style=\"font-size: 10px\">Distance<br/>Between<br/>Consecutive<br/>Values</font></div></div></div></foreignObject><text x=\"308\" y=\"32\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\">Distan...</text></switch></g><rect x=\"188\" y=\"152\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 189px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 2</font></div></div></div></div></foreignObject><text x=\"208\" y=\"165\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">Map...</text></switch></g><rect x=\"218\" y=\"152\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 162px; margin-left: 219px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">Map</font></div><div style=\"font-size: 10px\"><font style=\"font-size: 10px\">to 3</font></div></div></div></div></foreignObject><text x=\"238\" y=\"165\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">Map...</text></switch></g><rect x=\"128\" y=\"174\" width=\"120\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 118px; height: 1px; padding-top: 184px; margin-left: 129px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">(when n = 2)</div></div></div></foreignObject><text x=\"188\" y=\"187\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">(when n = 2)</text></switch></g><rect x=\"28\" y=\"94\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 104px; margin-left: 29px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \">0</div></div></div></foreignObject><text x=\"48\" y=\"107\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">0</text></switch></g><rect x=\"308\" y=\"18.66\" width=\"110\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 108px; height: 1px; padding-top: 29px; margin-left: 309px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div><font style=\"font-size: 12px\">= 1 / scale</font></div><div><font style=\"font-size: 12px\">= 1 / q</font></div></div></div></div></foreignObject><text x=\"363\" y=\"32\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">= 1 / scale...</text></switch></g><rect x=\"128\" y=\"24\" width=\"140\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 138px; height: 1px; padding-top: 34px; margin-left: 129px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><font style=\"font-size: 12px\">x =</font><font style=\"font-size: 12px\"> (x   + zp  ) / q </font></div></div></div></foreignObject><text x=\"198\" y=\"37\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">x = (x   + zp  ) / q </text></switch></g><rect x=\"167\" y=\"29\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 168px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div><font style=\"font-size: 10px\">q</font></div></div></div></div></foreignObject><text x=\"187\" y=\"42\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">q</text></switch></g><rect x=\"199\" y=\"29\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 200px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>x</div></div></div></div></foreignObject><text x=\"219\" y=\"42\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">x</text></switch></g><rect x=\"227\" y=\"29\" width=\"40\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 38px; height: 1px; padding-top: 39px; margin-left: 228px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>x</div></div></div></div></foreignObject><text x=\"247\" y=\"42\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">x</text></switch></g><rect x=\"48\" y=\"28\" width=\"80\" height=\"20\" fill=\"none\" stroke=\"none\" pointer-events=\"all\"/><g transform=\"translate(-0.5 -0.5)\"><switch><foreignObject style=\"overflow: visible; text-align: left;\" pointer-events=\"none\" width=\"100%\" height=\"100%\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"><div xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: flex; align-items: unsafe center; justify-content: unsafe center; width: 78px; height: 1px; padding-top: 38px; margin-left: 49px;\"><div style=\"box-sizing: border-box; font-size: 0; text-align: center; \"><div style=\"display: inline-block; font-size: 10px; font-family: Helvetica; color: #000000; line-height: 1.2; pointer-events: all; white-space: normal; word-wrap: normal; \"><div>zero point<br/></div><div>zp = 2</div></div></div></div></foreignObject><text x=\"88\" y=\"41\" fill=\"#000000\" font-family=\"Helvetica\" font-size=\"10px\" text-anchor=\"middle\">zero point...</text></switch></g></g><switch><g requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"/><a transform=\"translate(0,-5)\" xlink:href=\"https://www.diagrams.net/doc/faq/svg-export-text-problems\" target=\"_blank\"><text text-anchor=\"middle\" font-size=\"10px\" x=\"50%\" y=\"100%\">Viewer does not support full SVG 1.1</text></a></switch></svg>",
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import SVG\n",
"SVG(filename=\"figures/QuantizationVisualized.svg\")"
]
},
{
"cell_type": "markdown",
"id": "45d12e7a",
"metadata": {},
"source": [
"If you want to learn more, head to https://intellabs.github.io/distiller/algo_quantization.html"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2541cdb7",
"metadata": {},
"outputs": [],
"source": [
"class QuantizationParameters:\n",
" def __init__(self, q, zp, n):\n",
" # q = scale factor = 1 / distance between consecutive values\n",
" # zp = zero point which is used to determine the beginning of the quantized range\n",
" # (quantized 0 = the beginning of the quantized range = zp * distance between consecutive values)\n",
" # n = number of bits\n",
" \n",
" # e.g.,\n",
" \n",
" # n = 2\n",
" # zp = 2\n",
" # q = 0.66\n",
" # distance between consecutive values = 1 / q = 1.5151\n",
" \n",
" # quantized 0 = zp / q = zp * distance between consecutive values = 3.0303\n",
" # quantized 1 = quantized 0 + distance between consecutive values = 4.5454\n",
" # quantized 2 = quantized 1 + distance between consecutive values = 6.0606\n",
" # quantized 3 = quantized 2 + distance between consecutive values = 7.5757\n",
" \n",
" self.q = q\n",
" self.zp = zp\n",
" self.n = n\n",
"\n",
"class QuantizedArray:\n",
" def __init__(self, values, parameters):\n",
" # values = quantized values\n",
" # parameters = parameters used during quantization\n",
" \n",
" # e.g.,\n",
" \n",
" # values = [1, 0, 2, 1]\n",
" # parameters = QuantizationParameters(q=0.66, zp=2, n=2)\n",
" \n",
" # original array = [4.5454, 3.0303, 6.0606, 4.5454]\n",
" \n",
" self.values = np.array(values)\n",
" self.parameters = parameters\n",
"\n",
" @staticmethod\n",
" def of(x, n):\n",
" if not isinstance(x, np.ndarray):\n",
" x = np.array(x)\n",
"\n",
" min_x = x.min()\n",
" max_x = x.max()\n",
"\n",
" if min_x == max_x: # encoding single valued arrays\n",
" \n",
" if min_x == 0.0: # encoding 0s\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = 0 --> q_x = 1 && zp_x = 0 && x_q = 0\n",
" q_x = 1\n",
" zp_x = 0\n",
" x_q = np.zeros(x.shape, dtype=np.uint)\n",
" \n",
" elif min_x < 0.0: # encoding negative scalars\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = -x --> q_x = 1 / x & zp_x = -1 & x_q = 0\n",
" q_x = abs(1 / min_x)\n",
" zp_x = -1\n",
" x_q = np.zeros(x.shape, dtype=np.uint)\n",
" \n",
" else: # encoding positive scalars\n",
" \n",
" # dequantization = (x_q + zp_x) / q_x = x --> q_x = 1 / x & zp_x = 0 & x_q = 1\n",
" q_x = 1 / min_x\n",
" zp_x = 0\n",
" x_q = np.ones(x.shape, dtype=np.uint)\n",
" \n",
" else: # encoding multi valued arrays\n",
" \n",
" # distance between consecutive values = range of x / number of different quantized values = (max_x - min_x) / (2^n - 1)\n",
" # q = 1 / distance between consecutive values\n",
" q_x = (2**n - 1) / (max_x - min_x)\n",
" \n",
" # zp = what should be added to 0 to get min_x -> min_x = (0 + zp) / q -> zp = min_x * q\n",
" zp_x = int(round(min_x * q_x))\n",
" \n",
" # x = (x_q + zp) / q -> x_q = (x * q) - zp\n",
" x_q = ((q_x * x) - zp_x).round().astype(np.uint)\n",
"\n",
" return QuantizedArray(x_q, QuantizationParameters(q_x, zp_x, n))\n",
"\n",
" def dequantize(self):\n",
" # x = (x_q + zp) / q\n",
" # x = (x_q + zp) / q\n",
" return (self.values.astype(np.float32) + float(self.parameters.zp)) / self.parameters.q\n",
"\n",
" def affine(self, w, b, min_y, max_y, n_y):\n",
" # the formulas used in this method was derived from the following equations\n",
" #\n",
" # x = (x_q + zp_x) / q_x\n",
" # w = (w_q + zp_w) / q_w\n",
" # b = (b_q + zp_b) / q_b\n",
" #\n",
" # (x * w) + b = ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b)\n",
" # = y = (y_q + zp_y) / q_y\n",
" #\n",
" # So, ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b) = (y_q + zp_y) / q_y\n",
" # We can calculate zp_y and q_y from min_y, max_y, n_y. So, the only unknown is y_q and it can be solved.\n",
"\n",
" x_q = self.values\n",
" w_q = w.values\n",
" b_q = b.values\n",
"\n",
" q_x = self.parameters.q\n",
" q_w = w.parameters.q\n",
" q_b = b.parameters.q\n",
"\n",
" zp_x = self.parameters.zp\n",
" zp_w = w.parameters.zp\n",
" zp_b = b.parameters.zp\n",
"\n",
" q_y = (2**n_y - 1) / (max_y - min_y)\n",
" zp_y = int(round(min_y * q_y))\n",
"\n",
" y_q = (q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b))\n",
" y_q -= min_y * q_y\n",
" y_q = y_q.round().clip(0, 2**n_y - 1).astype(np.uint)\n",
"\n",
" return QuantizedArray(y_q, QuantizationParameters(q_y, zp_y, n_y))\n",
"\n",
"class QuantizedFunction:\n",
" def __init__(self, table):\n",
" self.table = table\n",
"\n",
" @staticmethod\n",
" def of(f, input_bits, output_bits):\n",
" domain = np.array(range(2**input_bits), dtype=np.uint)\n",
" table = f(domain).round().clip(0, 2**output_bits - 1).astype(np.uint)\n",
" return QuantizedFunction(table)"
]
},
{
"cell_type": "markdown",
"id": "ab82ae87",
"metadata": {},
"source": [
"### Let's quantize our model parameters\n",
"\n",
"Since the parameters only consist of scalars, we can use a single bit quantization."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "c8b08ef4",
"metadata": {},
"outputs": [],
"source": [
"parameter_bits = 1\n",
"\n",
"w_q = QuantizedArray.of(model.w, parameter_bits)\n",
"b_q = QuantizedArray.of(model.b, parameter_bits)"
]
},
{
"cell_type": "markdown",
"id": "e2528092",
"metadata": {},
"source": [
"### And quantize our inputs"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "affe644e",
"metadata": {},
"outputs": [],
"source": [
"input_bits = 6\n",
"\n",
"x_q = QuantizedArray.of(inputs, input_bits)"
]
},
{
"cell_type": "markdown",
"id": "a5a50eb8",
"metadata": {},
"source": [
"### Time to make quantized inference"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0fdfd3d9",
"metadata": {},
"outputs": [],
"source": [
"output_bits = 7\n",
"\n",
"min_y = predictions.min()\n",
"max_y = predictions.max()\n",
"y_q = x_q.affine(w_q, b_q, min_y, max_y, output_bits)\n",
"\n",
"quantized_predictions = y_q.dequantize()"
]
},
{
"cell_type": "markdown",
"id": "5fb15eb4",
"metadata": {},
"source": [
"### And visualize the results"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8076a406",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax.plot(inputs, quantized_predictions, color=\"black\")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "af6bc89e",
"metadata": {},
"source": [
"### Now it's time to make the inference homomorphic"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "cbda8067",
"metadata": {},
"outputs": [],
"source": [
"q_y = (2**output_bits - 1) / (max_y - min_y)\n",
"zp_y = int(round(min_y * q_y))\n",
"\n",
"q_x = x_q.parameters.q\n",
"q_w = w_q.parameters.q\n",
"q_b = b_q.parameters.q\n",
"\n",
"zp_x = x_q.parameters.zp\n",
"zp_w = w_q.parameters.zp\n",
"zp_b = b_q.parameters.zp\n",
"\n",
"x_q = x_q.values\n",
"w_q = w_q.values\n",
"b_q = b_q.values"
]
},
{
"cell_type": "markdown",
"id": "b8e95e3d",
"metadata": {},
"source": [
"### Simplification to rescue!\n",
"\n",
"The `y_q` formula in `QuantizedArray.affine(...)` can be rewritten to make it easier to implement in homomorphically. Here is the breakdown.\n",
"```\n",
"(q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b)) - (min_y * q_y)\n",
"^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^\n",
"constant (c1) can be done constant (c2) constant (c3) constant (c4)\n",
" on the circuit \n",
" \n",
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
" can be done on the circuit\n",
" \n",
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
"cannot be done on the circuit because of floating point operation so will be a single table lookup\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "c6e101ae",
"metadata": {},
"source": [
"### Let's import the Concrete numpy package now!"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "4da7aed5",
"metadata": {},
"outputs": [],
"source": [
"import concrete.numpy as hnp"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "d3816fa5",
"metadata": {},
"outputs": [],
"source": [
"c1 = q_y / (q_x * q_w)\n",
"c2 = w_q + zp_w\n",
"c3 = (q_x * q_w / q_b) * (b_q + zp_b)\n",
"c4 = min_y * q_y\n",
"\n",
"f = lambda intermediate: (c1 * (intermediate + c3)) - c4\n",
"f_q = QuantizedFunction.of(f, input_bits + parameter_bits, output_bits)\n",
"\n",
"table = hnp.LookupTable([int(entry) for entry in f_q.table])\n",
"\n",
"w_0 = int(c2.flatten()[0])\n",
"\n",
"def infer(x_0):\n",
" return table[(x_0 + zp_x) * w_0]"
]
},
{
"cell_type": "markdown",
"id": "01d67c28",
"metadata": {},
"source": [
"### Let's compile our quantized inference function to it's homomorphic equivalent"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "81304aca",
"metadata": {},
"outputs": [],
"source": [
"inputset = [int(x_i[0]) for x_i in x_q]\n",
"\n",
"circuit = hnp.compile_numpy_function(\n",
" infer,\n",
" {\"x_0\": hnp.EncryptedScalar(hnp.Integer(input_bits, is_signed=False))},\n",
" inputset,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c62af039",
"metadata": {},
"source": [
"### Here are some representations of the fhe circuit"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "0c533af6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"%0 = Constant(1) # ClearScalar<Integer<unsigned, 8 bits>>\n",
"%1 = x_0 # EncryptedScalar<Integer<unsigned, 7 bits>>\n",
"%2 = Constant(15) # ClearScalar<Integer<unsigned, 8 bits>>\n",
"%3 = Add(1, 2) # EncryptedScalar<Integer<unsigned, 7 bits>>\n",
"%4 = Mul(3, 0) # EncryptedScalar<Integer<unsigned, 7 bits>>\n",
"%5 = TLU(4) # EncryptedScalar<Integer<unsigned, 7 bits>>\n",
"return(%5)\n",
"\n"
]
}
],
"source": [
"print(circuit)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "c1fc0f48",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=227x423 at 0x7FF17C8E64F0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from PIL import Image\n",
"file = Image.open(circuit.draw())\n",
"file.show()\n",
"file.close()"
]
},
{
"cell_type": "markdown",
"id": "46753da7",
"metadata": {},
"source": [
"### Finally, let's make homomorphic inference"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "c0b246f7",
"metadata": {},
"outputs": [],
"source": [
"homomorphic_predictions = []\n",
"for x_i in (int(x_i[0]) for x_i in x_q):\n",
" inference = QuantizedArray(circuit.run(x_i), y_q.parameters)\n",
" homomorphic_predictions.append(inference.dequantize())\n",
"homomorphic_predictions = np.array(homomorphic_predictions, dtype=np.float32)"
]
},
{
"cell_type": "markdown",
"id": "68f67b3f",
"metadata": {},
"source": [
"### And visualize it"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "92c7f2f5",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax.plot(inputs, homomorphic_predictions, color=\"green\")\n",
"display(fig)"
]
},
{
"cell_type": "markdown",
"id": "c18dbdd1",
"metadata": {},
"source": [
"### Enjoy!"
]
}
],
"metadata": {
"execution": {
"timeout": 10800
}
},
"nbformat": 4,
"nbformat_minor": 5
}