{ "cells": [ { "cell_type": "markdown", "id": "b760a0f6", "metadata": {}, "source": [ "# Quantized Linear Regression\n", "\n", "Currently, **Concrete** only supports unsigned integers up to 7-bits. Nevertheless, we want to evaluate a linear regression model with it. Luckily, we can make use of **quantization** to overcome this limitation!" ] }, { "cell_type": "markdown", "id": "253288cf", "metadata": {}, "source": [ "### Let's start by importing some libraries to develop our linear regression model" ] }, { "cell_type": "code", "execution_count": 1, "id": "6200ab62", "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "markdown", "id": "f43e2387", "metadata": {}, "source": [ "### And some helpers for visualization" ] }, { "cell_type": "code", "execution_count": 2, "id": "d104c8df", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "from IPython.display import display" ] }, { "cell_type": "markdown", "id": "53e676b8", "metadata": {}, "source": [ "### We need an inputset, a handcrafted one for simplicity" ] }, { "cell_type": "code", "execution_count": 3, "id": "d451e829", "metadata": {}, "outputs": [], "source": [ "x = np.array([[130], [110], [100], [145], [160], [185], [200], [80], [50]], dtype=np.float32)\n", "y = np.array([325, 295, 268, 400, 420, 500, 520, 220, 120], dtype=np.float32)" ] }, { "cell_type": "markdown", "id": "75f4fdb7", "metadata": {}, "source": [ "### Let's visualize our inputset to get a grasp of it" ] }, { "cell_type": "code", "execution_count": 4, "id": "2a124a62", "metadata": {}, "outputs": [], "source": [ "plt.ioff()\n", "fig, ax = plt.subplots(1)" ] }, { "cell_type": "code", "execution_count": 5, "id": "edcd361b", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax.scatter(x[:, 0], y, marker=\"x\", color=\"red\")\n", "display(fig)" ] }, { "cell_type": "markdown", "id": "5c8310ab", "metadata": {}, "source": [ "### Now, we need a model so let's define it\n", "\n", "The main purpose of this tutorial is not to train a linear regression model but to use it homomorphically. So we will not discuss about how the model is trained." ] }, { "cell_type": "code", "execution_count": 6, "id": "91d4a1da", "metadata": {}, "outputs": [], "source": [ "class Model:\n", " w = None\n", " b = None\n", "\n", " def fit(self, x, y):\n", " a = np.ones((x.shape[0], x.shape[1] + 1), dtype=np.float32)\n", " a[:, 1:] = x\n", "\n", " regularization_contribution = np.identity(x.shape[1] + 1, dtype=np.float32)\n", " regularization_contribution[0][0] = 0\n", "\n", " parameters = np.linalg.pinv(a.T @ a + regularization_contribution) @ a.T @ y\n", "\n", " self.b = parameters[0]\n", " self.w = parameters[1:].reshape(-1, 1)\n", "\n", " return self\n", "\n", " def evaluate(self, x):\n", " return x @ self.w + self.b" ] }, { "cell_type": "markdown", "id": "faa5247c", "metadata": {}, "source": [ "### And create one" ] }, { "cell_type": "code", "execution_count": 7, "id": "682fb2d8", "metadata": {}, "outputs": [], "source": [ "model = Model().fit(x, y)" ] }, { "cell_type": "markdown", "id": "084fb296", "metadata": {}, "source": [ "### Time to make some predictions" ] }, { "cell_type": "code", "execution_count": 8, "id": "4953b03e", "metadata": {}, "outputs": [], "source": [ "inputs = np.linspace(40, 210, 100).reshape(-1, 1)\n", "predictions = model.evaluate(inputs)" ] }, { "cell_type": "markdown", "id": "f28155cf", "metadata": {}, "source": [ "### Let's visualize our predictions to see how our model performs" ] }, { "cell_type": "code", "execution_count": 9, "id": "111574ed", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax.plot(inputs, predictions, color=\"blue\")\n", "display(fig)" ] }, { "cell_type": "markdown", "id": "23852861", "metadata": {}, "source": [ "### As a bonus let's inspect the model parameters" ] }, { "cell_type": "code", "execution_count": 10, "id": "7877cb2e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[2.669915]]\n", "-3.2335143\n" ] } ], "source": [ "print(model.w)\n", "print(model.b)" ] }, { "cell_type": "markdown", "id": "de63118c", "metadata": {}, "source": [ "They are floating point numbers and we can't directly work with them!" ] }, { "cell_type": "markdown", "id": "2d959640", "metadata": {}, "source": [ "### So, let's abstract quantization\n", "\n", "Here is a quick summary of quantization. We have a range of values and we want to represent them using small number of bits (n). To do this, we split the range into 2^n sections and map each section to a value. Here is a visualization of the process!" ] }, { "cell_type": "code", "execution_count": 11, "id": "9da2e1a4", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "
min(x)
min(x)
max(x)
max(x)
Map
to 0
Map...
Map
to 1
Map...
Distance
Between
Consecutive
Values
Distan...
Map
to 2
Map...
Map
to 3
Map...
(when n = 2)
(when n = 2)
0
0
= 1 / scale
= 1 / q
= 1 / scale...
x = (x   + zp  ) / q
x = (x   + zp  ) / q
q
q
x
x
x
x
zero point
zp = 2
zero point...
Viewer does not support full SVG 1.1
" ], "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import SVG\n", "SVG(filename=\"figures/QuantizationVisualized.svg\")" ] }, { "cell_type": "markdown", "id": "45d12e7a", "metadata": {}, "source": [ "If you want to learn more, head to https://intellabs.github.io/distiller/algo_quantization.html" ] }, { "cell_type": "code", "execution_count": 12, "id": "2541cdb7", "metadata": {}, "outputs": [], "source": [ "class QuantizationParameters:\n", " def __init__(self, q, zp, n):\n", " # q = scale factor = 1 / distance between consecutive values\n", " # zp = zero point which is used to determine the beginning of the quantized range\n", " # (quantized 0 = the beginning of the quantized range = zp * distance between consecutive values)\n", " # n = number of bits\n", " \n", " # e.g.,\n", " \n", " # n = 2\n", " # zp = 2\n", " # q = 0.66\n", " # distance between consecutive values = 1 / q = 1.5151\n", " \n", " # quantized 0 = zp / q = zp * distance between consecutive values = 3.0303\n", " # quantized 1 = quantized 0 + distance between consecutive values = 4.5454\n", " # quantized 2 = quantized 1 + distance between consecutive values = 6.0606\n", " # quantized 3 = quantized 2 + distance between consecutive values = 7.5757\n", " \n", " self.q = q\n", " self.zp = zp\n", " self.n = n\n", "\n", "class QuantizedArray:\n", " def __init__(self, values, parameters):\n", " # values = quantized values\n", " # parameters = parameters used during quantization\n", " \n", " # e.g.,\n", " \n", " # values = [1, 0, 2, 1]\n", " # parameters = QuantizationParameters(q=0.66, zp=2, n=2)\n", " \n", " # original array = [4.5454, 3.0303, 6.0606, 4.5454]\n", " \n", " self.values = np.array(values)\n", " self.parameters = parameters\n", "\n", " @staticmethod\n", " def of(x, n):\n", " if not isinstance(x, np.ndarray):\n", " x = np.array(x)\n", "\n", " min_x = x.min()\n", " max_x = x.max()\n", "\n", " if min_x == max_x: # encoding single valued arrays\n", " \n", " if min_x == 0.0: # encoding 0s\n", " \n", " # dequantization = (x_q + zp_x) / q_x = 0 --> q_x = 1 && zp_x = 0 && x_q = 0\n", " q_x = 1\n", " zp_x = 0\n", " x_q = np.zeros(x.shape, dtype=np.uint)\n", " \n", " elif min_x < 0.0: # encoding negative scalars\n", " \n", " # dequantization = (x_q + zp_x) / q_x = -x --> q_x = 1 / x & zp_x = -1 & x_q = 0\n", " q_x = abs(1 / min_x)\n", " zp_x = -1\n", " x_q = np.zeros(x.shape, dtype=np.uint)\n", " \n", " else: # encoding positive scalars\n", " \n", " # dequantization = (x_q + zp_x) / q_x = x --> q_x = 1 / x & zp_x = 0 & x_q = 1\n", " q_x = 1 / min_x\n", " zp_x = 0\n", " x_q = np.ones(x.shape, dtype=np.uint)\n", " \n", " else: # encoding multi valued arrays\n", " \n", " # distance between consecutive values = range of x / number of different quantized values = (max_x - min_x) / (2^n - 1)\n", " # q = 1 / distance between consecutive values\n", " q_x = (2**n - 1) / (max_x - min_x)\n", " \n", " # zp = what should be added to 0 to get min_x -> min_x = (0 + zp) / q -> zp = min_x * q\n", " zp_x = int(round(min_x * q_x))\n", " \n", " # x = (x_q + zp) / q -> x_q = (x * q) - zp\n", " x_q = ((q_x * x) - zp_x).round().astype(np.uint)\n", "\n", " return QuantizedArray(x_q, QuantizationParameters(q_x, zp_x, n))\n", "\n", " def dequantize(self):\n", " # x = (x_q + zp) / q\n", " # x = (x_q + zp) / q\n", " return (self.values.astype(np.float32) + float(self.parameters.zp)) / self.parameters.q\n", "\n", " def affine(self, w, b, min_y, max_y, n_y):\n", " # the formulas used in this method was derived from the following equations\n", " #\n", " # x = (x_q + zp_x) / q_x\n", " # w = (w_q + zp_w) / q_w\n", " # b = (b_q + zp_b) / q_b\n", " #\n", " # (x * w) + b = ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b)\n", " # = y = (y_q + zp_y) / q_y\n", " #\n", " # So, ((x_q + zp_x) / q_x) * ((w_q + zp_w) / q_w) + ((b_q + zp_b) / q_b) = (y_q + zp_y) / q_y\n", " # We can calculate zp_y and q_y from min_y, max_y, n_y. So, the only unknown is y_q and it can be solved.\n", "\n", " x_q = self.values\n", " w_q = w.values\n", " b_q = b.values\n", "\n", " q_x = self.parameters.q\n", " q_w = w.parameters.q\n", " q_b = b.parameters.q\n", "\n", " zp_x = self.parameters.zp\n", " zp_w = w.parameters.zp\n", " zp_b = b.parameters.zp\n", "\n", " q_y = (2**n_y - 1) / (max_y - min_y)\n", " zp_y = int(round(min_y * q_y))\n", "\n", " y_q = (q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b))\n", " y_q -= min_y * q_y\n", " y_q = y_q.round().clip(0, 2**n_y - 1).astype(np.uint)\n", "\n", " return QuantizedArray(y_q, QuantizationParameters(q_y, zp_y, n_y))\n", "\n", "class QuantizedFunction:\n", " def __init__(self, table):\n", " self.table = table\n", "\n", " @staticmethod\n", " def of(f, input_bits, output_bits):\n", " domain = np.array(range(2**input_bits), dtype=np.uint)\n", " table = f(domain).round().clip(0, 2**output_bits - 1).astype(np.uint)\n", " return QuantizedFunction(table)" ] }, { "cell_type": "markdown", "id": "ab82ae87", "metadata": {}, "source": [ "### Let's quantize our model parameters\n", "\n", "Since the parameters only consist of scalars, we can use a single bit quantization." ] }, { "cell_type": "code", "execution_count": 13, "id": "c8b08ef4", "metadata": {}, "outputs": [], "source": [ "parameter_bits = 1\n", "\n", "w_q = QuantizedArray.of(model.w, parameter_bits)\n", "b_q = QuantizedArray.of(model.b, parameter_bits)" ] }, { "cell_type": "markdown", "id": "e2528092", "metadata": {}, "source": [ "### And quantize our inputs" ] }, { "cell_type": "code", "execution_count": 14, "id": "affe644e", "metadata": {}, "outputs": [], "source": [ "input_bits = 6\n", "\n", "x_q = QuantizedArray.of(inputs, input_bits)" ] }, { "cell_type": "markdown", "id": "a5a50eb8", "metadata": {}, "source": [ "### Time to make quantized inference" ] }, { "cell_type": "code", "execution_count": 15, "id": "0fdfd3d9", "metadata": {}, "outputs": [], "source": [ "output_bits = 7\n", "\n", "min_y = predictions.min()\n", "max_y = predictions.max()\n", "y_q = x_q.affine(w_q, b_q, min_y, max_y, output_bits)\n", "\n", "quantized_predictions = y_q.dequantize()" ] }, { "cell_type": "markdown", "id": "5fb15eb4", "metadata": {}, "source": [ "### And visualize the results" ] }, { "cell_type": "code", "execution_count": 16, "id": "8076a406", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax.plot(inputs, quantized_predictions, color=\"black\")\n", "display(fig)" ] }, { "cell_type": "markdown", "id": "af6bc89e", "metadata": {}, "source": [ "### Now it's time to make the inference homomorphic" ] }, { "cell_type": "code", "execution_count": 17, "id": "cbda8067", "metadata": {}, "outputs": [], "source": [ "q_y = (2**output_bits - 1) / (max_y - min_y)\n", "zp_y = int(round(min_y * q_y))\n", "\n", "q_x = x_q.parameters.q\n", "q_w = w_q.parameters.q\n", "q_b = b_q.parameters.q\n", "\n", "zp_x = x_q.parameters.zp\n", "zp_w = w_q.parameters.zp\n", "zp_b = b_q.parameters.zp\n", "\n", "x_q = x_q.values\n", "w_q = w_q.values\n", "b_q = b_q.values" ] }, { "cell_type": "markdown", "id": "b8e95e3d", "metadata": {}, "source": [ "### Simplification to rescue!\n", "\n", "The `y_q` formula in `QuantizedArray.affine(...)` can be rewritten to make it easier to implement in homomorphically. Here is the breakdown.\n", "```\n", "(q_y / (q_x * q_w)) * ((x_q + zp_x) @ (w_q + zp_w) + (q_x * q_w / q_b) * (b_q + zp_b)) - (min_y * q_y)\n", "^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^\n", "constant (c1) can be done constant (c2) constant (c3) constant (c4)\n", " on the circuit \n", " \n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " can be done on the circuit\n", " \n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "cannot be done on the circuit because of floating point operation so will be a single table lookup\n", "```" ] }, { "cell_type": "markdown", "id": "c6e101ae", "metadata": {}, "source": [ "### Let's import the Concrete numpy package now!" ] }, { "cell_type": "code", "execution_count": 18, "id": "4da7aed5", "metadata": {}, "outputs": [], "source": [ "import concrete.numpy as hnp" ] }, { "cell_type": "code", "execution_count": 19, "id": "d3816fa5", "metadata": {}, "outputs": [], "source": [ "c1 = q_y / (q_x * q_w)\n", "c2 = w_q + zp_w\n", "c3 = (q_x * q_w / q_b) * (b_q + zp_b)\n", "c4 = min_y * q_y\n", "\n", "f = lambda intermediate: (c1 * (intermediate + c3)) - c4\n", "f_q = QuantizedFunction.of(f, input_bits + parameter_bits, output_bits)\n", "\n", "table = hnp.LookupTable([int(entry) for entry in f_q.table])\n", "\n", "w_0 = int(c2.flatten()[0])\n", "\n", "def infer(x_0):\n", " return table[(x_0 + zp_x) * w_0]" ] }, { "cell_type": "markdown", "id": "01d67c28", "metadata": {}, "source": [ "### Let's compile our quantized inference function to it's homomorphic equivalent" ] }, { "cell_type": "code", "execution_count": 20, "id": "81304aca", "metadata": {}, "outputs": [], "source": [ "inputset = []\n", "for x_i in x_q:\n", " inputset.append((int(x_i[0]),))\n", "\n", "circuit = hnp.compile_numpy_function(\n", " infer,\n", " {\"x_0\": hnp.EncryptedScalar(hnp.Integer(input_bits, is_signed=False))},\n", " inputset,\n", ")" ] }, { "cell_type": "markdown", "id": "c62af039", "metadata": {}, "source": [ "### Here are some representations of the fhe circuit" ] }, { "cell_type": "code", "execution_count": 21, "id": "0c533af6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "%0 = Constant(1) # ClearScalar>\n", "%1 = x_0 # EncryptedScalar>\n", "%2 = Constant(15) # ClearScalar>\n", "%3 = Add(1, 2) # EncryptedScalar>\n", "%4 = Mul(3, 0) # EncryptedScalar>\n", "%5 = TLU(4) # EncryptedScalar>\n", "return(%5)\n", "\n" ] } ], "source": [ "print(circuit)" ] }, { "cell_type": "code", "execution_count": 22, "id": "c1fc0f48", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from PIL import Image\n", "file = Image.open(circuit.draw())\n", "file.show()\n", "file.close()" ] }, { "cell_type": "markdown", "id": "46753da7", "metadata": {}, "source": [ "### Finally, let's make homomorphic inference" ] }, { "cell_type": "code", "execution_count": 23, "id": "c0b246f7", "metadata": {}, "outputs": [], "source": [ "homomorphic_predictions = []\n", "for x_i in map(lambda x_i: int(x_i[0]), x_q):\n", " inference = QuantizedArray(circuit.run(x_i), y_q.parameters)\n", " homomorphic_predictions.append(inference.dequantize())\n", "homomorphic_predictions = np.array(homomorphic_predictions, dtype=np.float32)" ] }, { "cell_type": "markdown", "id": "68f67b3f", "metadata": {}, "source": [ "### And visualize it" ] }, { "cell_type": "code", "execution_count": 24, "id": "92c7f2f5", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax.plot(inputs, homomorphic_predictions, color=\"green\")\n", "display(fig)" ] }, { "cell_type": "markdown", "id": "c18dbdd1", "metadata": {}, "source": [ "### Enjoy!" ] } ], "metadata": { "execution": { "timeout": 10800 } }, "nbformat": 4, "nbformat_minor": 5 }