Files
circomlib-ml/models/depthwiseConv2D.ipynb
Semar Augusto 418917338a Add separable convolution circuit implementation (#7)
* add circuits and tests for separable convolution. Circuits do not yet comply with repo`s quantization

* make depthwise circuit compliant with quantization method from repo

* make pointwise circuit compliant with quantization method from repo

* separable convolution test works

* clean up

* fix typos and skip failing test

* clean up duplicated code for depthwise conv

* clean up duplicated code for pointwise conv

* clean up duplicated code for separable conv notebook

* chore: update filename to capital case

---------

Co-authored-by: drCathieSo.eth <socathie@users.noreply.github.com>
2023-11-26 18:53:14 +08:00

194 lines
6.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4d60427f-21e9-41b1-a5eb-0d36d2c395ea",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b1962e3f-18b6-43b2-88f8-e81a49f4edbc",
"metadata": {},
"outputs": [],
"source": [
"p = 21888242871839275222246405745257275088548364400416034343698204186575808495617\n",
"CIRCOM_PRIME = 21888242871839275222246405745257275088548364400416034343698204186575808495617\n",
"MAX_POSITIVE = CIRCOM_PRIME // 2\n",
"MAX_NEGATIVE = MAX_POSITIVE + 1 # The most positive number\n",
"\n",
"EXPONENT = 15\n",
"\n",
"def from_circom(x):\n",
" if type(x) != int:\n",
" x = int(x)\n",
" if x > MAX_POSITIVE: \n",
" return x - CIRCOM_PRIME\n",
" return x\n",
" \n",
"def to_circom(x):\n",
" if type(x) != int:\n",
" x = int(x)\n",
" if x < 0:\n",
" return x + CIRCOM_PRIME \n",
" return x\n",
"\n",
"class SeparableConv2D(nn.Module):\n",
" '''Separable convolution'''\n",
" def __init__(self, in_channels, out_channels, stride=1):\n",
" super(SeparableConv2D, self).__init__()\n",
" self.dw_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)\n",
" self.pw_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)\n",
"\n",
" def forward(self, x):\n",
" x = self.dw_conv(x)\n",
" x = self.pw_conv(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a7ad1f77-24e0-470e-b4de-63234ac9542b",
"metadata": {},
"outputs": [],
"source": [
"input = torch.randn((1, 3, 5, 5))\n",
"model = SeparableConv2D(3, 6)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "88e91743-d234-4e55-bd65-f4a5b0f5b350",
"metadata": {},
"outputs": [],
"source": [
"def DepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):\n",
" assert(nFilters % nChannels == 0)\n",
" outRows = (nRows - kernelSize)//strides + 1\n",
" outCols = (nCols - kernelSize)//strides + 1\n",
" \n",
" # out = np.zeros((outRows, outCols, nFilters))\n",
" out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]\n",
" remainder = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]\n",
" # remainder = np.zeros((outRows, outCols, nFilters))\n",
" \n",
" for row in range(outRows):\n",
" for col in range(outCols):\n",
" for channel in range(nChannels):\n",
" for x in range(kernelSize):\n",
" for y in range(kernelSize):\n",
" out[row][col][channel] += int(input[row*strides+x, col*strides+y, channel]) * int(weights[x, y, channel])\n",
" \n",
" out[row][col][channel] += int(bias[channel])\n",
" remainder[row][col][channel] = str(int(out[row][col][channel] % n))\n",
" out[row][col][channel] = int(out[row][col][channel] // n)\n",
" \n",
" return out, remainder"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e666c225-f618-43d4-b003-56f9b4699d2e",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"weights = model.dw_conv.weight.squeeze().detach().numpy()\n",
"bias = torch.zeros(weights.shape[0]).numpy()\n",
"\n",
"expected = model.dw_conv(input).detach().numpy()\n",
"\n",
"padded = F.pad(input, (1,1,1,1), \"constant\", 0)\n",
"padded = padded.squeeze().numpy().transpose((1, 2, 0))\n",
"weights = weights.transpose((1, 2, 0))\n",
"\n",
"quantized_image = padded * 10**EXPONENT\n",
"quantized_weights = weights * 10**EXPONENT\n",
"\n",
"actual, rem = DepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)\n",
"\n",
"expected = expected.squeeze().transpose((1, 2, 0))\n",
"expected = expected * 10**EXPONENT\n",
"\n",
"assert(np.allclose(expected, actual, atol=0.00001))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "904ce6c4-f1d4-43f3-80f0-5e3df61d5546",
"metadata": {},
"outputs": [],
"source": [
"weights = model.dw_conv.weight.squeeze().detach().numpy()\n",
"bias = torch.zeros(weights.shape[0]).numpy()\n",
"\n",
"padded = F.pad(input, (1,1,1,1), \"constant\", 0)\n",
"padded = padded.squeeze().numpy().transpose((1, 2, 0))\n",
"weights = weights.transpose((1, 2, 0))\n",
"\n",
"quantized_image = padded * 10**EXPONENT\n",
"quantized_weights = weights * 10**EXPONENT\n",
"\n",
"out, remainder = DepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)\n",
"\n",
"circuit_in = quantized_image.round().astype(int).astype(str).tolist()\n",
"circuit_weights = quantized_weights.round().astype(int).astype(str).tolist()\n",
"circuit_bias = bias.round().astype(int).astype(str).tolist()\n",
"\n",
"input_json_path = \"depthwiseConv2D_input.json\"\n",
"with open(input_json_path, \"w\") as input_file:\n",
" json.dump({\"in\": circuit_in,\n",
" \"weights\": circuit_weights,\n",
" \"remainder\": remainder,\n",
" \"out\": out,\n",
" \"bias\": circuit_bias,\n",
" },\n",
" input_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "523588d7-4c81-4bb9-9dbd-e626b6d2a8a9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}