Files
circomlib-ml/models/batchNormalization.ipynb
2022-11-12 20:42:10 +08:00

843 lines
29 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Input, BatchNormalization, Dense, Flatten\n",
"from tensorflow.keras import Model\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(5,5,3))\n",
"out = BatchNormalization()(inputs)\n",
"out = Flatten()(out)\n",
"out = Dense(1)(out)\n",
"model = Model(inputs, out)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model_1\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_3 (InputLayer) [(None, 5, 5, 3)] 0 \n",
"_________________________________________________________________\n",
"batch_normalization_2 (Batch (None, 5, 5, 3) 12 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 75) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 1) 76 \n",
"=================================================================\n",
"Total params: 88\n",
"Trainable params: 82\n",
"Non-trainable params: 6\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.19552316, 0.40459235, 0.65069767],\n",
" [0.97024076, 0.81831547, 0.5177182 ],\n",
" [0.05950918, 0.03350272, 0.9602491 ],\n",
" [0.08329085, 0.30977063, 0.72642731],\n",
" [0.54523607, 0.01088976, 0.82204523]],\n",
"\n",
" [[0.54181334, 0.16416867, 0.26710512],\n",
" [0.97727194, 0.83483966, 0.24119749],\n",
" [0.31456203, 0.78914841, 0.65570909],\n",
" [0.82076783, 0.80354282, 0.08698502],\n",
" [0.15471074, 0.02844633, 0.95682819]],\n",
"\n",
" [[0.98696698, 0.95562821, 0.93463443],\n",
" [0.82143354, 0.56738999, 0.41847729],\n",
" [0.23732813, 0.73090915, 0.52152988],\n",
" [0.66008344, 0.19977572, 0.46451022],\n",
" [0.65939235, 0.70715058, 0.12179132]],\n",
"\n",
" [[0.89408023, 0.76606208, 0.89090992],\n",
" [0.07624889, 0.99371095, 0.1459594 ],\n",
" [0.93944538, 0.47966536, 0.06641057],\n",
" [0.48141466, 0.88173281, 0.45451347],\n",
" [0.68325521, 0.81211422, 0.6974409 ]],\n",
"\n",
" [[0.45623391, 0.24889916, 0.48213288],\n",
" [0.97917808, 0.44051506, 0.83544472],\n",
" [0.81790414, 0.714095 , 0.71208781],\n",
" [0.33263783, 0.20025252, 0.88928298],\n",
" [0.51657685, 0.06946578, 0.24005295]]],\n",
"\n",
"\n",
" [[[0.41251204, 0.08574465, 0.12044708],\n",
" [0.93664694, 0.68204459, 0.00902096],\n",
" [0.27111802, 0.49101903, 0.14119703],\n",
" [0.37685715, 0.48218268, 0.14003066],\n",
" [0.49494717, 0.59536967, 0.65260165]],\n",
"\n",
" [[0.51329764, 0.23448758, 0.43125756],\n",
" [0.35955841, 0.62410782, 0.69694018],\n",
" [0.08759923, 0.70741018, 0.51275518],\n",
" [0.4564451 , 0.23809092, 0.28503377],\n",
" [0.89275057, 0.13423004, 0.33386003]],\n",
"\n",
" [[0.79615467, 0.36188927, 0.55795231],\n",
" [0.68212191, 0.35490659, 0.04314993],\n",
" [0.69512635, 0.3449032 , 0.55855772],\n",
" [0.12583677, 0.11592502, 0.3908717 ],\n",
" [0.62589457, 0.44257744, 0.2509827 ]],\n",
"\n",
" [[0.96085928, 0.96575564, 0.86015473],\n",
" [0.12703668, 0.75852669, 0.43599641],\n",
" [0.09436329, 0.57686766, 0.78204905],\n",
" [0.06632954, 0.81858282, 0.49565418],\n",
" [0.70864398, 0.21069811, 0.08654022]],\n",
"\n",
" [[0.96314029, 0.51009506, 0.96585813],\n",
" [0.94763539, 0.8855648 , 0.80125637],\n",
" [0.08834902, 0.86867179, 0.21453661],\n",
" [0.5445098 , 0.97082482, 0.85593957],\n",
" [0.49210416, 0.43450547, 0.13265603]]],\n",
"\n",
"\n",
" [[[0.03283252, 0.82868096, 0.35256111],\n",
" [0.77679631, 0.23561114, 0.38093588],\n",
" [0.49257694, 0.58746603, 0.12927319],\n",
" [0.19976394, 0.05557136, 0.64552896],\n",
" [0.5255962 , 0.29374333, 0.11068957]],\n",
"\n",
" [[0.38891949, 0.29918601, 0.35197725],\n",
" [0.9955491 , 0.00327387, 0.62204086],\n",
" [0.99480043, 0.90241281, 0.45853315],\n",
" [0.81022432, 0.22181264, 0.92766954],\n",
" [0.93221732, 0.75516309, 0.09468144]],\n",
"\n",
" [[0.88470667, 0.60910815, 0.4566243 ],\n",
" [0.86763459, 0.83159756, 0.96109032],\n",
" [0.22208161, 0.67459461, 0.91500101],\n",
" [0.79349498, 0.70085273, 0.10879059],\n",
" [0.55167307, 0.98773022, 0.79994771]],\n",
"\n",
" [[0.99263745, 0.82693783, 0.23376903],\n",
" [0.07539697, 0.05538842, 0.85691306],\n",
" [0.67048213, 0.76165171, 0.20386099],\n",
" [0.73758635, 0.48078414, 0.35278882],\n",
" [0.45800851, 0.42588808, 0.97806156]],\n",
"\n",
" [[0.7585081 , 0.31273093, 0.0779818 ],\n",
" [0.70683551, 0.78945448, 0.28707299],\n",
" [0.16351077, 0.82045415, 0.63177139],\n",
" [0.91288109, 0.16592425, 0.2684148 ],\n",
" [0.04510453, 0.06687807, 0.47876222]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[0.80388499, 0.8306954 , 0.88533695],\n",
" [0.2544695 , 0.42772413, 0.29397423],\n",
" [0.30088549, 0.07140127, 0.08414232],\n",
" [0.59652497, 0.61840432, 0.09003471],\n",
" [0.30686121, 0.84965596, 0.55377544]],\n",
"\n",
" [[0.64575651, 0.78031831, 0.03105214],\n",
" [0.68957908, 0.84591631, 0.32222136],\n",
" [0.57994328, 0.80673372, 0.87924149],\n",
" [0.55281667, 0.50330466, 0.8140472 ],\n",
" [0.59172028, 0.32120471, 0.06004827]],\n",
"\n",
" [[0.46860565, 0.43696271, 0.91800029],\n",
" [0.79865908, 0.19934814, 0.74944044],\n",
" [0.4385724 , 0.84540006, 0.41441632],\n",
" [0.12826746, 0.92421342, 0.04584626],\n",
" [0.39343792, 0.58411404, 0.87618497]],\n",
"\n",
" [[0.70076119, 0.08611802, 0.22301873],\n",
" [0.31479409, 0.54422436, 0.85506774],\n",
" [0.88506506, 0.33433081, 0.15683975],\n",
" [0.54448916, 0.20975331, 0.25313311],\n",
" [0.46810122, 0.50938211, 0.78746 ]],\n",
"\n",
" [[0.17822056, 0.409899 , 0.58227314],\n",
" [0.72493334, 0.67675684, 0.39940547],\n",
" [0.57085187, 0.7474426 , 0.56146696],\n",
" [0.58634465, 0.19243015, 0.23359699],\n",
" [0.85772102, 0.47205961, 0.87255158]]],\n",
"\n",
"\n",
" [[[0.54245385, 0.74207376, 0.54623721],\n",
" [0.64837447, 0.39587349, 0.04896531],\n",
" [0.88401149, 0.71336029, 0.34635153],\n",
" [0.61111006, 0.6177961 , 0.54227982],\n",
" [0.44277679, 0.83113755, 0.04340859]],\n",
"\n",
" [[0.38069219, 0.90538212, 0.19320333],\n",
" [0.01068953, 0.39745318, 0.27597242],\n",
" [0.53845729, 0.17192169, 0.95972409],\n",
" [0.88988958, 0.38178294, 0.18909108],\n",
" [0.84478266, 0.83814902, 0.32351588]],\n",
"\n",
" [[0.52913815, 0.99913557, 0.40826834],\n",
" [0.36888234, 0.24461337, 0.30100798],\n",
" [0.93609241, 0.91217449, 0.33189497],\n",
" [0.7143837 , 0.75380449, 0.74337091],\n",
" [0.39894018, 0.10044562, 0.85844636]],\n",
"\n",
" [[0.79250446, 0.71461032, 0.54259891],\n",
" [0.09899685, 0.89501629, 0.26856979],\n",
" [0.27151432, 0.17484271, 0.03369879],\n",
" [0.9199204 , 0.70766018, 0.36239686],\n",
" [0.31023809, 0.67833763, 0.1695362 ]],\n",
"\n",
" [[0.99527975, 0.56018757, 0.17711387],\n",
" [0.3569703 , 0.28941656, 0.54529331],\n",
" [0.67249171, 0.03973704, 0.21149269],\n",
" [0.74074145, 0.28847081, 0.67414207],\n",
" [0.33986564, 0.83839466, 0.91125865]]],\n",
"\n",
"\n",
" [[[0.32102491, 0.84690145, 0.2979887 ],\n",
" [0.55869975, 0.32024372, 0.80889414],\n",
" [0.91025205, 0.69196448, 0.57958605],\n",
" [0.83198663, 0.26188438, 0.76765691],\n",
" [0.44824623, 0.90996026, 0.94743678]],\n",
"\n",
" [[0.01253521, 0.42097529, 0.07818629],\n",
" [0.68072037, 0.45587601, 0.50010411],\n",
" [0.73692317, 0.50393005, 0.02089452],\n",
" [0.84774507, 0.9384536 , 0.19898286],\n",
" [0.17313704, 0.24205738, 0.61104063]],\n",
"\n",
" [[0.19622547, 0.49347153, 0.74742077],\n",
" [0.44855744, 0.26467912, 0.24315565],\n",
" [0.91279984, 0.3891196 , 0.83430496],\n",
" [0.15561309, 0.98367959, 0.91468684],\n",
" [0.89776262, 0.81107386, 0.01643423]],\n",
"\n",
" [[0.39498909, 0.26282065, 0.75595037],\n",
" [0.77790085, 0.89858599, 0.43854238],\n",
" [0.24432012, 0.29300183, 0.51955436],\n",
" [0.96470115, 0.88156727, 0.75932967],\n",
" [0.90743858, 0.24770177, 0.37398297]],\n",
"\n",
" [[0.63992795, 0.70556042, 0.57476261],\n",
" [0.83028084, 0.90006127, 0.37940215],\n",
" [0.35120713, 0.20167522, 0.7885152 ],\n",
" [0.76128603, 0.00147126, 0.33921825],\n",
" [0.17960168, 0.73438349, 0.3825673 ]]]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train = np.random.rand(100,5,5,3)\n",
"X_train"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.41246058],\n",
" [0.20653047],\n",
" [0.89323387],\n",
" [0.77751859],\n",
" [0.71370253],\n",
" [0.95423612],\n",
" [0.73784913],\n",
" [0.26439079],\n",
" [0.98359435],\n",
" [0.31472127],\n",
" [0.93990584],\n",
" [0.13038553],\n",
" [0.66315094],\n",
" [0.99092294],\n",
" [0.00492079],\n",
" [0.33818935],\n",
" [0.81103904],\n",
" [0.70299052],\n",
" [0.98968483],\n",
" [0.30019187],\n",
" [0.07267281],\n",
" [0.07848536],\n",
" [0.09338778],\n",
" [0.12276658],\n",
" [0.55228728],\n",
" [0.64323192],\n",
" [0.56388056],\n",
" [0.16935227],\n",
" [0.59849839],\n",
" [0.0426659 ],\n",
" [0.28957528],\n",
" [0.42632845],\n",
" [0.3771928 ],\n",
" [0.73206006],\n",
" [0.24224496],\n",
" [0.29938391],\n",
" [0.89107997],\n",
" [0.4720799 ],\n",
" [0.29481523],\n",
" [0.81441205],\n",
" [0.1567971 ],\n",
" [0.49008399],\n",
" [0.23646704],\n",
" [0.62417537],\n",
" [0.53533012],\n",
" [0.36987479],\n",
" [0.79447313],\n",
" [0.35812802],\n",
" [0.35472884],\n",
" [0.26632959],\n",
" [0.87546254],\n",
" [0.99142194],\n",
" [0.60332898],\n",
" [0.96716054],\n",
" [0.94736124],\n",
" [0.86073541],\n",
" [0.49582776],\n",
" [0.73568325],\n",
" [0.77496782],\n",
" [0.262697 ],\n",
" [0.08363569],\n",
" [0.80047577],\n",
" [0.57348913],\n",
" [0.41349442],\n",
" [0.96962925],\n",
" [0.06607914],\n",
" [0.29725151],\n",
" [0.25198772],\n",
" [0.00255906],\n",
" [0.13953033],\n",
" [0.2642294 ],\n",
" [0.97552587],\n",
" [0.28051008],\n",
" [0.59269401],\n",
" [0.58358134],\n",
" [0.35289514],\n",
" [0.45055523],\n",
" [0.9325669 ],\n",
" [0.25847543],\n",
" [0.41978902],\n",
" [0.79561462],\n",
" [0.10705706],\n",
" [0.22216489],\n",
" [0.59989157],\n",
" [0.88503714],\n",
" [0.39226583],\n",
" [0.88642095],\n",
" [0.47593827],\n",
" [0.51311448],\n",
" [0.11776821],\n",
" [0.91050219],\n",
" [0.87378644],\n",
" [0.23428406],\n",
" [0.75091434],\n",
" [0.6610274 ],\n",
" [0.30534909],\n",
" [0.53461862],\n",
" [0.66292414],\n",
" [0.56894196],\n",
" [0.06660922]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train = np.random.rand(100,1)\n",
"y_train"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"4/4 [==============================] - 0s 566us/step - loss: 1.7214\n",
"Epoch 2/10\n",
"4/4 [==============================] - 0s 660us/step - loss: 1.6521\n",
"Epoch 3/10\n",
"4/4 [==============================] - 0s 1ms/step - loss: 1.6740\n",
"Epoch 4/10\n",
"4/4 [==============================] - 0s 729us/step - loss: 1.6128\n",
"Epoch 5/10\n",
"4/4 [==============================] - 0s 1ms/step - loss: 1.4255\n",
"Epoch 6/10\n",
"4/4 [==============================] - 0s 585us/step - loss: 1.4899\n",
"Epoch 7/10\n",
"4/4 [==============================] - 0s 1ms/step - loss: 1.3811\n",
"Epoch 8/10\n",
"4/4 [==============================] - 0s 673us/step - loss: 1.3549\n",
"Epoch 9/10\n",
"4/4 [==============================] - 0s 653us/step - loss: 1.3313\n",
"Epoch 10/10\n",
"4/4 [==============================] - 0s 1ms/step - loss: 1.2854\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x15461dd60>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.compile(\"adam\", \"mse\")\n",
"model.fit(X_train, y_train, epochs=10)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Variable 'batch_normalization_2/gamma:0' shape=(3,) dtype=float32, numpy=array([0.9656307, 0.9650906, 0.9646477], dtype=float32)>,\n",
" <tf.Variable 'batch_normalization_2/beta:0' shape=(3,) dtype=float32, numpy=array([0.03126225, 0.03251555, 0.0313088 ], dtype=float32)>,\n",
" <tf.Variable 'batch_normalization_2/moving_mean:0' shape=(3,) dtype=float32, numpy=array([0.16415757, 0.16638677, 0.16154662], dtype=float32)>,\n",
" <tf.Variable 'batch_normalization_2/moving_variance:0' shape=(3,) dtype=float32, numpy=array([0.6964935 , 0.69624424, 0.6967338 ], dtype=float32)>]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.layers[1].weights"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.53603175, 0.16288177, 0.54647377],\n",
" [0.58968269, 0.51633636, 0.09084824],\n",
" [0.58954409, 0.60274177, 0.16628126],\n",
" [0.03036366, 0.25776433, 0.97264483],\n",
" [0.06753911, 0.4969747 , 0.02626603]],\n",
"\n",
" [[0.92805506, 0.92962356, 0.94991846],\n",
" [0.40984699, 0.57242913, 0.73624703],\n",
" [0.27120968, 0.30428539, 0.6547197 ],\n",
" [0.6895789 , 0.12203021, 0.56160566],\n",
" [0.35853814, 0.61396961, 0.30326431]],\n",
"\n",
" [[0.49895694, 0.26192641, 0.41918769],\n",
" [0.56496371, 0.13934069, 0.77930897],\n",
" [0.9652276 , 0.68000352, 0.59384582],\n",
" [0.18267196, 0.26760574, 0.93864666],\n",
" [0.49916607, 0.63215712, 0.38614211]],\n",
"\n",
" [[0.46365438, 0.3845917 , 0.6604073 ],\n",
" [0.59669509, 0.22802217, 0.62536791],\n",
" [0.37852067, 0.51773501, 0.96948045],\n",
" [0.46492378, 0.09701206, 0.90831063],\n",
" [0.31265477, 0.43007139, 0.82608669]],\n",
"\n",
" [[0.32252988, 0.28388506, 0.15159293],\n",
" [0.54518128, 0.73664414, 0.27618411],\n",
" [0.41446863, 0.45379391, 0.65724072],\n",
" [0.1670575 , 0.82368301, 0.41525341],\n",
" [0.05091919, 0.78432432, 0.29655634]]]])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,5,5,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[ 0.4612311 , 0.02846454, 0.47584018],\n",
" [ 0.5232636 , 0.4369807 , -0.05033689],\n",
" [ 0.5231033 , 0.5368464 , 0.03677658],\n",
" [-0.12343314, 0.13812803, 0.9680018 ],\n",
" [-0.08045009, 0.41460282, -0.12491935]],\n",
"\n",
" [[ 0.9144969 , 0.9146502 , 0.94175637],\n",
" [ 0.31533363, 0.50181156, 0.69499886],\n",
" [ 0.1550382 , 0.19189616, 0.6008475 ],\n",
" [ 0.63876563, -0.01875092, 0.4933151 ],\n",
" [ 0.25600925, 0.54982334, 0.19497083]],\n",
"\n",
" [[ 0.41836447, 0.14293846, 0.3288444 ],\n",
" [ 0.49468288, 0.00125621, 0.7447288 ],\n",
" [ 0.9574766 , 0.62614405, 0.53054756],\n",
" [ 0.05266898, 0.14950253, 0.92873925],\n",
" [ 0.41860625, 0.57084405, 0.29068184]],\n",
"\n",
" [[ 0.37754688, 0.28471267, 0.6074158 ],\n",
" [ 0.53137136, 0.10375258, 0.5669507 ],\n",
" [ 0.27911344, 0.43859717, 0.96434736],\n",
" [ 0.37901458, -0.04766643, 0.8937058 ],\n",
" [ 0.20295788, 0.33727726, 0.7987498 ]],\n",
"\n",
" [[ 0.2143757 , 0.16831785, 0.01981383],\n",
" [ 0.47181004, 0.69160825, 0.16369738],\n",
" [ 0.32067728, 0.3646953 , 0.6037588 ],\n",
" [ 0.03461521, 0.7922061 , 0.32430092],\n",
" [-0.0996664 , 0.74671614, 0.18722416]]]], dtype=float32)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_model = Model(inputs, model.layers[1].output)\n",
"y = new_model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"gamma = model.layers[1].weights[0].numpy()\n",
"beta = model.layers[1].weights[1].numpy()\n",
"moving_mean = model.layers[1].weights[2].numpy()\n",
"moving_var = model.layers[1].weights[3].numpy()\n",
"epsilon = model.layers[1].epsilon"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"a = gamma/(moving_var+epsilon)**.5\n",
"b = beta-gamma*moving_mean/(moving_var+epsilon)**.5"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[-2.03794953e-08, 5.91488761e-09, 1.45154126e-08],\n",
" [ 7.98100520e-08, 1.98810133e-08, 7.91968675e-09],\n",
" [ 4.81973802e-08, -2.20449419e-08, -1.88879183e-09],\n",
" [ 2.64212180e-10, 1.55602272e-08, -7.21641903e-08],\n",
" [-8.82741547e-10, -2.68981404e-08, 1.23617632e-08]],\n",
"\n",
" [[ 2.51851505e-08, -1.65522146e-08, -4.00145346e-08],\n",
" [ 2.10715177e-08, -7.90676278e-08, -8.24527382e-08],\n",
" [ 5.33205935e-09, -1.58572262e-08, 1.61795255e-09],\n",
" [ 4.38547865e-10, 3.73349793e-09, -5.56143572e-08],\n",
" [ 3.41888787e-08, 3.81396181e-09, 5.69365846e-09]],\n",
"\n",
" [[ 2.33330081e-08, 6.56881610e-09, -1.21710513e-08],\n",
" [-2.89912733e-09, 4.82403046e-09, -1.72175791e-08],\n",
" [ 6.43033263e-08, -4.13130072e-08, -1.50603818e-09],\n",
" [ 7.43607972e-09, 4.41849135e-09, 2.61667081e-08],\n",
" [ 1.46215137e-08, -6.68016564e-08, -4.45506274e-08]],\n",
"\n",
" [[ 1.13439502e-08, -1.31488602e-08, 1.62490282e-08],\n",
" [-3.02629793e-08, 4.91692291e-09, -2.13291591e-08],\n",
" [-4.75657103e-09, -4.41433378e-08, -1.26229787e-07],\n",
" [ 9.41011957e-09, 2.70651909e-09, -1.61074364e-08],\n",
" [ 1.15060916e-09, 3.34245048e-09, -5.07417969e-08]],\n",
"\n",
" [[ 4.33865668e-09, 4.46010895e-09, -6.16794888e-09],\n",
" [ 3.34336615e-08, -4.74926964e-09, -1.07198432e-08],\n",
" [ 2.74091595e-08, 1.35368101e-08, -5.05023202e-08],\n",
" [ 4.70631774e-09, -2.86037326e-08, -1.74001409e-08],\n",
" [-1.16552118e-08, 3.97495532e-08, -2.19518315e-09]]]])"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y-(gamma*(X-moving_mean)/((moving_var+epsilon)**.5)+beta)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[-1.61475423e-08, 5.64923888e-09, -2.40047562e-08],\n",
" [ 8.45288702e-08, 8.12395617e-09, -3.95743559e-09],\n",
" [ 5.29149408e-08, -3.66111857e-08, -1.81769167e-08],\n",
" [-9.26127786e-11, 1.22097854e-08, -1.35605033e-07],\n",
" [-9.02211070e-10, -3.80257174e-08, 4.26113381e-09]],\n",
"\n",
" [[ 3.29745908e-08, -4.17459400e-08, -1.02126435e-07],\n",
" [ 2.41583839e-08, -9.26483562e-08, -1.32070041e-07],\n",
" [ 7.16083581e-09, -2.07201465e-08, -4.32319784e-08],\n",
" [ 6.06389250e-09, 4.79600271e-09, -9.50193745e-08],\n",
" [ 3.68101333e-08, -1.11173180e-08, -1.86046535e-08]],\n",
"\n",
" [[ 2.72285189e-08, 3.08305820e-09, -4.32480702e-08],\n",
" [ 1.59537394e-09, 5.32374209e-09, -6.93529612e-08],\n",
" [ 7.24300957e-08, -5.83911616e-08, -4.27963234e-08],\n",
" [ 8.46140438e-09, 7.48088591e-10, -3.52860654e-08],\n",
" [ 1.85189223e-08, -8.23242431e-08, -7.36952816e-08]],\n",
"\n",
" [[ 1.49191015e-08, -2.06226749e-08, -2.89334895e-08],\n",
" [-2.54805258e-08, 2.53344934e-09, -6.44627220e-08],\n",
" [-1.95398131e-09, -5.59458676e-08, -1.89485590e-07],\n",
" [ 1.29967902e-08, 4.58240647e-09, -7.57862879e-08],\n",
" [ 3.35548667e-09, -5.60998636e-09, -1.05612541e-07]],\n",
"\n",
" [[ 6.63314770e-09, 2.60438060e-10, -2.15971630e-08],\n",
" [ 3.77486434e-08, -2.36689068e-08, -3.34346205e-08],\n",
" [ 3.05379656e-08, 3.81311405e-09, -9.54996694e-08],\n",
" [ 5.58994609e-09, -5.03531512e-08, -4.82471003e-08],\n",
" [-1.18255018e-08, 1.92797525e-08, -2.61012420e-08]]]])"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y-(a*X+b)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1.1562214, 1.1557811, 1.1548455], dtype=float32)"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": (X*1000).round().astype(int).flatten().tolist(),\n",
" \"a\": (a*1000).round().astype(int).flatten().tolist(),\n",
" \"b\": (b*1000*1000).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"out_json = {\n",
" \"out\": (y*1000000).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"with open(\"batchNormalization_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"with open(\"batchNormalization_output.json\", \"w\") as f:\n",
" json.dump(out_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 5, 5, 3)"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3,)"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.shape"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3,)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"b.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.6 ('tf24')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "11280bdb37aa6bc5d4cf1e4de756386eb1f9eecd8dcdefa77636dfac7be2370d"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}