mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
422 lines
74 KiB
Plaintext
422 lines
74 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Fully Connected Neural Network\n",
|
|
"\n",
|
|
"In this example, we show how one can train a neural network on a specific task (here, Iris Classification) and use Concrete Numpy to make the model work in FHE settings."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch import nn\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Define our neural network"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class FCIris(torch.nn.Module):\n",
|
|
" \"\"\"Neural network for Iris classification\n",
|
|
" \n",
|
|
" We define a fully connected network with three (3) fully connected (fc) layers that \n",
|
|
" perform feature extraction and one (fc) layer to produce the final classification. \n",
|
|
" We will use 3 neurons on all layers to ensure that the FHE accumulators\n",
|
|
" do not overflow (we are currently only allowed a maximum of 7 bits-width).\n",
|
|
" More information on this is available at https://docs.zama.ai/concrete-numpy/main/user/howto/reduce_needed_precision.html#limitations-for-fhe-friendly-neural-network.\n",
|
|
"\n",
|
|
" Due to accumulator limits, we have to design a network with only a few neurons on each layer. \n",
|
|
" This is in contrast to a traditional approach where the number of neurons increases after \n",
|
|
" each layer or block.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, input_size):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" # The first layer processes the input data, in our case 4 dimensional vectors \n",
|
|
" self.linear1 = nn.Linear(input_size, 3)\n",
|
|
" self.sigmoid1 = nn.Sigmoid()\n",
|
|
" # Next, we add a one intermediate layer\n",
|
|
" self.linear2 = nn.Linear(3, 3)\n",
|
|
" self.sigmoid2 = nn.Sigmoid()\n",
|
|
" # Finally, we add the decision layer for 3 output classes encoded as one-hot vectors\n",
|
|
" self.decision = nn.Linear(3, 3)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
"\n",
|
|
" x = self.linear1(x)\n",
|
|
" x = self.sigmoid1(x)\n",
|
|
" x = self.linear2(x)\n",
|
|
" x = self.sigmoid2(x)\n",
|
|
" x = self.decision(x)\n",
|
|
"\n",
|
|
" return x\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Define all required variables to train the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get iris dataset\n",
|
|
"from sklearn.datasets import load_iris\n",
|
|
"X, y = load_iris(return_X_y=True)\n",
|
|
"\n",
|
|
"# Split into train and test\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)\n",
|
|
"\n",
|
|
"# Convert to tensors\n",
|
|
"X_train = torch.tensor(X_train).float()\n",
|
|
"X_test = torch.tensor(X_test).float()\n",
|
|
"y_train = torch.tensor(y_train)\n",
|
|
"y_test = torch.tensor(y_test)\n",
|
|
"\n",
|
|
"# Initialize our model\n",
|
|
"model = FCIris(X.shape[1])\n",
|
|
"\n",
|
|
"# Define our loss function\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"\n",
|
|
"# Define our optimizer\n",
|
|
"optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n",
|
|
"\n",
|
|
"# Define the number of iterations\n",
|
|
"n_iters = 50001\n",
|
|
"\n",
|
|
"# Define the batch size\n",
|
|
"batch_size = 16"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Train the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train():\n",
|
|
" for iter in range(n_iters):\n",
|
|
" # Get a random batch of training data\n",
|
|
" idx = torch.randperm(X_train.size()[0])\n",
|
|
" X_batch = X_train[idx][:batch_size]\n",
|
|
" y_batch = y_train[idx][:batch_size]\n",
|
|
" \n",
|
|
" # Forward pass\n",
|
|
" y_pred = model(X_batch)\n",
|
|
" \n",
|
|
" # Compute loss\n",
|
|
" loss = criterion(y_pred, y_batch)\n",
|
|
" \n",
|
|
" # Backward pass\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss.backward()\n",
|
|
" \n",
|
|
" # Update weights\n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" \n",
|
|
" if iter % 1000 == 0:\n",
|
|
" # Print epoch number, loss and accuracy\n",
|
|
" accuracy = torch.sum(torch.argmax(y_pred, dim=1) == y_batch).item() / y_batch.size()[0]\n",
|
|
" print(f'Iterations: {iter:02} | Loss: {loss.item():.4f} | Accuracy: {100*accuracy:.2f}%')\n",
|
|
" if accuracy == 1:\n",
|
|
" break"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Compile the model\n",
|
|
"\n",
|
|
"The `compile_torch_model` applies first a quantization to `model` with `n_bits` of precision using `X_train` as the calibration dataset and compile the model to its FHE counterparts. Here we use 3 bits of precision. In some edge cases, the network accumulators can overflow (i.e. extreme quantized values in both input and weights which is unlikely). In such a case, we need to retrain the model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Training a FHE friendly quantized network.\n",
|
|
"Iterations: 00 | Loss: 1.2000 | Accuracy: 18.75%\n",
|
|
"Iterations: 1000 | Loss: 0.5623 | Accuracy: 75.00%\n",
|
|
"Iterations: 2000 | Loss: 0.3556 | Accuracy: 87.50%\n",
|
|
"Iterations: 3000 | Loss: 0.0646 | Accuracy: 100.00%\n",
|
|
"Compiling the model to FHE.\n",
|
|
"The network is trained and FHE friendly.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from concrete.torch.compile import compile_torch_model\n",
|
|
"print(\"Training a FHE friendly quantized network.\")\n",
|
|
"for trial in range(10):\n",
|
|
" try:\n",
|
|
" train()\n",
|
|
" print(\"Compiling the model to FHE.\")\n",
|
|
" quantized_compiled_module = compile_torch_model(\n",
|
|
" model,\n",
|
|
" X_train,\n",
|
|
" n_bits=3,\n",
|
|
" )\n",
|
|
" print(\"The network is trained and FHE friendly.\")\n",
|
|
" break\n",
|
|
" except Exception as e:\n",
|
|
" if str(e).startswith(\"max_bit_width of some nodes is too high\"):\n",
|
|
" print(f'The network is not fully FHE friendly, retrain.')\n",
|
|
" train()\n",
|
|
" else:\n",
|
|
" raise e\n",
|
|
" break"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Predict with the torch model in clear"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y_pred = model(X_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Predict with the quantized model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# We now have a module in full numpy.\n",
|
|
"# Convert data to a numpy array.\n",
|
|
"X_train_numpy = X_train.numpy()\n",
|
|
"X_test_numpy = X_test.numpy()\n",
|
|
"y_train_numpy = y_train.numpy()\n",
|
|
"y_test_numpy = y_test.numpy()\n",
|
|
"\n",
|
|
"quant_model_predictions = quantized_compiled_module(X_test_numpy)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Predict in FHE"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 38/38 [03:03<00:00, 4.84s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from tqdm import tqdm\n",
|
|
"homomorphic_quant_predictions = []\n",
|
|
"for x_q in tqdm(X_test_numpy):\n",
|
|
" homomorphic_quant_predictions.append(\n",
|
|
" quantized_compiled_module.forward_fhe.run(np.array([x_q]).astype(np.uint8))\n",
|
|
" )\n",
|
|
"homomorphic_predictions = quantized_compiled_module.dequantize_output(\n",
|
|
" np.array(homomorphic_quant_predictions, dtype=np.float32).reshape(quant_model_predictions.shape)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Print the accuracy of both models"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Test Accuracy: 94.74%\n",
|
|
"Test Accuracy Quantized Inference: 89.47%\n",
|
|
"Test Accuracy Homomorphic Inference: 89.47%\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f'Test Accuracy: {100*(y_pred.argmax(1) == y_test).float().mean():.2f}%')\n",
|
|
"print(f'Test Accuracy Quantized Inference: {100*(quant_model_predictions.argmax(1) == y_test_numpy).mean():.2f}%')\n",
|
|
"print(f'Test Accuracy Homomorphic Inference: {100*(homomorphic_predictions.argmax(1) == y_test_numpy).mean():.2f}%') "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 0 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 864x432 with 2 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.decomposition import PCA\n",
|
|
"pca = PCA(n_components=2)\n",
|
|
"X_train_2d = pca.fit_transform(X_train_numpy)\n",
|
|
"\n",
|
|
"b_min = np.min(X_train_2d, axis=0)\n",
|
|
"b_max = np.max(X_train_2d, axis=0)\n",
|
|
"\n",
|
|
"grid_dims = tuple([np.linspace(b_min[i], b_max[i], 128) for i in range(X_train_2d.shape[1])])\n",
|
|
"ndgrid_tuple = np.meshgrid(*grid_dims)\n",
|
|
"grid_2d = np.vstack([g.ravel() for g in ndgrid_tuple]).transpose()\n",
|
|
"\n",
|
|
"grid_test = pca.inverse_transform(grid_2d)\n",
|
|
"\n",
|
|
"grid_pred_all = quantized_compiled_module(grid_test)\n",
|
|
"grid_pred_all_original = model(torch.tensor(grid_test).float()).detach().numpy()\n",
|
|
"\n",
|
|
"pred_classes = np.argmax(grid_pred_all, axis=1).astype(np.int32)\n",
|
|
"pred_classes_original = np.argmax(grid_pred_all_original, axis=1).astype(np.int32)\n",
|
|
"\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"\n",
|
|
"cmap = 'autumn'\n",
|
|
"# Create two subplots and set their locations\n",
|
|
"plt.clf()\n",
|
|
"fig, axs = plt.subplots(1, 2, figsize=(12, 6))\n",
|
|
"\n",
|
|
"# Plot original model contour plot\n",
|
|
"axs[0].contourf(ndgrid_tuple[0], ndgrid_tuple[1], pred_classes_original.reshape(ndgrid_tuple[0].shape), cmap=cmap)\n",
|
|
"\n",
|
|
"# Plot the scatter with marker borders\n",
|
|
"axs[0].scatter(X_train_2d[:, 0], X_train_2d[:, 1], c=y_train_numpy, s=50, edgecolors='k', cmap=cmap)\n",
|
|
"\n",
|
|
"# Add title and axis labels\n",
|
|
"axs[0].set_title('Original Inference')\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"# Plot quantized model contour plot\n",
|
|
"axs[1].contourf(ndgrid_tuple[0], ndgrid_tuple[1], pred_classes.reshape(ndgrid_tuple[0].shape), cmap=cmap)\n",
|
|
"\n",
|
|
"# Plot the scatter with marker borders\n",
|
|
"axs[1].scatter(X_train_2d[:, 0], X_train_2d[:, 1], c=y_train_numpy, s=50, edgecolors='k', cmap=cmap)\n",
|
|
"\n",
|
|
"# Add title and axis labels\n",
|
|
"axs[1].set_title('Quantized Inference')\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"plt.show()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In the above plot, we show the decision boundaries for both the original and quantized model. The quantized model has it's decision boundaries (colored regions) slightly shifted compared to the original model. This is due to the low bit quantization applied to the model in post training.\n",
|
|
"\n",
|
|
"Here we do not compute the contour plot for the FHE inference as this would be really costly but it should be pretty close to the quantized model. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Summary\n",
|
|
"\n",
|
|
"In this notebook, we presented a few steps to have a model (torch neural network) inference in over homomorphically encrypted data: \n",
|
|
"- We first trained a fully connected neural network yielding ~95% accuracy\n",
|
|
"- Then, we quantized it using Concrete Numpy. As we can see, the extreme post training quantization (only 3 bits of precision for weights, inputs and activations) made the neural network accuracy slightly drop (~89%).\n",
|
|
"- We then used the compiled inference into its FHE equivalent to get our FHE predictions over the test set\n",
|
|
"\n",
|
|
"The Homomorphic inference achieves a similar accuracy as the quantized model inference.\n",
|
|
"\n",
|
|
"Disclaimer: post training quantization with such a low bit width (<=3) can yield different results for the quantized model which will mainly depends on the range of the learned weights."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"execution": {
|
|
"timeout": 10800
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|