SumPooling2D with strides

This commit is contained in:
Cathie So
2022-11-11 20:38:33 +08:00
parent 6c0254308f
commit 0ab4025c14
9 changed files with 290 additions and 14 deletions

View File

@@ -0,0 +1 @@
{"in": [315, 32, 291, 357, 873, 772, 617, 795, 489, 982, 888, 697, 252, 828, 575, 621, 768, 660, 325, 934, 748, 456, 968, 848, 21, 278, 182, 159, 694, 435, 709, 604, 947, 562, 655, 274, 4, 86, 305, 115, 150, 585, 728, 552, 917, 435, 81, 684, 395, 204, 105, 788, 635, 8, 85, 757, 950, 922, 69, 378, 360, 357, 651, 295, 495, 405, 95, 87, 393, 388, 509, 634, 270, 678, 93, 764, 670, 288, 553, 602, 354, 558, 870, 261, 873, 788, 48, 724, 677, 823, 718, 509, 520, 748, 972, 88, 158, 732, 407, 933, 877, 345, 424, 805, 391, 814, 568, 126, 118, 611, 813, 276, 575, 106, 531, 146, 81, 341, 276, 195, 130, 626, 563, 57, 938, 718, 576, 371, 539, 459, 309, 735, 641, 594, 838, 807, 795, 697, 799, 264, 292, 451, 283, 504, 295, 318, 493, 825, 394, 339, 111, 353, 440, 499, 344, 776, 764, 887, 58, 572, 75, 510, 225, 477, 290, 433, 587, 693, 952, 872, 107, 931, 194, 955, 910, 860, 674, 94, 758, 620, 434, 973, 896, 917, 690, 53, 721, 58, 303, 556, 116, 249, 407, 661, 222, 530, 232, 828, 204, 235, 365, 55, 597, 80, 113, 985, 576, 612, 658, 784, 395, 787, 376, 862, 594, 508, 414, 964, 837, 770, 418, 174, 59, 933, 638, 976, 64, 346, 423, 521, 627, 40, 807, 270, 219, 190, 45, 162, 209, 341, 131, 902, 825, 732, 248, 359, 888, 920, 691, 81, 636, 426, 445, 991, 156, 787, 47, 464, 535, 799, 330, 311, 654, 773, 564, 576, 429, 575, 672, 529, 784, 152, 928, 100, 473, 444, 748, 46, 58, 373, 125, 612, 74, 655, 188, 170, 128, 483, 178, 857, 414, 217, 189, 277, 67, 169, 469, 399, 177, 945]}

View File

@@ -0,0 +1 @@
{"out": [1944, 2164, 2284, 2077, 2417, 2773, 1963, 2742, 1709, 1652, 3045, 1889, 2457, 2585, 2309, 1644, 1733, 1717, 2609, 3043, 1833, 1792, 2128, 1283, 722, 2160, 1342]}

View File

@@ -151,7 +151,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -178,6 +178,250 @@
" json.dump(out_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(10,10,3))\n",
"x = AveragePooling2D(pool_size=2, strides=3)(inputs)\n",
"x = Lambda(lambda x: x*4)(x)\n",
"model = Model(inputs, x)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model_2\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_3 (InputLayer) [(None, 10, 10, 3)] 0 \n",
"_________________________________________________________________\n",
"average_pooling2d_2 (Average (None, 3, 3, 3) 0 \n",
"_________________________________________________________________\n",
"lambda_2 (Lambda) (None, 3, 3, 3) 0 \n",
"=================================================================\n",
"Total params: 0\n",
"Trainable params: 0\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.31514958, 0.03200121, 0.29129004],\n",
" [0.35725668, 0.87252739, 0.77162311],\n",
" [0.61707883, 0.7945887 , 0.48907944],\n",
" [0.98197306, 0.88814753, 0.69652672],\n",
" [0.2518265 , 0.82753267, 0.57464263],\n",
" [0.62115028, 0.76805041, 0.65967975],\n",
" [0.32491691, 0.93353364, 0.74831234],\n",
" [0.45570461, 0.96830864, 0.8476189 ],\n",
" [0.02149766, 0.27808247, 0.18207897],\n",
" [0.15949257, 0.69372265, 0.43455872]],\n",
"\n",
" [[0.70915807, 0.60410698, 0.94721792],\n",
" [0.5621233 , 0.65546021, 0.27357865],\n",
" [0.00414209, 0.08635782, 0.30528659],\n",
" [0.11492599, 0.15002234, 0.58496289],\n",
" [0.72848003, 0.55169839, 0.91708802],\n",
" [0.43479205, 0.08069621, 0.68404234],\n",
" [0.3946513 , 0.20447291, 0.10467492],\n",
" [0.78817621, 0.63518792, 0.00827133],\n",
" [0.0853401 , 0.75656605, 0.95034115],\n",
" [0.92239164, 0.06871402, 0.37783711]],\n",
"\n",
" [[0.36000095, 0.35659628, 0.6507159 ],\n",
" [0.29486935, 0.49464939, 0.40502335],\n",
" [0.09509896, 0.08726498, 0.39326876],\n",
" [0.38827707, 0.50908505, 0.63443643],\n",
" [0.27030144, 0.67783072, 0.09309034],\n",
" [0.76360544, 0.67003754, 0.28767228],\n",
" [0.55305299, 0.60216561, 0.3544107 ],\n",
" [0.55839884, 0.86964781, 0.26053367],\n",
" [0.87306012, 0.78756102, 0.04817508],\n",
" [0.72406774, 0.67679246, 0.82272016]],\n",
"\n",
" [[0.71765743, 0.50852032, 0.52047892],\n",
" [0.7484707 , 0.97207503, 0.08778545],\n",
" [0.15780167, 0.73192822, 0.40718403],\n",
" [0.93263197, 0.8772701 , 0.34486053],\n",
" [0.42436095, 0.80504181, 0.39139203],\n",
" [0.81358273, 0.56754054, 0.12608038],\n",
" [0.11843567, 0.61136361, 0.81339895],\n",
" [0.27636648, 0.57453166, 0.10632468],\n",
" [0.53090786, 0.14594835, 0.08140653],\n",
" [0.34118642, 0.27554414, 0.19515355]],\n",
"\n",
" [[0.12974003, 0.6264065 , 0.56250089],\n",
" [0.05655555, 0.93847961, 0.71849845],\n",
" [0.57644684, 0.37077012, 0.53949152],\n",
" [0.45904117, 0.30854737, 0.73517714],\n",
" [0.64076017, 0.59373326, 0.83758554],\n",
" [0.80707699, 0.79461191, 0.69655474],\n",
" [0.79872758, 0.26420269, 0.29237624],\n",
" [0.45087863, 0.28258419, 0.50447663],\n",
" [0.29494657, 0.31770288, 0.49309187],\n",
" [0.82460949, 0.3940875 , 0.33865267]],\n",
"\n",
" [[0.1108653 , 0.35294351, 0.44014634],\n",
" [0.4988099 , 0.34405962, 0.77622373],\n",
" [0.76444373, 0.88689451, 0.05756076],\n",
" [0.57160174, 0.0752442 , 0.5098132 ],\n",
" [0.22539676, 0.47741414, 0.28993556],\n",
" [0.43298235, 0.58710277, 0.69306001],\n",
" [0.9521223 , 0.87239108, 0.10672981],\n",
" [0.93125144, 0.19405455, 0.95483289],\n",
" [0.91030892, 0.85961313, 0.67439157],\n",
" [0.09377237, 0.75818836, 0.61985122]],\n",
"\n",
" [[0.4344654 , 0.97297157, 0.89560878],\n",
" [0.91664946, 0.68966445, 0.0530751 ],\n",
" [0.72099738, 0.05779864, 0.30259649],\n",
" [0.55598956, 0.11611106, 0.24856552],\n",
" [0.40690072, 0.66148966, 0.22159354],\n",
" [0.53035294, 0.23237414, 0.82781172],\n",
" [0.20375017, 0.23486322, 0.36461596],\n",
" [0.05525619, 0.59671011, 0.08001122],\n",
" [0.11250979, 0.98519728, 0.57553523],\n",
" [0.6117834 , 0.65811775, 0.78386287]],\n",
"\n",
" [[0.39532528, 0.78660638, 0.37617851],\n",
" [0.86246711, 0.59398046, 0.50843286],\n",
" [0.41395181, 0.96399598, 0.8374128 ],\n",
" [0.76981858, 0.41760042, 0.17438256],\n",
" [0.05937649, 0.93289121, 0.63833505],\n",
" [0.97571178, 0.06364159, 0.34572432],\n",
" [0.42278241, 0.52111442, 0.62746908],\n",
" [0.0401781 , 0.80713288, 0.26990436],\n",
" [0.21850787, 0.19009324, 0.04497292],\n",
" [0.16176602, 0.20893733, 0.34094974]],\n",
"\n",
" [[0.13094336, 0.90151022, 0.82541695],\n",
" [0.73192844, 0.24791076, 0.3587372 ],\n",
" [0.88818939, 0.92023872, 0.69098959],\n",
" [0.08104613, 0.6361497 , 0.42552169],\n",
" [0.44517886, 0.99055202, 0.15580116],\n",
" [0.78742252, 0.04735346, 0.46423316],\n",
" [0.53474903, 0.79917168, 0.33019955],\n",
" [0.31087978, 0.65384266, 0.77275665],\n",
" [0.56393354, 0.5761927 , 0.4287843 ],\n",
" [0.57457285, 0.67154059, 0.52881047]],\n",
"\n",
" [[0.78373278, 0.15164648, 0.92791502],\n",
" [0.0999971 , 0.47319914, 0.44424683],\n",
" [0.74758969, 0.04583226, 0.0579972 ],\n",
" [0.37325021, 0.12464474, 0.61199188],\n",
" [0.07404238, 0.65504221, 0.18787021],\n",
" [0.16955187, 0.12750002, 0.48252436],\n",
" [0.17829354, 0.85701326, 0.41402596],\n",
" [0.21677806, 0.18949005, 0.27735136],\n",
" [0.06721647, 0.16941253, 0.46916978],\n",
" [0.39921131, 0.17705116, 0.94534667]]]])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,10,10,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[1.9436876 , 2.1640959 , 2.2837098 ],\n",
" [2.0772057 , 2.4174008 , 2.7732203 ],\n",
" [1.963449 , 2.741503 , 1.7088774 ]],\n",
"\n",
" [[1.6524236 , 3.0454814 , 1.8892637 ],\n",
" [2.4567943 , 2.5845926 , 2.3090153 ],\n",
" [1.6444083 , 1.7326821 , 1.7165766 ]],\n",
"\n",
" [[2.6089072 , 3.043223 , 1.8332952 ],\n",
" [1.7920854 , 2.1280923 , 1.2828767 ],\n",
" [0.72196686, 2.1598206 , 1.3420006 ]]]], dtype=float32)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": (X*1000).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"out_json = {\n",
" \"out\": (y*1000).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"with open(\"sumPooling2D_stride_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"with open(\"sumPooling2D_stride_output.json\", \"w\") as f:\n",
" json.dump(out_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,