mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-04-23 03:00:32 -04:00
* add circuits and tests for separable convolution. Circuits do not yet comply with repo`s quantization * make depthwise circuit compliant with quantization method from repo * make pointwise circuit compliant with quantization method from repo * separable convolution test works * clean up * fix typos and skip failing test * clean up duplicated code for depthwise conv * clean up duplicated code for pointwise conv * clean up duplicated code for separable conv notebook * chore: update filename to capital case --------- Co-authored-by: drCathieSo.eth <socathie@users.noreply.github.com>
194 lines
6.4 KiB
Plaintext
194 lines
6.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "4d60427f-21e9-41b1-a5eb-0d36d2c395ea",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import numpy as np\n",
|
|
"import json"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "b1962e3f-18b6-43b2-88f8-e81a49f4edbc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"p = 21888242871839275222246405745257275088548364400416034343698204186575808495617\n",
|
|
"CIRCOM_PRIME = 21888242871839275222246405745257275088548364400416034343698204186575808495617\n",
|
|
"MAX_POSITIVE = CIRCOM_PRIME // 2\n",
|
|
"MAX_NEGATIVE = MAX_POSITIVE + 1 # The most positive number\n",
|
|
"\n",
|
|
"EXPONENT = 15\n",
|
|
"\n",
|
|
"def from_circom(x):\n",
|
|
" if type(x) != int:\n",
|
|
" x = int(x)\n",
|
|
" if x > MAX_POSITIVE: \n",
|
|
" return x - CIRCOM_PRIME\n",
|
|
" return x\n",
|
|
" \n",
|
|
"def to_circom(x):\n",
|
|
" if type(x) != int:\n",
|
|
" x = int(x)\n",
|
|
" if x < 0:\n",
|
|
" return x + CIRCOM_PRIME \n",
|
|
" return x\n",
|
|
"\n",
|
|
"class SeparableConv2D(nn.Module):\n",
|
|
" '''Separable convolution'''\n",
|
|
" def __init__(self, in_channels, out_channels, stride=1):\n",
|
|
" super(SeparableConv2D, self).__init__()\n",
|
|
" self.dw_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)\n",
|
|
" self.pw_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.dw_conv(x)\n",
|
|
" x = self.pw_conv(x)\n",
|
|
" return x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "a7ad1f77-24e0-470e-b4de-63234ac9542b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"input = torch.randn((1, 3, 5, 5))\n",
|
|
"model = SeparableConv2D(3, 6)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "88e91743-d234-4e55-bd65-f4a5b0f5b350",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def DepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):\n",
|
|
" assert(nFilters % nChannels == 0)\n",
|
|
" outRows = (nRows - kernelSize)//strides + 1\n",
|
|
" outCols = (nCols - kernelSize)//strides + 1\n",
|
|
" \n",
|
|
" # out = np.zeros((outRows, outCols, nFilters))\n",
|
|
" out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]\n",
|
|
" remainder = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]\n",
|
|
" # remainder = np.zeros((outRows, outCols, nFilters))\n",
|
|
" \n",
|
|
" for row in range(outRows):\n",
|
|
" for col in range(outCols):\n",
|
|
" for channel in range(nChannels):\n",
|
|
" for x in range(kernelSize):\n",
|
|
" for y in range(kernelSize):\n",
|
|
" out[row][col][channel] += int(input[row*strides+x, col*strides+y, channel]) * int(weights[x, y, channel])\n",
|
|
" \n",
|
|
" out[row][col][channel] += int(bias[channel])\n",
|
|
" remainder[row][col][channel] = str(int(out[row][col][channel] % n))\n",
|
|
" out[row][col][channel] = int(out[row][col][channel] // n)\n",
|
|
" \n",
|
|
" return out, remainder"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "e666c225-f618-43d4-b003-56f9b4699d2e",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"weights = model.dw_conv.weight.squeeze().detach().numpy()\n",
|
|
"bias = torch.zeros(weights.shape[0]).numpy()\n",
|
|
"\n",
|
|
"expected = model.dw_conv(input).detach().numpy()\n",
|
|
"\n",
|
|
"padded = F.pad(input, (1,1,1,1), \"constant\", 0)\n",
|
|
"padded = padded.squeeze().numpy().transpose((1, 2, 0))\n",
|
|
"weights = weights.transpose((1, 2, 0))\n",
|
|
"\n",
|
|
"quantized_image = padded * 10**EXPONENT\n",
|
|
"quantized_weights = weights * 10**EXPONENT\n",
|
|
"\n",
|
|
"actual, rem = DepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)\n",
|
|
"\n",
|
|
"expected = expected.squeeze().transpose((1, 2, 0))\n",
|
|
"expected = expected * 10**EXPONENT\n",
|
|
"\n",
|
|
"assert(np.allclose(expected, actual, atol=0.00001))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "904ce6c4-f1d4-43f3-80f0-5e3df61d5546",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"weights = model.dw_conv.weight.squeeze().detach().numpy()\n",
|
|
"bias = torch.zeros(weights.shape[0]).numpy()\n",
|
|
"\n",
|
|
"padded = F.pad(input, (1,1,1,1), \"constant\", 0)\n",
|
|
"padded = padded.squeeze().numpy().transpose((1, 2, 0))\n",
|
|
"weights = weights.transpose((1, 2, 0))\n",
|
|
"\n",
|
|
"quantized_image = padded * 10**EXPONENT\n",
|
|
"quantized_weights = weights * 10**EXPONENT\n",
|
|
"\n",
|
|
"out, remainder = DepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)\n",
|
|
"\n",
|
|
"circuit_in = quantized_image.round().astype(int).astype(str).tolist()\n",
|
|
"circuit_weights = quantized_weights.round().astype(int).astype(str).tolist()\n",
|
|
"circuit_bias = bias.round().astype(int).astype(str).tolist()\n",
|
|
"\n",
|
|
"input_json_path = \"depthwiseConv2D_input.json\"\n",
|
|
"with open(input_json_path, \"w\") as input_file:\n",
|
|
" json.dump({\"in\": circuit_in,\n",
|
|
" \"weights\": circuit_weights,\n",
|
|
" \"remainder\": remainder,\n",
|
|
" \"out\": out,\n",
|
|
" \"bias\": circuit_bias,\n",
|
|
" },\n",
|
|
" input_file)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "523588d7-4c81-4bb9-9dbd-e626b6d2a8a9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|