v1.3.0 - added Conv1D layer, added strides compatibility to Conv2D

This commit is contained in:
Cathie So
2022-06-28 01:04:19 +08:00
parent 0a93f87abc
commit adb9eddff0
18 changed files with 723 additions and 19 deletions

40
circuits/Conv1D.circom Normal file
View File

@@ -0,0 +1,40 @@
pragma circom 2.0.3;
include "./circomlib-matrix/matElemMul.circom";
include "./circomlib-matrix/matElemSum.circom";
include "./util.circom";
// Conv1D layer with valid padding
template Conv1D (nInputs, nChannels, nFilters, kernelSize, strides) {
signal input in[nInputs][nChannels];
signal input weights[kernelSize][nChannels][nFilters];
signal input bias[nFilters];
signal output out[(nInputs-kernelSize)\strides+1][nFilters];
component mul[(nInputs-kernelSize)\strides+1][nChannels][nFilters];
component elemSum[(nInputs-kernelSize)\strides+1][nChannels][nFilters];
component sum[(nInputs-kernelSize)\strides+1][nFilters];
for (var i=0; i<(nInputs-kernelSize)\strides+1; i++) {
for (var j=0; j<nChannels; j++) {
for (var k=0; k<nFilters; k++) {
mul[i][j][k] = matElemMul(kernelSize,1);
for (var x=0; x<kernelSize; x++) {
mul[i][j][k].a[x][0] <== in[i*strides+x][j];
mul[i][j][k].b[x][0] <== weights[x][j][k];
}
elemSum[i][j][k] = matElemSum(kernelSize,1);
for (var x=0; x<kernelSize; x++) {
elemSum[i][j][k].a[x][0] <== mul[i][j][k].out[x][0];
}
}
}
for (var k=0; k<nFilters; k++) {
sum[i][k] = Sum(nChannels);
for (var j=0; j<nChannels; j++) {
sum[i][k].in[j] <== elemSum[i][j][k].out;
}
out[i][k] <== sum[i][k].out + bias[k];
}
}
}

View File

@@ -4,25 +4,25 @@ include "./circomlib-matrix/matElemMul.circom";
include "./circomlib-matrix/matElemSum.circom";
include "./util.circom";
// Conv2D layer with strides=0
template Conv2D (nRows, nCols, nChannels, nFilters, kernelSize) {
// Conv2D layer with valid padding
template Conv2D (nRows, nCols, nChannels, nFilters, kernelSize, strides) {
signal input in[nRows][nCols][nChannels];
signal input weights[kernelSize][kernelSize][nChannels][nFilters];
signal input bias[nFilters];
signal output out[nRows-kernelSize+1][nCols-kernelSize+1][nFilters];
signal output out[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nFilters];
component mul[nRows-kernelSize+1][nCols-kernelSize+1][nChannels][nFilters];
component elemSum[nRows-kernelSize+1][nCols-kernelSize+1][nChannels][nFilters];
component sum[nRows-kernelSize+1][nCols-kernelSize+1][nFilters];
component mul[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nChannels][nFilters];
component elemSum[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nChannels][nFilters];
component sum[(nRows-kernelSize)\strides+1][(nCols-kernelSize)\strides+1][nFilters];
for (var i=0; i<nRows-kernelSize+1; i++) {
for (var j=0; j<nCols-kernelSize+1; j++) {
for (var i=0; i<(nRows-kernelSize)\strides+1; i++) {
for (var j=0; j<(nCols-kernelSize)\strides+1; j++) {
for (var k=0; k<nChannels; k++) {
for (var m=0; m<nFilters; m++) {
mul[i][j][k][m] = matElemMul(kernelSize,kernelSize);
for (var x=0; x<kernelSize; x++) {
for (var y=0; y<kernelSize; y++) {
mul[i][j][k][m].a[x][y] <== in[i+x][j+y][k];
mul[i][j][k][m].a[x][y] <== in[i*strides+x][j*strides+y][k];
mul[i][j][k][m].b[x][y] <== weights[x][y][k][m];
}
}

1
models/conv1D_input.json Normal file
View File

@@ -0,0 +1 @@
{"in": [845, 408, 904, 651, 599, 600, 449, 255, 400, 577, 960, 82, 666, 4, 646, 855, 386, 851, 275, 201, 188, 901, 338, 848, 362, 394, 641, 350, 372, 484, 111, 605, 584, 871, 646, 484, 964, 294, 220, 376, 136, 772, 428, 411, 283, 666, 954, 664, 415, 375, 231, 543, 139, 853, 253, 377, 750, 755, 528, 462], "weights": [134, 152, 23, -53, 313, -297, 272, -150, 108, -367, 142, -291, -486, -424, 77, 323, -111, -40, 266, 8, 119, 295, -428, 32], "bias": [0.0, 0.0]}

View File

@@ -0,0 +1 @@
{"out": [721859, -487530, -65677, -480889, 216206, -453824, 155316, -596682, 230192, -17062, -51941, -495391]}

238
models/conv1d.ipynb Normal file
View File

@@ -0,0 +1,238 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Input, Conv1D\n",
"from tensorflow.keras import Model\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(20,3))\n",
"x = Conv1D(2,4,3)(inputs)\n",
"model = Model(inputs, x)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) [(None, 20, 3)] 0 \n",
"_________________________________________________________________\n",
"conv1d (Conv1D) (None, 6, 2) 26 \n",
"=================================================================\n",
"Total params: 26\n",
"Trainable params: 26\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Variable 'conv1d/kernel:0' shape=(4, 3, 2) dtype=float32, numpy=\n",
" array([[[ 0.13351655, 0.15176392],\n",
" [ 0.02279764, -0.05288953],\n",
" [ 0.31283128, -0.2973013 ]],\n",
" \n",
" [[ 0.27222842, -0.14998454],\n",
" [ 0.10822219, -0.36693448],\n",
" [ 0.14209783, -0.29082325]],\n",
" \n",
" [[-0.48623034, -0.4235212 ],\n",
" [ 0.07678127, 0.32315302],\n",
" [-0.11050957, -0.03980196]],\n",
" \n",
" [[ 0.26589322, 0.00770217],\n",
" [ 0.11929381, 0.29532135],\n",
" [-0.42807332, 0.03231919]]], dtype=float32)>,\n",
" <tf.Variable 'conv1d/bias:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.weights"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[0.84452152, 0.40819381, 0.90425158],\n",
" [0.65092274, 0.59893334, 0.60028247],\n",
" [0.44949689, 0.254961 , 0.40024966],\n",
" [0.57719296, 0.96000249, 0.08217882],\n",
" [0.66562102, 0.00446413, 0.6464117 ],\n",
" [0.85546508, 0.38582714, 0.8505983 ],\n",
" [0.27450673, 0.201367 , 0.18818527],\n",
" [0.90095107, 0.33781381, 0.84773899],\n",
" [0.36179829, 0.39354172, 0.64100907],\n",
" [0.35045615, 0.37234526, 0.48415795],\n",
" [0.11139122, 0.60499841, 0.58387442],\n",
" [0.87088195, 0.64640523, 0.48402816],\n",
" [0.9638806 , 0.29440207, 0.22027009],\n",
" [0.3764113 , 0.13575624, 0.7720717 ],\n",
" [0.42784105, 0.41073501, 0.28311926],\n",
" [0.66622245, 0.95378854, 0.66375939],\n",
" [0.41543282, 0.37529526, 0.23072244],\n",
" [0.54334445, 0.1388458 , 0.85307472],\n",
" [0.25258015, 0.37725648, 0.75018739],\n",
" [0.75532678, 0.52800654, 0.46152891]]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,20,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[ 0.7218593 , -0.4875301 ],\n",
" [-0.06567734, -0.48088893],\n",
" [ 0.21620643, -0.45382428],\n",
" [ 0.1553162 , -0.5966817 ],\n",
" [ 0.23019192, -0.01706157],\n",
" [-0.05194061, -0.49539083]]], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": (X*1000).round().astype(int).flatten().tolist(),\n",
" \"weights\": (model.weights[0].numpy()*1000).round().astype(int).flatten().tolist(),\n",
" \"bias\": np.zeros(model.weights[1].numpy().shape).tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"out_json = {\n",
" \"out\": (y*1000000).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"with open(\"conv1D_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"with open(\"conv1D_output.json\", \"w\") as f:\n",
" json.dump(out_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf24",
"language": "python",
"name": "tf24"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1 @@
{"in": [973, 455, 27, 779, 573, 643, 791, 682, 559, 200, 164, 118, 262, 737, 415, 916, 189, 868, 905, 880, 569, 158, 662, 768, 912, 183, 835, 547, 665, 882, 760, 511, 931, 911, 564, 152, 37, 373, 105, 447, 543, 840, 795, 91, 357, 518, 561, 588, 847, 638, 660, 925, 632, 747, 731, 405, 980, 638, 895, 327, 718, 461, 718, 158, 810, 519, 161, 782, 715, 862, 535, 393, 409, 14, 791, 405, 854, 711, 403, 406, 349, 327, 459, 349, 534, 930, 727, 482, 222, 238, 704, 916, 922, 134, 835, 747, 551, 935, 413, 462, 117, 963, 746, 181, 927, 679, 843, 490, 562, 944, 920, 234, 287, 500, 877, 33, 594, 141, 379, 934, 333, 767, 35, 121, 69, 326, 597, 595, 10, 951, 478, 18, 797, 203, 334, 619, 975, 733, 203, 223, 546, 968, 964, 252, 399, 236, 546, 318, 551, 708, 235, 173, 473, 639, 872, 414, 431, 180, 288, 79, 425, 298, 779, 265, 194, 890, 177, 936, 955, 714, 22, 775, 211, 717, 211, 682, 993, 441, 351, 184, 727, 344, 453, 700, 742, 475, 907, 793, 925, 354, 277, 626, 803, 380, 886, 706, 735, 646, 762, 341, 252, 414, 812, 329, 819, 432, 642, 257, 757, 562, 421, 332, 893, 254, 75, 167, 421, 207, 818, 805, 627, 877, 665, 805, 20, 495, 286, 590, 743, 648, 284, 138, 29, 394, 727, 503, 112, 20, 542, 601, 538, 918, 390, 37, 925, 913, 347, 4, 175, 165, 398, 430, 523, 173, 81, 614, 508, 262, 361, 619, 4, 12, 702, 301, 563, 847, 882, 787, 70, 65, 314, 29, 109, 828, 132, 570, 340, 942, 938, 283, 529, 359, 948, 358, 255, 801, 951, 682, 23, 986, 691, 544, 206, 910, 985, 620, 898, 872, 771, 845], "weights": [90, 64, 172, 119, -231, 77, 56, 87, 109, -132, 72, -49, -144, -204, -186, -236, -135, -200, -38, 245, 96, 129, 139, -61, -154, -157, 175, 234, -94, 137, 139, -52, 123, 31, -228, 164, 141, 0, 199, 256, 174, -155, 212, 134, 239, 158, 12, -174, -134, -88, -182, 49, 191, -183, 244, 117, -15, 149, 1, 143, -60, -36, -44, -107, 178, 221, -140, -153, 37, -244, -257, -119, 92, -150, 97, -77, -102, -13, 197, -218, 196, 54, -239, 138, 153, 237, 252, 272, 114, -62, -14, 253, 161, 236, -116, -99], "bias": [0, 0]}

View File

@@ -0,0 +1 @@
{"out": [604066, 49637, 825491, 594827, 930285, 228053, 1282346, 594940, 897624, 379819, 1158895, 446232, 593408, 64606, 1260976, 251827, 504049, 658720]}

View File

@@ -241,6 +241,336 @@
" json.dump(out_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(10,10,3))\n",
"x = Conv2D(2, 4, 3)(inputs)\n",
"model = Model(inputs, x)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) [(None, 10, 10, 3)] 0 \n",
"_________________________________________________________________\n",
"conv2d (Conv2D) (None, 3, 3, 2) 98 \n",
"=================================================================\n",
"Total params: 98\n",
"Trainable params: 98\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Variable 'conv2d/kernel:0' shape=(4, 4, 3, 2) dtype=float32, numpy=\n",
" array([[[[ 9.01020467e-02, 6.39823377e-02],\n",
" [ 1.72438890e-01, 1.19038552e-01],\n",
" [-2.31111139e-01, 7.65160322e-02]],\n",
" \n",
" [[ 5.60960472e-02, 8.70717168e-02],\n",
" [ 1.09150052e-01, -1.31718382e-01],\n",
" [ 7.23880231e-02, -4.93483245e-02]],\n",
" \n",
" [[-1.44102752e-01, -2.04105571e-01],\n",
" [-1.85530186e-01, -2.36467183e-01],\n",
" [-1.34880558e-01, -1.99714184e-01]],\n",
" \n",
" [[-3.83580476e-02, 2.45451868e-01],\n",
" [ 9.63985324e-02, 1.29009217e-01],\n",
" [ 1.38856411e-01, -6.06727898e-02]]],\n",
" \n",
" \n",
" [[[-1.54422879e-01, -1.56738907e-01],\n",
" [ 1.74920231e-01, 2.34346807e-01],\n",
" [-9.43229645e-02, 1.37158900e-01]],\n",
" \n",
" [[ 1.38553441e-01, -5.22587299e-02],\n",
" [ 1.23419642e-01, 3.11977565e-02],\n",
" [-2.27711692e-01, 1.64391518e-01]],\n",
" \n",
" [[ 1.41024947e-01, 1.13755465e-04],\n",
" [ 1.99488789e-01, 2.56015778e-01],\n",
" [ 1.74426168e-01, -1.55410171e-01]],\n",
" \n",
" [[ 2.12402374e-01, 1.33800596e-01],\n",
" [ 2.38744080e-01, 1.57932788e-01],\n",
" [ 1.21993423e-02, -1.74463570e-01]]],\n",
" \n",
" \n",
" [[[-1.34224743e-01, -8.76786262e-02],\n",
" [-1.82140529e-01, 4.94003594e-02],\n",
" [ 1.90961450e-01, -1.82582170e-01]],\n",
" \n",
" [[ 2.43920863e-01, 1.16816819e-01],\n",
" [-1.51267648e-02, 1.49045289e-01],\n",
" [ 5.38736582e-04, 1.42932892e-01]],\n",
" \n",
" [[-6.00645095e-02, -3.57142389e-02],\n",
" [-4.42285985e-02, -1.07084349e-01],\n",
" [ 1.78276598e-01, 2.21306920e-01]],\n",
" \n",
" [[-1.39693484e-01, -1.52673334e-01],\n",
" [ 3.69200110e-02, -2.44445860e-01],\n",
" [-2.57119805e-01, -1.18823618e-01]]],\n",
" \n",
" \n",
" [[[ 9.21463966e-02, -1.50290370e-01],\n",
" [ 9.74731147e-02, -7.73224682e-02],\n",
" [-1.02264047e-01, -1.25368834e-02]],\n",
" \n",
" [[ 1.96688592e-01, -2.18126133e-01],\n",
" [ 1.96370363e-01, 5.42735457e-02],\n",
" [-2.39483550e-01, 1.38274789e-01]],\n",
" \n",
" [[ 1.53252542e-01, 2.36619353e-01],\n",
" [ 2.52445579e-01, 2.71820903e-01],\n",
" [ 1.13983452e-01, -6.24387860e-02]],\n",
" \n",
" [[-1.44442618e-02, 2.52830923e-01],\n",
" [ 1.60841286e-01, 2.35763371e-01],\n",
" [-1.15650281e-01, -9.92593616e-02]]]], dtype=float32)>,\n",
" <tf.Variable 'conv2d/bias:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.weights"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.97251485, 0.45513694, 0.02700021],\n",
" [0.77939851, 0.57285348, 0.64336143],\n",
" [0.7910932 , 0.68158499, 0.55903757],\n",
" [0.19978621, 0.16401386, 0.1178243 ],\n",
" [0.26202894, 0.73739062, 0.41515085],\n",
" [0.91602252, 0.1889969 , 0.86812371],\n",
" [0.90520645, 0.88021945, 0.56918999],\n",
" [0.1576814 , 0.66244825, 0.76843253],\n",
" [0.91151503, 0.18348028, 0.83457797],\n",
" [0.54744986, 0.6653615 , 0.88219377]],\n",
"\n",
" [[0.75996691, 0.51053919, 0.93055488],\n",
" [0.91144281, 0.56385737, 0.15222672],\n",
" [0.03655235, 0.37340661, 0.10463077],\n",
" [0.44692914, 0.54256831, 0.84009866],\n",
" [0.79509769, 0.09059503, 0.35705721],\n",
" [0.51836358, 0.56141664, 0.58831904],\n",
" [0.84739405, 0.63802925, 0.66017923],\n",
" [0.92534248, 0.63225917, 0.7473312 ],\n",
" [0.73096606, 0.40520727, 0.97955684],\n",
" [0.63831414, 0.89495898, 0.32695978]],\n",
"\n",
" [[0.71829301, 0.46062667, 0.71811823],\n",
" [0.15848069, 0.80969418, 0.51871761],\n",
" [0.16135808, 0.78216571, 0.71516724],\n",
" [0.86211842, 0.53539272, 0.39332206],\n",
" [0.40901593, 0.01396001, 0.79064577],\n",
" [0.40514149, 0.85380516, 0.71092302],\n",
" [0.40279167, 0.40553908, 0.34943154],\n",
" [0.32739584, 0.45935947, 0.34916069],\n",
" [0.53435728, 0.93021973, 0.72660915],\n",
" [0.48201736, 0.22166786, 0.23815264]],\n",
"\n",
" [[0.70384743, 0.91639052, 0.92217664],\n",
" [0.13377693, 0.83509932, 0.7465723 ],\n",
" [0.55129428, 0.93502285, 0.4134988 ],\n",
" [0.46219387, 0.11715096, 0.96329141],\n",
" [0.74610843, 0.18125007, 0.92666074],\n",
" [0.67935838, 0.84330332, 0.4897472 ],\n",
" [0.56189416, 0.94376182, 0.91971249],\n",
" [0.2344166 , 0.28675204, 0.50005015],\n",
" [0.87694834, 0.03292064, 0.59372731],\n",
" [0.1410371 , 0.37930507, 0.93440982]],\n",
"\n",
" [[0.33332779, 0.76704737, 0.03504045],\n",
" [0.12148567, 0.06901141, 0.32551204],\n",
" [0.59735274, 0.59523581, 0.00993422],\n",
" [0.9513948 , 0.47815931, 0.018291 ],\n",
" [0.79727936, 0.20293482, 0.33432458],\n",
" [0.61879149, 0.97450524, 0.73288953],\n",
" [0.20291612, 0.2230533 , 0.54581403],\n",
" [0.96841368, 0.96407929, 0.25184421],\n",
" [0.39925087, 0.23617574, 0.54586969],\n",
" [0.31766469, 0.55082408, 0.70822638]],\n",
"\n",
" [[0.23476765, 0.17316881, 0.47290996],\n",
" [0.63928832, 0.87168719, 0.41376668],\n",
" [0.43124419, 0.18044566, 0.28764125],\n",
" [0.07918316, 0.42466569, 0.29817 ],\n",
" [0.77903568, 0.26465035, 0.19378377],\n",
" [0.88960928, 0.17741272, 0.93643759],\n",
" [0.95486371, 0.71402737, 0.02170905],\n",
" [0.77465114, 0.21071205, 0.71740116],\n",
" [0.2105675 , 0.68209689, 0.99265747],\n",
" [0.44123258, 0.35068216, 0.18443967]],\n",
"\n",
" [[0.72729028, 0.34422837, 0.45330759],\n",
" [0.70008861, 0.74244728, 0.47519843],\n",
" [0.90665756, 0.79266484, 0.92471306],\n",
" [0.35356468, 0.27658691, 0.62612034],\n",
" [0.80262629, 0.38014853, 0.88572335],\n",
" [0.70577279, 0.73463977, 0.64607726],\n",
" [0.76175153, 0.34114261, 0.25165355],\n",
" [0.41433933, 0.81211749, 0.32920431],\n",
" [0.81885149, 0.43244061, 0.64193369],\n",
" [0.25711737, 0.75723915, 0.56240932]],\n",
"\n",
" [[0.4214149 , 0.33210466, 0.89250414],\n",
" [0.25449979, 0.07459641, 0.16747531],\n",
" [0.42138569, 0.20694074, 0.81756281],\n",
" [0.80522426, 0.62663345, 0.87747416],\n",
" [0.66534828, 0.80525886, 0.01988047],\n",
" [0.4948794 , 0.28630048, 0.58954056],\n",
" [0.74320486, 0.64790933, 0.28396301],\n",
" [0.13750094, 0.02872952, 0.39362795],\n",
" [0.72667091, 0.50329443, 0.11188484],\n",
" [0.01977563, 0.54219458, 0.60052706]],\n",
"\n",
" [[0.53801627, 0.91787638, 0.38982677],\n",
" [0.03749426, 0.92513169, 0.91275971],\n",
" [0.34702074, 0.00438524, 0.17541972],\n",
" [0.16459438, 0.39827819, 0.43025267],\n",
" [0.52254362, 0.17287154, 0.08078938],\n",
" [0.61399286, 0.50806298, 0.26215701],\n",
" [0.36095081, 0.61864805, 0.00384717],\n",
" [0.01225557, 0.70154596, 0.30105766],\n",
" [0.56348969, 0.84660337, 0.88242305],\n",
" [0.78735834, 0.06987492, 0.06523453]],\n",
"\n",
" [[0.31370938, 0.02861093, 0.10943704],\n",
" [0.82845993, 0.13236614, 0.56989642],\n",
" [0.33997596, 0.94155797, 0.93786532],\n",
" [0.28295679, 0.52881401, 0.35867141],\n",
" [0.94786737, 0.35825851, 0.25476474],\n",
" [0.80097957, 0.95050832, 0.68227972],\n",
" [0.02266812, 0.98557197, 0.69057788],\n",
" [0.54403132, 0.20586683, 0.90978653],\n",
" [0.9848536 , 0.62006799, 0.89784117],\n",
" [0.87154093, 0.77054866, 0.84496778]]]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,10,10,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.60406584, 0.04963745],\n",
" [0.8254909 , 0.59482723],\n",
" [0.93028533, 0.22805347]],\n",
"\n",
" [[1.2823461 , 0.5949404 ],\n",
" [0.897624 , 0.3798192 ],\n",
" [1.1588949 , 0.44623226]],\n",
"\n",
" [[0.5934084 , 0.06460603],\n",
" [1.2609756 , 0.25182724],\n",
" [0.5040486 , 0.6587203 ]]]], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": (X*1000).round().astype(int).flatten().tolist(),\n",
" \"weights\": (model.weights[0].numpy()*1000).round().astype(int).flatten().tolist(),\n",
" \"bias\": [0,0]\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"out_json = {\n",
" \"out\": (y*1000000).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"with open(\"conv2D_stride_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"with open(\"conv2D_stride_output.json\", \"w\") as f:\n",
" json.dump(out_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,

4
package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "circomlib-ml",
"version": "1.2.1",
"version": "1.3.0",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "circomlib-ml",
"version": "1.2.1",
"version": "1.3.0",
"license": "GPL-3.0",
"devDependencies": {
"blake-hash": "^2.0.0",

View File

@@ -1,6 +1,6 @@
{
"name": "circomlib-ml",
"version": "1.2.1",
"version": "1.3.0",
"description": "Circuits library for machine learning in circom",
"main": "index.js",
"directories": {

49
test/Conv1D.js Normal file
View File

@@ -0,0 +1,49 @@
const chai = require("chai");
const { Console } = require("console");
const path = require("path");
const wasm_tester = require("circom_tester").wasm;
const F1Field = require("ffjavascript").F1Field;
const Scalar = require("ffjavascript").Scalar;
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const Fr = new F1Field(exports.p);
const assert = chai.assert;
const json = require("../models/conv1D_input.json");
const OUTPUT = require("../models/conv1D_output.json");
describe("Conv1D layer test", function () {
this.timeout(100000000);
it("(20,3) -> (6,2)", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv1D_test.circom"));
//await circuit.loadConstraints();
//assert.equal(circuit.nVars, 618);
//assert.equal(circuit.constraints.length, 486);
let INPUT = {};
for (const [key, value] of Object.entries(json)) {
if (Array.isArray(value)) {
let tmpArray = [];
for (let i = 0; i < value.flat().length; i++) {
tmpArray.push(Fr.e(value.flat()[i]));
}
INPUT[key] = tmpArray;
} else {
INPUT[key] = Fr.e(value);
}
}
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<6*2; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(5000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000));
}
});
});

View File

@@ -11,13 +11,15 @@ const Fr = new F1Field(exports.p);
const assert = chai.assert;
const json = require("../models/conv2D_input.json");
const OUTPUT = require("../models/conv2D_output.json");
describe("Conv2D layer test", function () {
this.timeout(100000000);
it("(5,5,3) -> (3,3,2)", async () => {
let json = require("../models/conv2D_input.json");
let OUTPUT = require("../models/conv2D_output.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2D_test.circom"));
//await circuit.loadConstraints();
//assert.equal(circuit.nVars, 618);
@@ -44,4 +46,35 @@ describe("Conv2D layer test", function () {
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000));
}
});
it("(10,10,3) -> (3,3,2)", async () => {
let json = require("../models/conv2D_stride_input.json");
let OUTPUT = require("../models/conv2D_stride_output.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2D_stride_test.circom"));
//await circuit.loadConstraints();
//assert.equal(circuit.nVars, 618);
//assert.equal(circuit.constraints.length, 486);
const weights = [];
for (var i=0; i<json.weights.length; i++) {
weights.push(Fr.e(json.weights[i]));
}
const INPUT = {
"in": json.in,
"weights": weights,
"bias": ["0","0"]
}
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<3*3*2; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(5000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000));
}
});
});

View File

@@ -0,0 +1,5 @@
pragma circom 2.0.3;
include "../../circuits/Conv1D.circom";
component main = Conv1D(20, 3, 2, 4, 3);

View File

@@ -0,0 +1,5 @@
pragma circom 2.0.3;
include "../../circuits/Conv2D.circom";
component main = Conv2D(10, 10, 3, 2, 4, 3);

View File

@@ -2,4 +2,4 @@ pragma circom 2.0.3;
include "../../circuits/Conv2D.circom";
component main = Conv2D(5, 5, 3, 2, 3);
component main = Conv2D(5, 5, 3, 2, 3,1);

View File

@@ -16,10 +16,10 @@ template mnist_convnet() {
signal input dense_bias[10];
signal output out;
component conv2d_1 = Conv2D(28,28,1,4,3);
component conv2d_1 = Conv2D(28,28,1,4,3,1);
component poly_1[26][26][4];
component sum2d_1 = SumPooling2D(26,26,4,2);
component conv2d_2 = Conv2D(13,13,4,8,3);
component conv2d_2 = Conv2D(13,13,4,8,3,1);
component poly_2[11][11][8];
component sum2d_2 = SumPooling2D(11,11,8,2);
component dense = Dense(200,10);

View File

@@ -13,7 +13,7 @@ template mnist_poly() {
signal input dense_bias[10];
signal output out;
component conv2d = Conv2D(28,28,1,1,3);
component conv2d = Conv2D(28,28,1,1,3,1);
component poly[26*26];
component dense = Dense(676,10);
component argmax = ArgMax(10);

View File

@@ -13,7 +13,7 @@ template mnist() {
signal input dense_bias[10];
signal output out;
component conv2d = Conv2D(28,28,1,1,3);
component conv2d = Conv2D(28,28,1,1,3,1);
component relu[26*26];
component dense = Dense(676,10);
component argmax = ArgMax(10);