mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-09 14:08:04 -05:00
MaxPooling2D layer
This commit is contained in:
25
circuits/MaxPooling2D.circom
Normal file
25
circuits/MaxPooling2D.circom
Normal file
@@ -0,0 +1,25 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "./util.circom";
|
||||
|
||||
// MaxPooling2D layer
|
||||
template MaxPooling2D (nRows, nCols, nChannels, poolSize, strides) {
|
||||
signal input in[nRows][nCols][nChannels];
|
||||
signal output out[(nRows-poolSize)\strides+1][(nCols-poolSize)\strides+1][nChannels];
|
||||
|
||||
component max[(nRows-poolSize)\strides+1][(nCols-poolSize)\strides+1][nChannels];
|
||||
|
||||
for (var i=0; i<(nRows-poolSize)\strides+1; i++) {
|
||||
for (var j=0; j<(nCols-poolSize)\strides+1; j++) {
|
||||
for (var k=0; k<nChannels; k++) {
|
||||
max[i][j][k] = Max(poolSize*poolSize);
|
||||
for (var x=0; x<poolSize; x++) {
|
||||
for (var y=0; y<poolSize; y++) {
|
||||
max[i][j][k].in[x*poolSize+y] <== in[i*strides+x][j*strides+y][k];
|
||||
}
|
||||
}
|
||||
out[i][j][k] <== max[i][j][k].out;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,8 @@ pragma circom 2.0.3;
|
||||
|
||||
include "./circomlib/sign.circom";
|
||||
include "./circomlib/bitify.circom";
|
||||
include "./circomlib/comparators.circom";
|
||||
include "./circomlib/switcher.circom";
|
||||
|
||||
template IsNegative() {
|
||||
signal input in;
|
||||
@@ -45,4 +47,31 @@ template Sum(nInputs) {
|
||||
}
|
||||
|
||||
out <== partialSum[nInputs-1];
|
||||
}
|
||||
|
||||
template Max(n) {
|
||||
signal input in[n];
|
||||
signal output out;
|
||||
|
||||
component gts[n]; // store comparators
|
||||
component switchers[n+1]; // switcher for comparing maxs
|
||||
|
||||
signal maxs[n+1];
|
||||
|
||||
maxs[0] <== in[0];
|
||||
for(var i = 0; i < n; i++) {
|
||||
gts[i] = GreaterThan(252); // changed to 252 (maximum) for better compatibility
|
||||
switchers[i+1] = Switcher();
|
||||
|
||||
gts[i].in[1] <== maxs[i];
|
||||
gts[i].in[0] <== in[i];
|
||||
|
||||
switchers[i+1].sel <== gts[i].out;
|
||||
switchers[i+1].L <== maxs[i];
|
||||
switchers[i+1].R <== in[i];
|
||||
|
||||
maxs[i+1] <== switchers[i+1].outL;
|
||||
}
|
||||
|
||||
out <== maxs[n];
|
||||
}
|
||||
1
models/maxPooling2D_input.json
Normal file
1
models/maxPooling2D_input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"in": [305, 346, 857, 25, 681, 959, 542, 544, 585, 888, 355, 682, 592, 102, 869, 706, 817, 833, 375, 910, 180, 634, 499, 941, 292, 418, 978, 301, 234, 455, 806, 539, 293, 312, 10, 857, 346, 917, 944, 666, 366, 391, 182, 724, 485, 479, 424, 594, 269, 182, 826, 396, 336, 14, 262, 146, 209, 130, 498, 852, 274, 725, 260, 467, 606, 63, 446, 725, 491, 711, 384, 154, 937, 145, 41]}
|
||||
1
models/maxPooling2D_output.json
Normal file
1
models/maxPooling2D_output.json
Normal file
@@ -0,0 +1 @@
|
||||
{"out": [706, 910, 959, 888, 544, 978, 806, 539, 857, 666, 917, 944]}
|
||||
1
models/maxPooling2D_stride_input.json
Normal file
1
models/maxPooling2D_stride_input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"in": [966, 287, 646, 980, 70, 226, 16, 860, 319, 245, 695, 843, 722, 455, 700, 532, 532, 706, 319, 930, 713, 589, 295, 53, 601, 31, 364, 707, 875, 350, 730, 634, 272, 729, 301, 61, 16, 100, 456, 862, 654, 461, 947, 756, 418, 897, 638, 971, 694, 676, 189, 315, 658, 992, 483, 428, 558, 317, 146, 686, 855, 833, 329, 564, 526, 24, 47, 518, 799, 733, 679, 954, 776, 416, 96, 8, 667, 833, 325, 315, 164, 464, 503, 823, 825, 622, 891, 137, 943, 368, 497, 775, 616, 943, 330, 374, 306, 189, 980, 914, 711, 349, 21, 222, 554, 770, 666, 242, 565, 349, 510, 120, 288, 129, 401, 567, 515, 874, 307, 760, 750, 261, 701, 800, 37, 354, 51, 608, 318, 249, 593, 685, 784, 271, 791, 920, 456, 953, 376, 797, 843, 948, 209, 697, 272, 810, 40, 108, 457, 162, 617, 211, 529, 333, 39, 320, 15, 408, 149, 945, 66, 71, 682, 546, 32, 387, 840, 581, 333, 736, 376, 99, 410, 211, 516, 638, 318, 936, 392, 160, 113, 179, 2, 429, 543, 426, 232, 240, 969, 934, 907, 907, 814, 585, 578, 392, 49, 286, 648, 435, 858, 283, 867, 248, 585, 389, 72, 135, 469, 948, 428, 684, 177, 625, 503, 588, 860, 885, 611, 584, 779, 107, 697, 892, 135, 480, 226, 622, 113, 517, 352, 0, 368, 779, 187, 872, 831, 869, 241, 896, 357, 668, 380, 735, 345, 168, 157, 345, 246, 76, 108, 454, 885, 355, 100, 180, 159, 908, 549, 547, 235, 559, 241, 283, 579, 546, 84, 595, 745, 44, 862, 849, 399, 370, 636, 303, 832, 363, 874, 877, 148, 136, 555, 916, 208, 541, 667, 441, 770, 886, 919, 942, 953, 151, 810, 840, 130, 460, 821, 200]}
|
||||
1
models/maxPooling2D_stride_output.json
Normal file
1
models/maxPooling2D_stride_output.json
Normal file
@@ -0,0 +1 @@
|
||||
{"out": [980, 634, 646, 947, 756, 843, 694, 930, 992, 943, 775, 701, 914, 711, 791, 948, 797, 843, 625, 684, 588, 934, 907, 907, 648, 867, 858]}
|
||||
455
models/maxPooling2d.ipynb
Normal file
455
models/maxPooling2d.ipynb
Normal file
@@ -0,0 +1,455 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.keras.layers import Input, AveragePooling2D, Lambda\n",
|
||||
"from tensorflow.keras import Model\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = Input(shape=(5,5,3))\n",
|
||||
"x = AveragePooling2D(pool_size=2)(inputs)\n",
|
||||
"x = Lambda(lambda x: x*4)(x)\n",
|
||||
"model = Model(inputs, x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
"Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
"input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"average_pooling2d (AveragePo (None, 2, 2, 3) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda (Lambda) (None, 2, 2, 3) 0 \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 0\n",
|
||||
"Trainable params: 0\n",
|
||||
"Non-trainable params: 0\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[0.83128186, 0.15650764, 0.23798145],\n",
|
||||
" [0.00277366, 0.8374127 , 0.95278315],\n",
|
||||
" [0.3074389 , 0.21931738, 0.14886067],\n",
|
||||
" [0.13590018, 0.98728255, 0.12085182],\n",
|
||||
" [0.47212572, 0.51380922, 0.74891219]],\n",
|
||||
"\n",
|
||||
" [[0.74680338, 0.2533205 , 0.5039968 ],\n",
|
||||
" [0.14475403, 0.00791911, 0.4361197 ],\n",
|
||||
" [0.69925568, 0.77507624, 0.40388991],\n",
|
||||
" [0.29508251, 0.99375606, 0.84959701],\n",
|
||||
" [0.88844918, 0.33910189, 0.9617212 ]],\n",
|
||||
"\n",
|
||||
" [[0.76480625, 0.591287 , 0.0714191 ],\n",
|
||||
" [0.94371681, 0.1695303 , 0.4476252 ],\n",
|
||||
" [0.54372616, 0.83818804, 0.95211573],\n",
|
||||
" [0.30485104, 0.15165265, 0.94709317],\n",
|
||||
" [0.90827137, 0.58854675, 0.01857002]],\n",
|
||||
"\n",
|
||||
" [[0.70123418, 0.43090173, 0.7096038 ],\n",
|
||||
" [0.20637783, 0.20096581, 0.22956612],\n",
|
||||
" [0.81978383, 0.16775403, 0.67412096],\n",
|
||||
" [0.1011535 , 0.35596916, 0.36702071],\n",
|
||||
" [0.5874605 , 0.79341372, 0.93292159]],\n",
|
||||
"\n",
|
||||
" [[0.77997124, 0.46311399, 0.5465576 ],\n",
|
||||
" [0.20406287, 0.37547625, 0.59862253],\n",
|
||||
" [0.52933135, 0.84249092, 0.02969684],\n",
|
||||
" [0.29114617, 0.10405779, 0.5359062 ],\n",
|
||||
" [0.25197146, 0.83297465, 0.67025403]]]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = np.random.rand(1,5,5,3)\n",
|
||||
"X"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[1.7256129, 1.2551599, 2.1308813],\n",
|
||||
" [1.4376774, 2.9754324, 1.5231993]],\n",
|
||||
"\n",
|
||||
" [[2.6161351, 1.3926848, 1.4582142],\n",
|
||||
" [1.7695144, 1.5135639, 2.9403505]]]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y = model.predict(X)\n",
|
||||
"y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_json = {\n",
|
||||
" \"in\": (X*1000).round().astype(int).flatten().tolist()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out_json = {\n",
|
||||
" \"out\": (y*1000).round().astype(int).flatten().tolist()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"sumPooling2D_input.json\", \"w\") as f:\n",
|
||||
" json.dump(in_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"sumPooling2D_output.json\", \"w\") as f:\n",
|
||||
" json.dump(out_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = Input(shape=(10,10,3))\n",
|
||||
"x = AveragePooling2D(pool_size=2, strides=3)(inputs)\n",
|
||||
"x = Lambda(lambda x: x*4)(x)\n",
|
||||
"model = Model(inputs, x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model_2\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
"Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
"input_3 (InputLayer) [(None, 10, 10, 3)] 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"average_pooling2d_2 (Average (None, 3, 3, 3) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda_2 (Lambda) (None, 3, 3, 3) 0 \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 0\n",
|
||||
"Trainable params: 0\n",
|
||||
"Non-trainable params: 0\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[0.31514958, 0.03200121, 0.29129004],\n",
|
||||
" [0.35725668, 0.87252739, 0.77162311],\n",
|
||||
" [0.61707883, 0.7945887 , 0.48907944],\n",
|
||||
" [0.98197306, 0.88814753, 0.69652672],\n",
|
||||
" [0.2518265 , 0.82753267, 0.57464263],\n",
|
||||
" [0.62115028, 0.76805041, 0.65967975],\n",
|
||||
" [0.32491691, 0.93353364, 0.74831234],\n",
|
||||
" [0.45570461, 0.96830864, 0.8476189 ],\n",
|
||||
" [0.02149766, 0.27808247, 0.18207897],\n",
|
||||
" [0.15949257, 0.69372265, 0.43455872]],\n",
|
||||
"\n",
|
||||
" [[0.70915807, 0.60410698, 0.94721792],\n",
|
||||
" [0.5621233 , 0.65546021, 0.27357865],\n",
|
||||
" [0.00414209, 0.08635782, 0.30528659],\n",
|
||||
" [0.11492599, 0.15002234, 0.58496289],\n",
|
||||
" [0.72848003, 0.55169839, 0.91708802],\n",
|
||||
" [0.43479205, 0.08069621, 0.68404234],\n",
|
||||
" [0.3946513 , 0.20447291, 0.10467492],\n",
|
||||
" [0.78817621, 0.63518792, 0.00827133],\n",
|
||||
" [0.0853401 , 0.75656605, 0.95034115],\n",
|
||||
" [0.92239164, 0.06871402, 0.37783711]],\n",
|
||||
"\n",
|
||||
" [[0.36000095, 0.35659628, 0.6507159 ],\n",
|
||||
" [0.29486935, 0.49464939, 0.40502335],\n",
|
||||
" [0.09509896, 0.08726498, 0.39326876],\n",
|
||||
" [0.38827707, 0.50908505, 0.63443643],\n",
|
||||
" [0.27030144, 0.67783072, 0.09309034],\n",
|
||||
" [0.76360544, 0.67003754, 0.28767228],\n",
|
||||
" [0.55305299, 0.60216561, 0.3544107 ],\n",
|
||||
" [0.55839884, 0.86964781, 0.26053367],\n",
|
||||
" [0.87306012, 0.78756102, 0.04817508],\n",
|
||||
" [0.72406774, 0.67679246, 0.82272016]],\n",
|
||||
"\n",
|
||||
" [[0.71765743, 0.50852032, 0.52047892],\n",
|
||||
" [0.7484707 , 0.97207503, 0.08778545],\n",
|
||||
" [0.15780167, 0.73192822, 0.40718403],\n",
|
||||
" [0.93263197, 0.8772701 , 0.34486053],\n",
|
||||
" [0.42436095, 0.80504181, 0.39139203],\n",
|
||||
" [0.81358273, 0.56754054, 0.12608038],\n",
|
||||
" [0.11843567, 0.61136361, 0.81339895],\n",
|
||||
" [0.27636648, 0.57453166, 0.10632468],\n",
|
||||
" [0.53090786, 0.14594835, 0.08140653],\n",
|
||||
" [0.34118642, 0.27554414, 0.19515355]],\n",
|
||||
"\n",
|
||||
" [[0.12974003, 0.6264065 , 0.56250089],\n",
|
||||
" [0.05655555, 0.93847961, 0.71849845],\n",
|
||||
" [0.57644684, 0.37077012, 0.53949152],\n",
|
||||
" [0.45904117, 0.30854737, 0.73517714],\n",
|
||||
" [0.64076017, 0.59373326, 0.83758554],\n",
|
||||
" [0.80707699, 0.79461191, 0.69655474],\n",
|
||||
" [0.79872758, 0.26420269, 0.29237624],\n",
|
||||
" [0.45087863, 0.28258419, 0.50447663],\n",
|
||||
" [0.29494657, 0.31770288, 0.49309187],\n",
|
||||
" [0.82460949, 0.3940875 , 0.33865267]],\n",
|
||||
"\n",
|
||||
" [[0.1108653 , 0.35294351, 0.44014634],\n",
|
||||
" [0.4988099 , 0.34405962, 0.77622373],\n",
|
||||
" [0.76444373, 0.88689451, 0.05756076],\n",
|
||||
" [0.57160174, 0.0752442 , 0.5098132 ],\n",
|
||||
" [0.22539676, 0.47741414, 0.28993556],\n",
|
||||
" [0.43298235, 0.58710277, 0.69306001],\n",
|
||||
" [0.9521223 , 0.87239108, 0.10672981],\n",
|
||||
" [0.93125144, 0.19405455, 0.95483289],\n",
|
||||
" [0.91030892, 0.85961313, 0.67439157],\n",
|
||||
" [0.09377237, 0.75818836, 0.61985122]],\n",
|
||||
"\n",
|
||||
" [[0.4344654 , 0.97297157, 0.89560878],\n",
|
||||
" [0.91664946, 0.68966445, 0.0530751 ],\n",
|
||||
" [0.72099738, 0.05779864, 0.30259649],\n",
|
||||
" [0.55598956, 0.11611106, 0.24856552],\n",
|
||||
" [0.40690072, 0.66148966, 0.22159354],\n",
|
||||
" [0.53035294, 0.23237414, 0.82781172],\n",
|
||||
" [0.20375017, 0.23486322, 0.36461596],\n",
|
||||
" [0.05525619, 0.59671011, 0.08001122],\n",
|
||||
" [0.11250979, 0.98519728, 0.57553523],\n",
|
||||
" [0.6117834 , 0.65811775, 0.78386287]],\n",
|
||||
"\n",
|
||||
" [[0.39532528, 0.78660638, 0.37617851],\n",
|
||||
" [0.86246711, 0.59398046, 0.50843286],\n",
|
||||
" [0.41395181, 0.96399598, 0.8374128 ],\n",
|
||||
" [0.76981858, 0.41760042, 0.17438256],\n",
|
||||
" [0.05937649, 0.93289121, 0.63833505],\n",
|
||||
" [0.97571178, 0.06364159, 0.34572432],\n",
|
||||
" [0.42278241, 0.52111442, 0.62746908],\n",
|
||||
" [0.0401781 , 0.80713288, 0.26990436],\n",
|
||||
" [0.21850787, 0.19009324, 0.04497292],\n",
|
||||
" [0.16176602, 0.20893733, 0.34094974]],\n",
|
||||
"\n",
|
||||
" [[0.13094336, 0.90151022, 0.82541695],\n",
|
||||
" [0.73192844, 0.24791076, 0.3587372 ],\n",
|
||||
" [0.88818939, 0.92023872, 0.69098959],\n",
|
||||
" [0.08104613, 0.6361497 , 0.42552169],\n",
|
||||
" [0.44517886, 0.99055202, 0.15580116],\n",
|
||||
" [0.78742252, 0.04735346, 0.46423316],\n",
|
||||
" [0.53474903, 0.79917168, 0.33019955],\n",
|
||||
" [0.31087978, 0.65384266, 0.77275665],\n",
|
||||
" [0.56393354, 0.5761927 , 0.4287843 ],\n",
|
||||
" [0.57457285, 0.67154059, 0.52881047]],\n",
|
||||
"\n",
|
||||
" [[0.78373278, 0.15164648, 0.92791502],\n",
|
||||
" [0.0999971 , 0.47319914, 0.44424683],\n",
|
||||
" [0.74758969, 0.04583226, 0.0579972 ],\n",
|
||||
" [0.37325021, 0.12464474, 0.61199188],\n",
|
||||
" [0.07404238, 0.65504221, 0.18787021],\n",
|
||||
" [0.16955187, 0.12750002, 0.48252436],\n",
|
||||
" [0.17829354, 0.85701326, 0.41402596],\n",
|
||||
" [0.21677806, 0.18949005, 0.27735136],\n",
|
||||
" [0.06721647, 0.16941253, 0.46916978],\n",
|
||||
" [0.39921131, 0.17705116, 0.94534667]]]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = np.random.rand(1,10,10,3)\n",
|
||||
"X"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[1.9436876 , 2.1640959 , 2.2837098 ],\n",
|
||||
" [2.0772057 , 2.4174008 , 2.7732203 ],\n",
|
||||
" [1.963449 , 2.741503 , 1.7088774 ]],\n",
|
||||
"\n",
|
||||
" [[1.6524236 , 3.0454814 , 1.8892637 ],\n",
|
||||
" [2.4567943 , 2.5845926 , 2.3090153 ],\n",
|
||||
" [1.6444083 , 1.7326821 , 1.7165766 ]],\n",
|
||||
"\n",
|
||||
" [[2.6089072 , 3.043223 , 1.8332952 ],\n",
|
||||
" [1.7920854 , 2.1280923 , 1.2828767 ],\n",
|
||||
" [0.72196686, 2.1598206 , 1.3420006 ]]]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y = model.predict(X)\n",
|
||||
"y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_json = {\n",
|
||||
" \"in\": (X*1000).round().astype(int).flatten().tolist()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out_json = {\n",
|
||||
" \"out\": (y*1000).round().astype(int).flatten().tolist()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"sumPooling2D_stride_input.json\", \"w\") as f:\n",
|
||||
" json.dump(in_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"sumPooling2D_stride_output.json\", \"w\") as f:\n",
|
||||
" json.dump(out_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "tf24",
|
||||
"language": "python",
|
||||
"name": "tf24"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
32
test/Max.js
Normal file
32
test/Max.js
Normal file
@@ -0,0 +1,32 @@
|
||||
const chai = require("chai");
|
||||
const path = require("path");
|
||||
|
||||
const wasm_tester = require("circom_tester").wasm;
|
||||
|
||||
const F1Field = require("ffjavascript").F1Field;
|
||||
const Scalar = require("ffjavascript").Scalar;
|
||||
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
|
||||
const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
describe("Max test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
it("Maximum of 4 numbers", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "Max_test.circom"));
|
||||
//await circuit.loadConstraints();
|
||||
//assert.equal(circuit.nVars, 516);
|
||||
//assert.equal(circuit.constraints.length, 516);
|
||||
|
||||
const INPUT = {
|
||||
"in": ["1","4","2","3"]
|
||||
}
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
//console.log(witness);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
assert(Fr.eq(Fr.e(witness[1]),Fr.e(4)));
|
||||
});
|
||||
});
|
||||
63
test/MaxPooling2D.js
Normal file
63
test/MaxPooling2D.js
Normal file
@@ -0,0 +1,63 @@
|
||||
const chai = require("chai");
|
||||
const { Console } = require("console");
|
||||
const path = require("path");
|
||||
|
||||
const wasm_tester = require("circom_tester").wasm;
|
||||
|
||||
const F1Field = require("ffjavascript").F1Field;
|
||||
const Scalar = require("ffjavascript").Scalar;
|
||||
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
|
||||
const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
|
||||
|
||||
describe("MaxPooling2D layer test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
// MaxPooling with strides==poolSize
|
||||
it("(5,5,3) -> (2,2,3)", async () => {
|
||||
const json = require("../models/maxPooling2D_input.json");
|
||||
const OUTPUT = require("../models/maxPooling2D_output.json");
|
||||
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "MaxPooling2D_test.circom"));
|
||||
//await circuit.loadConstraints();
|
||||
//assert.equal(circuit.nVars, 76);
|
||||
//assert.equal(circuit.constraints.length, 0);
|
||||
|
||||
const INPUT = {
|
||||
"in": json.in
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
|
||||
for (var i=0; i<2*2*3; i++) {
|
||||
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(1));
|
||||
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(1));
|
||||
}
|
||||
});
|
||||
|
||||
// MaxPooling with strides!=poolSize
|
||||
it("(10,10,3) -> (3,3,3)", async () => {
|
||||
const json = require("../models/maxPooling2D_stride_input.json");
|
||||
const OUTPUT = require("../models/maxPooling2D_stride_output.json");
|
||||
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "MaxPooling2D_stride_test.circom"));
|
||||
|
||||
const INPUT = {
|
||||
"in": json.in
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
|
||||
for (var i=0; i<3*3*3; i++) {
|
||||
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(1));
|
||||
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(1));
|
||||
}
|
||||
});
|
||||
});
|
||||
5
test/circuits/MaxPooling2D_stride_test.circom
Normal file
5
test/circuits/MaxPooling2D_stride_test.circom
Normal file
@@ -0,0 +1,5 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/MaxPooling2D.circom";
|
||||
|
||||
component main = MaxPooling2D(10, 10, 3, 2, 3);
|
||||
6
test/circuits/MaxPooling2D_test.circom
Normal file
6
test/circuits/MaxPooling2D_test.circom
Normal file
@@ -0,0 +1,6 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/MaxPooling2D.circom";
|
||||
|
||||
// poolSize=strides - default Keras settings
|
||||
component main = MaxPooling2D(5, 5, 3, 2, 2);
|
||||
5
test/circuits/Max_test.circom
Normal file
5
test/circuits/Max_test.circom
Normal file
@@ -0,0 +1,5 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/util.circom";
|
||||
|
||||
component main = Max(4);
|
||||
Reference in New Issue
Block a user