Add MaxPooling2Dsame and MaxPooling2Dsame_stride test circuits and their dependencies

This commit is contained in:
drCathieSo.eth
2024-02-05 01:33:50 +08:00
parent cdb3a290e5
commit 526da93d40
7 changed files with 742 additions and 0 deletions

View File

@@ -0,0 +1,73 @@
pragma circom 2.0.0;
include "./MaxPooling2D.circom";
template MaxPooling2Dsame (nRows, nCols, nChannels, poolSize, strides) {
signal input in[nRows][nCols][nChannels];
var rowPadding, colPadding;
if (nRows % strides == 0) {
rowPadding = (poolSize - strides) > 0 ? (poolSize - strides) : 0;
} else {
rowPadding = (poolSize - (nRows % strides)) > 0 ? (poolSize - (nRows % strides)) : 0;
}
if (nCols % strides == 0) {
colPadding = (poolSize - strides) > 0 ? (poolSize - strides) : 0;
} else {
colPadding = (poolSize - (nCols % strides)) > 0 ? (poolSize - (nCols % strides)) : 0;
}
signal input out[(nRows+rowPadding-poolSize)\strides+1][(nCols+colPadding-poolSize)\strides+1][nChannels];
component max2d = MaxPooling2D(nRows+rowPadding, nCols+colPadding, nChannels, poolSize, strides);
for (var i = rowPadding\2; i < rowPadding\2+nRows; i++) {
for (var j = colPadding\2; j < colPadding\2+nCols; j++) {
for (var k = 0; k < nChannels; k++) {
max2d.in[i][j][k] <== in[i-rowPadding\2][j-colPadding\2][k];
}
}
}
for (var i = 0; i< rowPadding\2; i++) {
for (var j = 0; j < nCols+colPadding; j++) {
for (var k = 0; k < nChannels; k++) {
max2d.in[i][j][k] <== 0;
}
}
}
for (var i = nRows+rowPadding\2; i< nRows+rowPadding; i++) {
for (var j = 0; j < nCols+colPadding; j++) {
for (var k = 0; k < nChannels; k++) {
max2d.in[i][j][k] <== 0;
}
}
}
for (var i = rowPadding\2; i < nRows+rowPadding\2; i++) {
for (var j = 0; j < colPadding\2; j++) {
for (var k = 0; k < nChannels; k++) {
max2d.in[i][j][k] <== 0;
}
}
}
for (var i = rowPadding\2; i < nRows+rowPadding\2; i++) {
for (var j = nCols+colPadding\2; j < nCols+colPadding; j++) {
for (var k = 0; k < nChannels; k++) {
max2d.in[i][j][k] <== 0;
}
}
}
for (var i = 0; i < (nRows+rowPadding-poolSize)\strides+1; i++) {
for (var j = 0; j < (nCols+colPadding-poolSize)\strides+1; j++) {
for (var k = 0; k < nChannels; k++) {
max2d.out[i][j][k] <== out[i][j][k];
}
}
}
}

View File

@@ -0,0 +1 @@
{"in": [[[985416895178787714183661703755988992, 805695864472562858264688082666127360, 632684451822347108628256674661007360], [962060468041561843488482131179470848, 8869073729764975861586898758664192, 677759670322146260067784271585083392], [467925439750069089712591214043201536, 813606982115283460844798537969434624, 347772122728895153321033327318138880], [908354388868282470618698257109352448, 691746563537225063033040168500068352, 981726193592448973649175667344932864], [107153282843697739853753056552288256, 916572289086665811955284394381410304, 816245550879711653823401560087986176]], [[570382642241468206642626555033944064, 902803940435826284139749932730941440, 777663362123877860598598645145665536], [634090669863853333267758182352027648, 825958123785832877484312556952616960, 381513410888697931311864158837800960], [795520454944760995904638766836350976, 216359595257833349149815253067890688, 874616717227600117488934168588976128], [746811470919424536534678233059164160, 475841014459437513096347703977181184, 930028665971741586694297106155307008], [114000116831430861403008392105033728, 974091109200621567612816335791194112, 874043654050301690820171495435665408]], [[539579607228014027803293647267430400, 508632843040824474037175176904310784, 890125946940297760199841005954924544], [278016685349965289070216885173223424, 552815374291668108366329210074038272, 500049722551334607063608137326002176], [313744264694903791749843753654812672, 418498494396639822766365224919367680, 817633569064628434285947918592507904], [680992631793690390926240128859897856, 217157254616318320755856276537212928, 602450814158750475602742768018915328], [230448677275229241444208872890826752, 894021844181003495121759207996522496, 738029169437394307130485433159385088]], [[656815180323522869339469446928924672, 454311340827568891467175719507853312, 111869508785107817930652937351069696], [632327323641954593527382646479388672, 99140057192692252460452964236001280, 469778003485259478255996251364392960], [539803270562814329659938110829494272, 9999223425718462488177430625255424, 323853684321289258856386212080386048], [511898330068137781770564935289405440, 235408855051092995844022897467195392, 488270711104327731196314985370222592], [973477114549060379932212627888930816, 163297737054907472179883527553155072, 734044941826309056667877810110464000]], [[20108031993913978534061672446820352, 732003881291962206032259471665790976, 752966593417693633098737987950215168], [470805829710577633964522992960536576, 899977578211573860664606514781093888, 777543729617406477245147183797239808], [450476427222540130797322504023048192, 608466014932603115839705401413074944, 220254487946003773078748363501862912], [813185240376988213714402965620523008, 864052363876363850452980196205133824, 441362157036501413461183215236546560], [537904308558723633624141672348647424, 459317359421049476909261905824055296, 9517425402804758462154346481057792]]], "out": [[["985416895178787714183661703755988992", "902803940435826284139749932730941440", "777663362123877860598598645145665536"], ["908354388868282470618698257109352448", "813606982115283460844798537969434624", "981726193592448973649175667344932864"], ["114000116831430861403008392105033728", "974091109200621567612816335791194112", "874043654050301690820171495435665408"]], [["656815180323522869339469446928924672", "552815374291668108366329210074038272", "890125946940297760199841005954924544"], ["680992631793690390926240128859897856", "418498494396639822766365224919367680", "817633569064628434285947918592507904"], ["973477114549060379932212627888930816", "894021844181003495121759207996522496", "738029169437394307130485433159385088"]], [["470805829710577633964522992960536576", "899977578211573860664606514781093888", "777543729617406477245147183797239808"], ["813185240376988213714402965620523008", "864052363876363850452980196205133824", "441362157036501413461183215236546560"], ["537904308558723633624141672348647424", "459317359421049476909261905824055296", "9517425402804758462154346481057792"]]]}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,619 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Input, MaxPooling2D\n",
"from tensorflow.keras import Model\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(5,5,3))\n",
"x = MaxPooling2D(pool_size=2, padding='same')(inputs)\n",
"model = Model(inputs, x)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n",
" \n",
" max_pooling2d (MaxPooling2D (None, 3, 3, 3) 0 \n",
" ) \n",
" \n",
"=================================================================\n",
"Total params: 0\n",
"Trainable params: 0\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.9854169 , 0.80569586, 0.63268445],\n",
" [0.96206047, 0.00886907, 0.67775967],\n",
" [0.46792544, 0.81360698, 0.34777212],\n",
" [0.90835439, 0.69174656, 0.98172619],\n",
" [0.10715328, 0.91657229, 0.81624555]],\n",
"\n",
" [[0.57038264, 0.90280394, 0.77766336],\n",
" [0.63409067, 0.82595812, 0.38151341],\n",
" [0.79552045, 0.2163596 , 0.87461672],\n",
" [0.74681147, 0.47584101, 0.93002867],\n",
" [0.11400012, 0.97409111, 0.87404365]],\n",
"\n",
" [[0.53957961, 0.50863284, 0.89012595],\n",
" [0.27801669, 0.55281537, 0.50004972],\n",
" [0.31374426, 0.41849849, 0.81763357],\n",
" [0.68099263, 0.21715725, 0.60245081],\n",
" [0.23044868, 0.89402184, 0.73802917]],\n",
"\n",
" [[0.65681518, 0.45431134, 0.11186951],\n",
" [0.63232732, 0.09914006, 0.469778 ],\n",
" [0.53980327, 0.00999922, 0.32385368],\n",
" [0.51189833, 0.23540886, 0.48827071],\n",
" [0.97347711, 0.16329774, 0.73404494]],\n",
"\n",
" [[0.02010803, 0.73200388, 0.75296659],\n",
" [0.47080583, 0.89997758, 0.77754373],\n",
" [0.45047643, 0.60846601, 0.22025449],\n",
" [0.81318524, 0.86405236, 0.44136216],\n",
" [0.53790431, 0.45931736, 0.00951743]]]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,5,5,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 42ms/step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-02-05 01:27:05.945781: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n"
]
},
{
"data": {
"text/plain": [
"array([[[[0.9854169 , 0.90280396, 0.77766335],\n",
" [0.9083544 , 0.813607 , 0.98172617],\n",
" [0.11400012, 0.9740911 , 0.87404364]],\n",
"\n",
" [[0.6568152 , 0.5528154 , 0.89012593],\n",
" [0.6809926 , 0.4184985 , 0.81763357],\n",
" [0.9734771 , 0.89402187, 0.7380292 ]],\n",
"\n",
" [[0.47080582, 0.89997756, 0.7775437 ],\n",
" [0.8131852 , 0.86405236, 0.44136214],\n",
" [0.5379043 , 0.45931736, 0.00951743]]]], dtype=float32)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"X_in = [[[int(X[0][i][j][k]*1e36) for k in range(3)] for j in range(5)] for i in range(5)]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def MaxPooling2DInt(nRows, nCols, nChannels, poolSize, strides, input):\n",
" out = [[[str(max(int(input[i*strides + x][j*strides + y][k]) for x in range(poolSize) for y in range(poolSize))) for k in range(nChannels)] for j in range((nCols - poolSize) // strides + 1)] for i in range((nRows - poolSize) // strides + 1)]\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def MaxPooling2DsameInt(nRows, nCols, nChannels, poolSize, strides, input):\n",
" if nRows % strides == 0:\n",
" rowPadding = max(poolSize - strides, 0)\n",
" else:\n",
" rowPadding = max(poolSize - nRows % strides, 0)\n",
" if nCols % strides == 0:\n",
" colPadding = max(poolSize - strides, 0)\n",
" else:\n",
" colPadding = max(poolSize - nCols % strides, 0)\n",
" \n",
" _input = [[[0 for _ in range(nChannels)] for _ in range(nCols + colPadding)] for _ in range(nRows + rowPadding)]\n",
"\n",
" for i in range(nRows):\n",
" for j in range(nCols):\n",
" for k in range(nChannels):\n",
" _input[i+rowPadding//2][j+colPadding//2][k] = input[i][j][k]\n",
" \n",
" out = MaxPooling2DInt(nRows + rowPadding, nCols + colPadding, nChannels, poolSize, strides, _input)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[['985416895178787714183661703755988992',\n",
" '902803940435826284139749932730941440',\n",
" '777663362123877860598598645145665536'],\n",
" ['908354388868282470618698257109352448',\n",
" '813606982115283460844798537969434624',\n",
" '981726193592448973649175667344932864'],\n",
" ['114000116831430861403008392105033728',\n",
" '974091109200621567612816335791194112',\n",
" '874043654050301690820171495435665408']],\n",
" [['656815180323522869339469446928924672',\n",
" '552815374291668108366329210074038272',\n",
" '890125946940297760199841005954924544'],\n",
" ['680992631793690390926240128859897856',\n",
" '418498494396639822766365224919367680',\n",
" '817633569064628434285947918592507904'],\n",
" ['973477114549060379932212627888930816',\n",
" '894021844181003495121759207996522496',\n",
" '738029169437394307130485433159385088']],\n",
" [['470805829710577633964522992960536576',\n",
" '899977578211573860664606514781093888',\n",
" '777543729617406477245147183797239808'],\n",
" ['813185240376988213714402965620523008',\n",
" '864052363876363850452980196205133824',\n",
" '441362157036501413461183215236546560'],\n",
" ['537904308558723633624141672348647424',\n",
" '459317359421049476909261905824055296',\n",
" '9517425402804758462154346481057792']]]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out = MaxPooling2DsameInt(5, 5, 3, 2, 2, X_in)\n",
"out"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": X_in,\n",
" \"out\": out\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"with open(\"maxPooling2Dsame_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(10,10,3))\n",
"x = MaxPooling2D(pool_size=2, strides=3, padding='same')(inputs)\n",
"model = Model(inputs, x)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model_1\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_2 (InputLayer) [(None, 10, 10, 3)] 0 \n",
" \n",
" max_pooling2d_1 (MaxPooling (None, 4, 4, 3) 0 \n",
" 2D) \n",
" \n",
"=================================================================\n",
"Total params: 0\n",
"Trainable params: 0\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[1.81060846e-01, 8.73335066e-01, 4.66591631e-01],\n",
" [4.19512826e-01, 5.04687534e-01, 5.55108518e-01],\n",
" [8.89727395e-01, 3.04715421e-01, 3.96868333e-01],\n",
" [3.19583099e-01, 7.59947198e-04, 8.95376898e-01],\n",
" [1.86082695e-01, 5.89361183e-01, 2.29939007e-01],\n",
" [3.77869457e-01, 5.93831349e-01, 6.17898741e-01],\n",
" [2.08890055e-01, 4.52999632e-02, 4.81015031e-02],\n",
" [9.68292068e-01, 4.24571806e-01, 4.18124731e-02],\n",
" [3.30703359e-02, 6.13358807e-01, 9.62525235e-01],\n",
" [8.80909151e-01, 3.16377883e-02, 4.07229564e-01]],\n",
"\n",
" [[7.22047037e-01, 2.87769873e-01, 6.33057960e-01],\n",
" [5.87593115e-01, 9.95928671e-01, 9.47646805e-02],\n",
" [6.73764008e-01, 4.41940517e-01, 2.05872675e-01],\n",
" [2.74105248e-01, 3.41288944e-01, 7.03928155e-01],\n",
" [2.53407876e-01, 9.54086619e-01, 6.95949832e-01],\n",
" [7.72974380e-01, 6.94646688e-01, 2.39040667e-01],\n",
" [7.61119704e-01, 8.41368944e-01, 9.12971104e-01],\n",
" [3.19331761e-01, 2.36474093e-01, 3.11432211e-01],\n",
" [6.18384476e-01, 2.23794997e-01, 7.33884991e-01],\n",
" [1.00735392e-01, 7.42390580e-01, 7.43563278e-01]],\n",
"\n",
" [[1.82788443e-01, 9.69266407e-01, 6.79896409e-02],\n",
" [5.47375743e-01, 7.00178164e-01, 1.62010347e-02],\n",
" [2.70368172e-01, 9.53685968e-01, 4.62933777e-01],\n",
" [4.80574859e-01, 1.55085614e-01, 2.38047469e-01],\n",
" [4.86919140e-01, 1.94033355e-02, 1.80402947e-01],\n",
" [3.96449294e-01, 4.95605585e-01, 2.26605072e-01],\n",
" [1.74141020e-02, 3.46487475e-03, 8.98099350e-01],\n",
" [7.85190855e-01, 3.05525061e-01, 2.04565391e-01],\n",
" [8.23700964e-01, 2.97490709e-01, 6.50482927e-02],\n",
" [1.26448774e-01, 8.10615035e-01, 5.09045959e-01]],\n",
"\n",
" [[7.13680159e-01, 8.45319480e-01, 9.11203285e-01],\n",
" [8.27233800e-01, 4.12657299e-01, 6.97025851e-01],\n",
" [2.81018305e-01, 6.48955744e-01, 8.91153606e-01],\n",
" [6.29371477e-01, 7.75929101e-01, 7.79093686e-01],\n",
" [1.23716197e-01, 3.10122093e-01, 8.37055316e-01],\n",
" [1.68089475e-01, 9.54284278e-01, 3.18407200e-01],\n",
" [3.62606007e-01, 7.78830849e-01, 3.99089020e-01],\n",
" [8.51653764e-02, 5.34511941e-01, 6.04920400e-01],\n",
" [2.71115350e-01, 4.36692189e-01, 6.96241628e-01],\n",
" [8.65780466e-01, 3.22369175e-01, 7.43159548e-01]],\n",
"\n",
" [[9.28690706e-02, 3.94089568e-01, 5.61323421e-01],\n",
" [5.90368083e-01, 9.74901900e-01, 9.47030261e-01],\n",
" [6.30178577e-01, 5.13484751e-01, 1.18185924e-01],\n",
" [4.38458544e-01, 5.63764554e-01, 7.66318218e-01],\n",
" [7.18648243e-01, 1.70367043e-01, 7.70642982e-01],\n",
" [2.22980409e-01, 3.74329609e-01, 4.80409175e-01],\n",
" [9.45617766e-01, 8.81617847e-01, 1.19580346e-01],\n",
" [9.15676461e-01, 9.70518691e-01, 2.43214092e-01],\n",
" [2.41451379e-01, 4.88548632e-01, 3.70190637e-01],\n",
" [4.59905504e-01, 3.80889550e-01, 3.85698952e-01]],\n",
"\n",
" [[2.72104416e-01, 7.98516947e-01, 2.72194053e-01],\n",
" [9.48709871e-01, 2.66222380e-01, 9.42522429e-01],\n",
" [7.33569990e-01, 3.47061477e-01, 5.33050356e-01],\n",
" [6.11738759e-01, 8.68186611e-02, 3.75463635e-01],\n",
" [2.14948078e-01, 3.99863394e-01, 1.90922352e-01],\n",
" [5.10853285e-01, 8.50874635e-01, 3.81366847e-01],\n",
" [8.39025985e-01, 6.98107126e-01, 8.44114497e-01],\n",
" [6.99026648e-01, 6.91391860e-01, 5.95508900e-01],\n",
" [3.75264962e-01, 2.19625912e-02, 9.65963233e-01],\n",
" [8.73625070e-01, 1.70130110e-02, 4.28592791e-01]],\n",
"\n",
" [[5.30376524e-01, 5.19931919e-01, 8.79621802e-01],\n",
" [5.93691399e-01, 9.43767391e-01, 3.85923387e-01],\n",
" [2.66910663e-02, 9.36911794e-01, 1.02611787e-01],\n",
" [3.59377042e-01, 2.67568222e-01, 5.35446422e-01],\n",
" [5.51850227e-01, 6.35787754e-01, 1.98619411e-01],\n",
" [8.49409850e-01, 6.72271595e-01, 2.39559395e-01],\n",
" [1.73809713e-02, 6.90213328e-01, 4.68996474e-01],\n",
" [2.98860826e-01, 9.70887693e-02, 7.59385182e-01],\n",
" [2.57726817e-01, 9.57823991e-01, 7.18212290e-01],\n",
" [7.01041664e-01, 7.78681302e-01, 2.83077120e-01]],\n",
"\n",
" [[7.26341251e-01, 9.88746010e-02, 8.11023617e-01],\n",
" [5.17637504e-01, 1.22369589e-01, 6.26059218e-01],\n",
" [8.89381042e-01, 4.69513890e-01, 4.41358856e-01],\n",
" [1.31543858e-01, 2.52923839e-02, 4.59211802e-01],\n",
" [2.97316029e-01, 3.74157507e-01, 1.46629093e-01],\n",
" [5.42787121e-01, 4.83436833e-01, 6.48266145e-01],\n",
" [5.87451856e-01, 6.62348938e-01, 9.09155419e-01],\n",
" [4.19004871e-01, 1.82945864e-02, 6.63249102e-01],\n",
" [2.32421673e-01, 6.30460531e-01, 6.90273718e-02],\n",
" [1.00823603e-01, 6.38197200e-01, 1.54316174e-01]],\n",
"\n",
" [[3.95104675e-01, 9.54576105e-01, 7.18833793e-01],\n",
" [6.55325673e-01, 2.09750597e-01, 5.08179296e-01],\n",
" [8.35434924e-01, 3.11732329e-01, 2.53760179e-01],\n",
" [6.01300571e-01, 1.74387890e-01, 1.28901152e-01],\n",
" [3.60114137e-01, 6.03481824e-01, 6.42517616e-01],\n",
" [4.74822395e-01, 4.06953697e-02, 8.06656676e-01],\n",
" [2.35227783e-01, 1.86636675e-01, 2.92355800e-01],\n",
" [7.57334531e-01, 3.02550198e-01, 8.78392401e-01],\n",
" [9.89429375e-01, 9.17356225e-01, 9.94972892e-01],\n",
" [3.38833764e-01, 2.39923972e-01, 2.49753676e-02]],\n",
"\n",
" [[5.34504729e-02, 2.60437560e-01, 4.75257720e-01],\n",
" [5.10602046e-01, 7.24407672e-01, 1.27708925e-01],\n",
" [8.89866378e-01, 7.68391514e-01, 8.10330840e-01],\n",
" [9.04964699e-01, 9.07169014e-01, 3.30947299e-01],\n",
" [3.25628891e-01, 3.62460672e-01, 6.09180261e-01],\n",
" [3.97278800e-01, 2.87846815e-01, 9.51001737e-01],\n",
" [3.94260450e-01, 3.00806280e-01, 8.51360452e-01],\n",
" [3.62097670e-01, 2.24539340e-01, 7.61618704e-01],\n",
" [2.50380406e-01, 7.56360963e-01, 1.16230049e-01],\n",
" [9.01240322e-01, 3.73212987e-01, 7.36122529e-01]]]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,10,10,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 30ms/step\n"
]
},
{
"data": {
"text/plain": [
"array([[[[0.72204703, 0.99592865, 0.63305795],\n",
" [0.3195831 , 0.9540866 , 0.8953769 ],\n",
" [0.96829206, 0.841369 , 0.9129711 ],\n",
" [0.88090914, 0.7423906 , 0.7435633 ]],\n",
"\n",
" [[0.8272338 , 0.9749019 , 0.94703025],\n",
" [0.71864825, 0.7759291 , 0.8370553 ],\n",
" [0.9456178 , 0.9705187 , 0.6049204 ],\n",
" [0.8657805 , 0.38088953, 0.74315953]],\n",
"\n",
" [[0.72634125, 0.94376737, 0.8796218 ],\n",
" [0.5518502 , 0.6357877 , 0.5354464 ],\n",
" [0.5874519 , 0.6902133 , 0.9091554 ],\n",
" [0.70104164, 0.7786813 , 0.28307712]],\n",
"\n",
" [[0.51060206, 0.7244077 , 0.47525772],\n",
" [0.9049647 , 0.90716904, 0.6091803 ],\n",
" [0.39426044, 0.30080628, 0.85136044],\n",
" [0.90124035, 0.373213 , 0.73612255]]]], dtype=float32)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"X_in = [[[int(X[0][i][j][k]*1e36) for k in range(3)] for j in range(10)] for i in range(10)]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[['722047037048510692294577429613117440',\n",
" '995928671466629356695108803213918208',\n",
" '633057959701190742109452742097371136'],\n",
" ['319583098696415993297833594590330880',\n",
" '954086618948581700566160227891675136',\n",
" '895376897621437771937017335681908736'],\n",
" ['968292067860118608616536728156504064',\n",
" '841368944375512546814487235963912192',\n",
" '912971104374087013640201200717529088'],\n",
" ['880909150990983977881328708709515264',\n",
" '742390579543848987849194909946871808',\n",
" '743563277676321729157199064014520320']],\n",
" [['827233800468269198331599375468331008',\n",
" '974901899794306633179826070214410240',\n",
" '947030261209948995604238219295588352'],\n",
" ['718648242610498931670595372570378240',\n",
" '775929100688812511565570103251566592',\n",
" '837055315861977089490188905893330944'],\n",
" ['945617766167062485510007332488085504',\n",
" '970518691217475177799962810208223232',\n",
" '604920400036749440823474048758448128'],\n",
" ['865780466166194901147341462043623424',\n",
" '380889549714929988856634904844173312',\n",
" '743159547886228828110355377787240448']],\n",
" [['726341250795347218817727679288049664',\n",
" '943767390679553173919884888860786688',\n",
" '879621802069332614144323070292131840'],\n",
" ['551850226842807715717985301323317248',\n",
" '635787754066812906470475510293987328',\n",
" '535446422069610157460354392595103744'],\n",
" ['587451855996849487170757966572814336',\n",
" '690213328307762114218988812757893120',\n",
" '909155419360394098803395006243536896'],\n",
" ['701041664244132545037467114551640064',\n",
" '778681301680931345113753214299144192',\n",
" '283077119909484707357592139667079168']],\n",
" [['510602045587970767184337129422454784',\n",
" '724407671943644015539146665914007552',\n",
" '475257720371555822213696431964815360'],\n",
" ['904964699224751401721193296243458048',\n",
" '907169014433129231084536890038157312',\n",
" '609180260647294075284256963443032064'],\n",
" ['394260450257245198773068562662686720',\n",
" '300806280328724832256760381278519296',\n",
" '851360451534335061119480218741899264'],\n",
" ['901240322382283207026350889182953472',\n",
" '373212986501062202469532290256994304',\n",
" '736122528767077629789884872958935040']]]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out = MaxPooling2DsameInt(10, 10, 3, 2, 3, X_in)\n",
"out"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": X_in,\n",
" \"out\": out\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"with open(\"maxPooling2Dsame_stride_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "sklearn",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

37
test/MaxPooling2Dsame.js Normal file
View File

@@ -0,0 +1,37 @@
const chai = require("chai");
const path = require("path");
const wasm_tester = require("circom_tester").wasm;
const F1Field = require("ffjavascript").F1Field;
const Scalar = require("ffjavascript").Scalar;
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const Fr = new F1Field(exports.p);
const assert = chai.assert;
describe("MaxPooling2Dsame layer test", function () {
this.timeout(100000000);
// MaxPooling with strides==poolSize
it("(5,5,3) -> (3,3,3)", async () => {
const INPUT = require("../models/maxPooling2Dsame_input.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "MaxPooling2Dsame_test.circom"));
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
});
// MaxPooling with strides!=poolSize
it("(10,10,3) -> (4,4,3)", async () => {
const INPUT = require("../models/maxPooling2Dsame_stride_input.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "MaxPooling2Dsame_stride_test.circom"));
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
});
});

View File

@@ -0,0 +1,5 @@
pragma circom 2.0.0;
include "../../circuits/MaxPooling2Dsame.circom";
component main = MaxPooling2Dsame(10, 10, 3, 2, 3);

View File

@@ -0,0 +1,6 @@
pragma circom 2.0.0;
include "../../circuits/MaxPooling2Dsame.circom";
// poolSize=strides - default Keras settings
component main = MaxPooling2Dsame(5, 5, 3, 2, 2);