Files
circomlib-ml/models/batchNormalization.ipynb
drCathieSo.eth 30a43a2712 Version 2.0.0 (#5)
* feat: `Poly` renamed to `ZeLU` with scaling implemented

* fix: assertion in `ZeLU`

* feat: `AveragePooling2D` with scaling

* feat: `BatchNorm` with scaling

* feat: `Conv1D` with scaling

* feat: `Conv2D` with scaling

* feat: `Dense` with scaling

* fix: assertion in `Dense`

* feat: `GlobalAveragePooling2D` with scaling

* feat: input-only `ArgMax`

* feat: input-only `Flatten2D`

* feat: input-only `GlobalMaxPooling2D`

* feat: input-only `MaxPooling2D`

* feat: input-only `ReLU`

* test: precision up to 36 decimals

* chore: clean up

* test: model1 with 36 decimals

* fix: ReLU should use `p//2` as threshold

* test: clean up

* test: mnist model with 18 decimals

* build: Update package.json version to 2.0.0

* chore: Update README with warning message
2023-10-24 02:50:34 +07:00

977 lines
38 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"p = 21888242871839275222246405745257275088548364400416034343698204186575808495617"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Input, BatchNormalization, Dense, Flatten\n",
"from tensorflow.keras import Model\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(5,5,3))\n",
"out = BatchNormalization()(inputs)\n",
"out = Flatten()(out)\n",
"out = Dense(1)(out)\n",
"model = Model(inputs, out)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n",
" \n",
" batch_normalization (BatchN (None, 5, 5, 3) 12 \n",
" ormalization) \n",
" \n",
" flatten (Flatten) (None, 75) 0 \n",
" \n",
" dense (Dense) (None, 1) 76 \n",
" \n",
"=================================================================\n",
"Total params: 88\n",
"Trainable params: 82\n",
"Non-trainable params: 6\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.51546565, 0.46651282, 0.78133227],\n",
" [0.29092001, 0.15424294, 0.31568105],\n",
" [0.42554169, 0.17798232, 0.88401623],\n",
" [0.89061342, 0.56954272, 0.0481971 ],\n",
" [0.91385411, 0.87839928, 0.54333657]],\n",
"\n",
" [[0.06954288, 0.9189417 , 0.16824991],\n",
" [0.61834797, 0.48611683, 0.05698664],\n",
" [0.54958275, 0.40614748, 0.93471022],\n",
" [0.96738304, 0.9310907 , 0.05346684],\n",
" [0.01504873, 0.45594357, 0.92055095]],\n",
"\n",
" [[0.11021081, 0.40673465, 0.56412873],\n",
" [0.5385549 , 0.06119234, 0.5499304 ],\n",
" [0.409637 , 0.72651771, 0.71666944],\n",
" [0.99635655, 0.13653654, 0.10949123],\n",
" [0.2601131 , 0.70956364, 0.51173637]],\n",
"\n",
" [[0.29500952, 0.50029805, 0.48711255],\n",
" [0.58302715, 0.07682466, 0.06070311],\n",
" [0.63143749, 0.55553472, 0.62363371],\n",
" [0.34286961, 0.28562232, 0.10381272],\n",
" [0.5267341 , 0.44532117, 0.74259914]],\n",
"\n",
" [[0.89833895, 0.70393118, 0.3390096 ],\n",
" [0.31743696, 0.27991868, 0.41889825],\n",
" [0.88849899, 0.36099248, 0.8606479 ],\n",
" [0.8115506 , 0.9383675 , 0.29448095],\n",
" [0.97113074, 0.58140933, 0.96698766]]],\n",
"\n",
"\n",
" [[[0.22152899, 0.69173392, 0.10892224],\n",
" [0.45291025, 0.87475282, 0.78825403],\n",
" [0.30206652, 0.65193234, 0.74018285],\n",
" [0.28822764, 0.82794295, 0.88511038],\n",
" [0.65083834, 0.65900742, 0.37215784]],\n",
"\n",
" [[0.42637533, 0.82491213, 0.99759185],\n",
" [0.5311499 , 0.71585492, 0.35118 ],\n",
" [0.24486032, 0.91171808, 0.77877967],\n",
" [0.1899115 , 0.01429132, 0.53225576],\n",
" [0.26831007, 0.3223111 , 0.2793977 ]],\n",
"\n",
" [[0.43628462, 0.855694 , 0.45144496],\n",
" [0.17713725, 0.6788593 , 0.53221073],\n",
" [0.24239035, 0.29885874, 0.00793102],\n",
" [0.47036933, 0.83716811, 0.01747103],\n",
" [0.52148184, 0.88785307, 0.19371033]],\n",
"\n",
" [[0.45964463, 0.11785621, 0.71888447],\n",
" [0.37170738, 0.91271932, 0.28649456],\n",
" [0.3974761 , 0.54316076, 0.56468016],\n",
" [0.65716588, 0.00913184, 0.42107159],\n",
" [0.62804608, 0.19359813, 0.13687296]],\n",
"\n",
" [[0.15432216, 0.24376177, 0.73777443],\n",
" [0.37740194, 0.51219611, 0.18536464],\n",
" [0.00561986, 0.63145406, 0.57055781],\n",
" [0.45424905, 0.46157587, 0.87326481],\n",
" [0.44888708, 0.09268558, 0.54308092]]],\n",
"\n",
"\n",
" [[[0.44003404, 0.82605955, 0.53120479],\n",
" [0.34508592, 0.54358359, 0.63107103],\n",
" [0.18258236, 0.5828249 , 0.35215564],\n",
" [0.35995517, 0.58247612, 0.866913 ],\n",
" [0.37331628, 0.72105525, 0.86011076]],\n",
"\n",
" [[0.0037806 , 0.80184814, 0.19458159],\n",
" [0.01779705, 0.68048193, 0.08877037],\n",
" [0.72192341, 0.94542319, 0.7099795 ],\n",
" [0.38773382, 0.0054821 , 0.86694736],\n",
" [0.59622908, 0.88670846, 0.54186724]],\n",
"\n",
" [[0.91125128, 0.24628816, 0.91306807],\n",
" [0.20544878, 0.6600348 , 0.02261235],\n",
" [0.08034128, 0.50809017, 0.37310936],\n",
" [0.66526679, 0.6763669 , 0.3438879 ],\n",
" [0.54922619, 0.05982564, 0.21239447]],\n",
"\n",
" [[0.09812086, 0.36775343, 0.23264159],\n",
" [0.07952414, 0.14821264, 0.30500863],\n",
" [0.78574487, 0.20629136, 0.28249732],\n",
" [0.18532948, 0.62820359, 0.40885501],\n",
" [0.53428171, 0.01419418, 0.96868481]],\n",
"\n",
" [[0.46122128, 0.60562586, 0.19714442],\n",
" [0.83266218, 0.62255273, 0.48534285],\n",
" [0.78304596, 0.63898229, 0.91084764],\n",
" [0.04902518, 0.69016331, 0.16760015],\n",
" [0.88179687, 0.02644824, 0.78977572]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[0.17074279, 0.10103348, 0.63216146],\n",
" [0.32308489, 0.88627716, 0.84782501],\n",
" [0.1032571 , 0.02331302, 0.36067615],\n",
" [0.01463099, 0.89390044, 0.7491627 ],\n",
" [0.94200448, 0.90298754, 0.55947368]],\n",
"\n",
" [[0.7041535 , 0.39417727, 0.07809514],\n",
" [0.52135847, 0.0290839 , 0.45233884],\n",
" [0.71856329, 0.81167091, 0.9178133 ],\n",
" [0.02404749, 0.49104236, 0.21492727],\n",
" [0.15136666, 0.01310515, 0.29748093]],\n",
"\n",
" [[0.88120014, 0.16076401, 0.44107649],\n",
" [0.25002897, 0.68659302, 0.01653978],\n",
" [0.96055884, 0.40286758, 0.11771713],\n",
" [0.66927334, 0.77825925, 0.30680967],\n",
" [0.31482238, 0.7749523 , 0.11406879]],\n",
"\n",
" [[0.53919419, 0.60783489, 0.55848085],\n",
" [0.32747631, 0.09879024, 0.28296219],\n",
" [0.65211663, 0.70552049, 0.55715698],\n",
" [0.31646225, 0.26783995, 0.47998016],\n",
" [0.37794728, 0.42765623, 0.97190187]],\n",
"\n",
" [[0.53879646, 0.45261775, 0.46913403],\n",
" [0.2670922 , 0.79348712, 0.49027602],\n",
" [0.25085863, 0.10323594, 0.15697154],\n",
" [0.31783142, 0.54529695, 0.90404337],\n",
" [0.11790291, 0.15503706, 0.86422424]]],\n",
"\n",
"\n",
" [[[0.16965053, 0.19585652, 0.57609002],\n",
" [0.88591501, 0.90486881, 0.29608075],\n",
" [0.4261748 , 0.91284255, 0.40807448],\n",
" [0.9686301 , 0.74409545, 0.1479126 ],\n",
" [0.76802611, 0.07944974, 0.7871578 ]],\n",
"\n",
" [[0.77475062, 0.40398437, 0.22643119],\n",
" [0.0104928 , 0.38809633, 0.48776304],\n",
" [0.22491731, 0.94079457, 0.97114209],\n",
" [0.65810206, 0.9100557 , 0.24727223],\n",
" [0.05758432, 0.54125367, 0.01468728]],\n",
"\n",
" [[0.71953982, 0.17153423, 0.74242277],\n",
" [0.69007692, 0.16985619, 0.40592314],\n",
" [0.54681391, 0.5548588 , 0.47674301],\n",
" [0.57911416, 0.28571978, 0.39200726],\n",
" [0.11671085, 0.35797023, 0.65890719]],\n",
"\n",
" [[0.85257004, 0.7955835 , 0.66530167],\n",
" [0.60683054, 0.66667651, 0.58183113],\n",
" [0.95010415, 0.77982033, 0.48321958],\n",
" [0.73954001, 0.21177226, 0.40730859],\n",
" [0.58235681, 0.54621676, 0.72576091]],\n",
"\n",
" [[0.78679585, 0.07508165, 0.7500122 ],\n",
" [0.52093211, 0.58039123, 0.46972807],\n",
" [0.96323894, 0.99445506, 0.73634561],\n",
" [0.79991676, 0.66932491, 0.83293889],\n",
" [0.20586253, 0.41194998, 0.43086448]]],\n",
"\n",
"\n",
" [[[0.39755238, 0.44572733, 0.14839841],\n",
" [0.85064056, 0.46063306, 0.48205451],\n",
" [0.71843235, 0.16096294, 0.68091537],\n",
" [0.53677884, 0.1177746 , 0.2540637 ],\n",
" [0.27248667, 0.53513208, 0.88182579]],\n",
"\n",
" [[0.41187651, 0.77931731, 0.24484454],\n",
" [0.79766554, 0.41731159, 0.95576332],\n",
" [0.72876755, 0.27933312, 0.3082945 ],\n",
" [0.76694795, 0.74242235, 0.62142724],\n",
" [0.97480578, 0.35289907, 0.64357369]],\n",
"\n",
" [[0.00510661, 0.36997422, 0.16220858],\n",
" [0.9872583 , 0.95501977, 0.15935416],\n",
" [0.78867534, 0.54122521, 0.54486478],\n",
" [0.03354764, 0.05670898, 0.47520133],\n",
" [0.04376703, 0.69549825, 0.3957491 ]],\n",
"\n",
" [[0.84051762, 0.79462166, 0.30602544],\n",
" [0.25025524, 0.08666752, 0.69194498],\n",
" [0.67076409, 0.13492286, 0.4590991 ],\n",
" [0.20405993, 0.07903494, 0.70707305],\n",
" [0.83720425, 0.75953423, 0.81481202]],\n",
"\n",
" [[0.19551102, 0.46499979, 0.50975971],\n",
" [0.81569678, 0.34130601, 0.76208935],\n",
" [0.82885678, 0.10476159, 0.13018122],\n",
" [0.26689358, 0.79482095, 0.75352521],\n",
" [0.58703011, 0.11749016, 0.19427432]]]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train = np.random.rand(100,5,5,3)\n",
"X_train"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.21946477],\n",
" [0.3711923 ],\n",
" [0.24737705],\n",
" [0.16344551],\n",
" [0.57433513],\n",
" [0.76673914],\n",
" [0.7049643 ],\n",
" [0.67571894],\n",
" [0.15141958],\n",
" [0.84983105],\n",
" [0.05552415],\n",
" [0.95604334],\n",
" [0.70679183],\n",
" [0.86247509],\n",
" [0.01407803],\n",
" [0.55736261],\n",
" [0.19422897],\n",
" [0.86008897],\n",
" [0.2779906 ],\n",
" [0.55527211],\n",
" [0.16247257],\n",
" [0.13970005],\n",
" [0.92248072],\n",
" [0.93564244],\n",
" [0.28634139],\n",
" [0.23594509],\n",
" [0.44848018],\n",
" [0.35401733],\n",
" [0.99775489],\n",
" [0.1732102 ],\n",
" [0.81910272],\n",
" [0.46522896],\n",
" [0.99489404],\n",
" [0.55781089],\n",
" [0.13592596],\n",
" [0.43879534],\n",
" [0.81024602],\n",
" [0.94877655],\n",
" [0.75248092],\n",
" [0.30265859],\n",
" [0.42836091],\n",
" [0.59454105],\n",
" [0.43154069],\n",
" [0.67468033],\n",
" [0.58606549],\n",
" [0.1744458 ],\n",
" [0.49333982],\n",
" [0.06740779],\n",
" [0.00943958],\n",
" [0.4099619 ],\n",
" [0.49609765],\n",
" [0.31434293],\n",
" [0.24492819],\n",
" [0.27013361],\n",
" [0.69482176],\n",
" [0.93217975],\n",
" [0.57115641],\n",
" [0.64785497],\n",
" [0.09918379],\n",
" [0.55192952],\n",
" [0.81989582],\n",
" [0.05051461],\n",
" [0.93054128],\n",
" [0.67387306],\n",
" [0.76534683],\n",
" [0.6162766 ],\n",
" [0.96975296],\n",
" [0.79504513],\n",
" [0.42467985],\n",
" [0.11096869],\n",
" [0.22852679],\n",
" [0.65542355],\n",
" [0.35148172],\n",
" [0.77244589],\n",
" [0.35624866],\n",
" [0.2630125 ],\n",
" [0.20241856],\n",
" [0.35342962],\n",
" [0.4200376 ],\n",
" [0.21397308],\n",
" [0.38814001],\n",
" [0.82781448],\n",
" [0.14303959],\n",
" [0.83622505],\n",
" [0.19489373],\n",
" [0.50839397],\n",
" [0.94387189],\n",
" [0.26870297],\n",
" [0.66940445],\n",
" [0.31928558],\n",
" [0.25835398],\n",
" [0.59484484],\n",
" [0.21550068],\n",
" [0.19920632],\n",
" [0.85037438],\n",
" [0.64880356],\n",
" [0.19659417],\n",
" [0.12815503],\n",
" [0.70278136],\n",
" [0.65639473]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train = np.random.rand(100,1)\n",
"y_train"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"4/4 [==============================] - 0s 2ms/step - loss: 2.7752\n",
"Epoch 2/10\n",
"4/4 [==============================] - 0s 979us/step - loss: 2.6586\n",
"Epoch 3/10\n",
"4/4 [==============================] - 0s 649us/step - loss: 2.5723\n",
"Epoch 4/10\n",
"4/4 [==============================] - 0s 2ms/step - loss: 2.4804\n",
"Epoch 5/10\n",
"4/4 [==============================] - 0s 1ms/step - loss: 2.3874\n",
"Epoch 6/10\n",
"4/4 [==============================] - 0s 762us/step - loss: 2.2941\n",
"Epoch 7/10\n",
"4/4 [==============================] - 0s 889us/step - loss: 2.2374\n",
"Epoch 8/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-23 20:00:14.999431: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"4/4 [==============================] - 0s 15ms/step - loss: 2.1879\n",
"Epoch 9/10\n",
"4/4 [==============================] - 0s 987us/step - loss: 2.1199\n",
"Epoch 10/10\n",
"4/4 [==============================] - 0s 1ms/step - loss: 2.0514\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x12dea6fa0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.compile(\"adam\", \"mse\")\n",
"model.fit(X_train, y_train, epochs=10)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32, numpy=array([0.9704942, 0.9639488, 0.9704298], dtype=float32)>,\n",
" <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32, numpy=array([ 0.02682306, -0.02748435, 0.02756111], dtype=float32)>,\n",
" <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32, numpy=array([0.16209687, 0.1706537 , 0.16507432], dtype=float32)>,\n",
" <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32, numpy=array([0.6963565, 0.6966989, 0.6962899], dtype=float32)>]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.layers[1].weights"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.27455621, 0.27884516, 0.7259278 ],\n",
" [0.53377714, 0.6519488 , 0.86514377],\n",
" [0.49080768, 0.65021787, 0.68419028],\n",
" [0.0618392 , 0.02673969, 0.57868236],\n",
" [0.85494917, 0.71864642, 0.38163087]],\n",
"\n",
" [[0.22703165, 0.33350008, 0.85208601],\n",
" [0.86850824, 0.43578265, 0.70671478],\n",
" [0.75415054, 0.97444638, 0.48606382],\n",
" [0.14784358, 0.45406111, 0.08556609],\n",
" [0.20530247, 0.25450988, 0.51542351]],\n",
"\n",
" [[0.13544047, 0.01932561, 0.22823691],\n",
" [0.64891188, 0.18616972, 0.46087849],\n",
" [0.81283055, 0.5325164 , 0.32305024],\n",
" [0.75351508, 0.14501534, 0.44999686],\n",
" [0.42234362, 0.06347228, 0.63881747]],\n",
"\n",
" [[0.85467296, 0.80260163, 0.14398915],\n",
" [0.57931829, 0.53329564, 0.7187654 ],\n",
" [0.80024988, 0.96224997, 0.8937722 ],\n",
" [0.33025203, 0.91751174, 0.62203916],\n",
" [0.93176317, 0.72790829, 0.14862771]],\n",
"\n",
" [[0.60466651, 0.567648 , 0.37182924],\n",
" [0.76271487, 0.11624398, 0.81946002],\n",
" [0.4389968 , 0.90003205, 0.19711127],\n",
" [0.67622958, 0.53012144, 0.37373778],\n",
" [0.97509719, 0.48172791, 0.63218788]]]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,5,5,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 32ms/step\n"
]
},
{
"data": {
"text/plain": [
"array([[[[ 0.1575187 , 0.09737267, 0.67934984],\n",
" [ 0.45877463, 0.5279483 , 0.8411379 ],\n",
" [ 0.40883726, 0.52595073, 0.63084507],\n",
" [-0.0896923 , -0.19356653, 0.50823045],\n",
" [ 0.83202755, 0.60491985, 0.27922952]],\n",
"\n",
" [[ 0.10228759, 0.16044651, 0.8259631 ],\n",
" [ 0.84778535, 0.2784844 , 0.65702164],\n",
" [ 0.7148835 , 0.9001226 , 0.40059495],\n",
" [ 0.01025847, 0.29957846, -0.06483836],\n",
" [ 0.07703484, 0.06928882, 0.43471497]],\n",
"\n",
" [[-0.00415592, -0.20212264, 0.10096472],\n",
" [ 0.5925795 , -0.00957828, 0.37132615],\n",
" [ 0.7830791 , 0.39011884, 0.21115081],\n",
" [ 0.714145 , -0.05707199, 0.3586802 ],\n",
" [ 0.32927114, -0.15117574, 0.5781157 ]],\n",
"\n",
" [[ 0.8317066 , 0.7018073 , 0.00305727],\n",
" [ 0.5117007 , 0.39101806, 0.6710261 ],\n",
" [ 0.7684583 , 0.88604754, 0.8744081 ],\n",
" [ 0.22224607, 0.83441794, 0.558617 ],\n",
" [ 0.92129767, 0.6156084 , 0.0084479 ]],\n",
"\n",
" [[ 0.54115933, 0.43066198, 0.26783872],\n",
" [ 0.72483665, -0.09027521, 0.78804713],\n",
" [ 0.34862483, 0.81424564, 0.06479244],\n",
" [ 0.624327 , 0.38735494, 0.2700567 ],\n",
" [ 0.97165865, 0.33150697, 0.5704112 ]]]], dtype=float32)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_model = Model(inputs, model.layers[1].output)\n",
"y = new_model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"gamma = model.layers[1].weights[0].numpy()\n",
"beta = model.layers[1].weights[1].numpy()\n",
"moving_mean = model.layers[1].weights[2].numpy()\n",
"moving_var = model.layers[1].weights[3].numpy()\n",
"epsilon = model.layers[1].epsilon"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"a = gamma/(moving_var+epsilon)**.5\n",
"b = beta-gamma*moving_mean/(moving_var+epsilon)**.5"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[ 1.38897775e-08, -4.41621682e-09, -1.46314817e-08],\n",
" [ 1.22362858e-08, 2.73325588e-08, -4.77779197e-08],\n",
" [-1.14818392e-08, -8.60447691e-09, -5.83081214e-08],\n",
" [-5.10371694e-09, -8.67723293e-09, 2.98262304e-10],\n",
" [-7.37144357e-09, -6.87929902e-09, -3.09108445e-08]],\n",
"\n",
" [[-3.70238870e-09, 9.08749898e-09, 6.04406457e-08],\n",
" [-8.04363320e-09, -2.61427017e-08, -4.75678711e-08],\n",
" [-2.29807172e-08, -3.35601555e-08, -3.61847423e-08],\n",
" [ 1.76612667e-09, -8.73801476e-10, 2.36584587e-09],\n",
" [ 5.96964625e-09, -1.35698691e-08, -2.09152273e-09]],\n",
"\n",
" [[-2.18008088e-09, 3.40718009e-09, 1.14725877e-09],\n",
" [ 7.83470533e-09, -4.13062541e-09, -2.05064118e-08],\n",
" [ 5.68865103e-08, 3.01907888e-08, -1.38190100e-08],\n",
" [-2.30006845e-08, -8.84019241e-09, -3.75235497e-08],\n",
" [-1.33667987e-08, 2.45791437e-09, 1.05576055e-08]],\n",
"\n",
" [[ 1.63983768e-08, 4.34060976e-09, 7.69416812e-09],\n",
" [ 1.96461017e-08, -2.64836226e-08, -5.55432413e-08],\n",
" [ 1.60446652e-08, 3.83262573e-08, 3.27252936e-08],\n",
" [-1.79053708e-08, 3.49677078e-08, 2.28655761e-09],\n",
" [ 3.06888477e-08, -6.37458319e-09, -7.83266337e-10]],\n",
"\n",
" [[ 3.44015694e-09, -2.27795071e-08, -3.66268776e-09],\n",
" [ 8.27864410e-09, -8.66295551e-09, -9.44552192e-09],\n",
" [ 3.77712691e-08, -4.68591560e-08, -7.31272758e-09],\n",
" [ 1.90310351e-08, 2.21921187e-09, -1.20205449e-08],\n",
" [-1.50930767e-08, -1.72539832e-08, 6.61672261e-09]]]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y-(gamma*(X-moving_mean)/((moving_var+epsilon)**.5)+beta)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[ 6.29271463e-09, -1.51612348e-08, 4.78305939e-09],\n",
" [ 2.53706245e-09, 1.56956936e-08, -2.51350980e-08],\n",
" [-2.08326003e-08, -2.02372046e-08, -3.98614322e-08],\n",
" [-1.09757444e-08, -1.88196314e-08, 1.62983274e-08],\n",
" [-1.96752217e-08, -1.86755947e-08, -1.94802087e-08]],\n",
"\n",
" [[-1.09140497e-08, -1.78816328e-09, 8.27806710e-08],\n",
" [-2.04573691e-08, -3.72628548e-08, -2.85988609e-08],\n",
" [-3.44670656e-08, -4.59679019e-08, -2.23324095e-08],\n",
" [-4.80335610e-09, -1.20376463e-08, 6.93103225e-09],\n",
" [-1.06580120e-09, -2.42567173e-08, 1.24416322e-08]],\n",
"\n",
" [[-8.64898020e-09, -6.71749611e-09, 9.02084041e-09],\n",
" [-2.79820678e-09, -1.46541169e-08, -7.23810212e-09],\n",
" [ 4.49242945e-08, 1.88394085e-08, -3.74680098e-09],\n",
" [-3.44818796e-08, -1.92653107e-08, -2.45075741e-08],\n",
" [-2.21623482e-08, -7.77228767e-09, 2.79521442e-08]],\n",
"\n",
" [[ 4.09683865e-09, -7.65636798e-09, 1.36141277e-08],\n",
" [ 9.57756097e-09, -3.78368655e-08, -3.62947893e-08],\n",
" [ 4.18447266e-09, 2.59476647e-08, 5.60319802e-08],\n",
" [-2.59541007e-08, 2.26960550e-08, 1.92920239e-08],\n",
" [ 1.77621440e-08, -1.81930180e-08, 5.24425681e-09]],\n",
"\n",
" [[-6.83394596e-09, -3.42148641e-08, 7.54065804e-09],\n",
" [-3.27715710e-09, -1.90193002e-08, 1.21379391e-08],\n",
" [ 2.88406700e-08, -5.90890261e-08, -1.60918778e-10],\n",
" [ 8.17658918e-09, -9.12644371e-09, -7.72941988e-10],\n",
" [-2.83711992e-08, -2.84839613e-08, 2.38575276e-08]]]])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y-(a*X+b)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"X_in = [[[int(X[0][i][j][k]*1e36) for k in range(3)] for j in range(5)] for i in range(5)]\n",
"a_in = [int(a[i]*1e36) for i in range(3)]\n",
"b_in = [int(b[i]*1e72) for i in range(3)]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def BatchNormalizationInt(nRows, nCols, nChannels, n, X_in, a_in, b_in):\n",
" X = [[[str(X_in[i][j][k] % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]\n",
" A = [str(a_in[k] % p) for k in range(nChannels)]\n",
" B = [str(b_in[k] % p) for k in range(nChannels)]\n",
" out = [[[None for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]\n",
" remainder = [[[None for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]\n",
" for i in range(nRows):\n",
" for j in range(nCols):\n",
" for k in range(nChannels):\n",
" out[i][j][k] = (X_in[i][j][k]*a_in[k] + b_in[k])\n",
" remainder[i][j][k] = str(out[i][j][k] % n)\n",
" out[i][j][k] = str(out[i][j][k] // n % p)\n",
" return X, A, B, out, remainder"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([[['157518693472490790243750538747957882',\n",
" '97372681162554733176887955075564208',\n",
" '679349834904288174010983420548472594'],\n",
" ['458774623717973004897590573896222594',\n",
" '527948304216263291457870336316333444',\n",
" '841137911182461164557826291121517336'],\n",
" ['408837279648365652067755038590177210',\n",
" '525950750084159001810559673300027574',\n",
" '630845109746686169695072155498068928'],\n",
" ['21888242871839275222246405745257275088548274708124961881757683904418357553907',\n",
" '21888242871839275222246405745257275088548170833903911058199291157862382757945',\n",
" '508230431470837755985412821877993851'],\n",
" ['832027574187245701393604967574752718',\n",
" '604919869501858267645833548353564313',\n",
" '279229541231612628393358090302516442']],\n",
" [['102287601417742280387289538163651204',\n",
" '160446511507058261099060517183710523',\n",
" '825962997148680811368987640325934077'],\n",
" ['847785374117952740336862359672140583',\n",
" '278484441349921535654657902103684959',\n",
" '657021670330123182657390893609651294'],\n",
" ['714883540765130856969303509707571514',\n",
" '900122628880346975580290622927866446',\n",
" '400594972054699631420351232955312561'],\n",
" ['10258476396616863842152239203344050',\n",
" '299578470108401435866442829726136156',\n",
" '21888242871839275222246405745257275088548299562044382966913364566037249574881'],\n",
" ['77034839563439930876087711734475167',\n",
" '69288844285022374454862035586080571',\n",
" '434714960531237722761063240822063770']],\n",
" [['21888242871839275222246405745257275088548360244503864875852868041600221908663',\n",
" '21888242871839275222246405745257275088548162277779161866355363655040683885335',\n",
" '100964708546126595514097434616933959'],\n",
" ['592579486784107682988589813640686405',\n",
" '21888242871839275222246405745257275088548354822146812280026576875315705275880',\n",
" '371326155748081420061854877950882957'],\n",
" ['783079042809927963587133644460156189',\n",
" '390118818517159008054052294744317319',\n",
" '211150813869290941494836799384423040'],\n",
" ['714145039231177651343799972971259701',\n",
" '21888242871839275222246405745257275088548307328444034834297413996208073048117',\n",
" '358680213163427345690379627175515214'],\n",
" ['329271159876734251530552022145153369',\n",
" '21888242871839275222246405745257275088548213224686425649934303811174716724822',\n",
" '578115673723270972200332905482236663']],\n",
" [['831706579403069869857935985454018004',\n",
" '701807327774318518191690249279384790',\n",
" '3057252040078645241914286618837645'],\n",
" ['511700680215072064258322477636183971',\n",
" '391018100666836853673874037122153517',\n",
" '671026146943898182863942504121469642'],\n",
" ['768458302604925592394191747899082690',\n",
" '886047516147519780735326653713713464',\n",
" '874408069845400225234489253410344667'],\n",
" ['222246091689917701515896532244187761',\n",
" '834417916490041372048913944859818527',\n",
" '558616976519438572914283702703657983'],\n",
" ['921297651648561583336105580618598710',\n",
" '615608412338983701988855791011481577',\n",
" '8447895170210076430444298018909113']],\n",
" [['541159338632499440419779390078898826',\n",
" '430662010552296997105829163483722657',\n",
" '267838708966299969534341921959286450'],\n",
" ['724836650787685688836352855376326828',\n",
" '21888242871839275222246405745257275088548274125221931276103456351742703907060',\n",
" '788047122738312175009071698929184108'],\n",
" ['348624796636930142430096079480218740',\n",
" '814245700320563240509607601046393243',\n",
" '64792439502463893211811777556988475'],\n",
" ['624326995779251861311046568986208844',\n",
" '387354949302453930827325825546488462',\n",
" '270056695518959509623672894764882809'],\n",
" ['971658675431593506453919504060069966',\n",
" '331506996028517098372582336349879273',\n",
" '570411181434220420525142442837712479']]],\n",
" [[['960480586600175329574532373343633408',\n",
" '16700446728492295766061281370439680',\n",
" '379639810818394963719156481312948224'],\n",
" ['388476442819984282074088736716488704',\n",
" '605933780470969152068605021142384640',\n",
" '798916055303936274984266387226099712'],\n",
" ['468208661324380829760178449488543744',\n",
" '819358397205938578920215010284142592',\n",
" '21589625405262166471345886990434304'],\n",
" ['697307730873745656946722950128599040',\n",
" '484544812284374194987581109665005568',\n",
" '165375951315867937763795967171100672'],\n",
" ['245646470350672698747537604682448896',\n",
" '516198956417707676498067467442061312',\n",
" '851127713155083345720389185963032576']],\n",
" [['963964631181695048142847483580514304',\n",
" '717110020036717348293837961318891520',\n",
" '559953044877040683984591286389702656'],\n",
" ['18430124250391465600117066342334464',\n",
" '132070511117078294010171654032850944',\n",
" '159098461579690714866905242830635008'],\n",
" ['627161913358867337554883287034363904',\n",
" '214447136123138332377227662874116096',\n",
" '882093469775889964276298230664265728'],\n",
" ['506592892690571316134013534512087040',\n",
" '120413569434570423403255788841467904',\n",
" '363677450590533494361708506394918912'],\n",
" ['1589009160702191880025766302318592',\n",
" '606881345580479787255639207307116544',\n",
" '959500312602351349459070457955221504']],\n",
" [['269867103629848361043137830460588032',\n",
" '990156537691794618122141045295677440',\n",
" '865669682192046160291865685070446592'],\n",
" ['290313896548079828952332075043651584',\n",
" '864848850508538634449993085726228480',\n",
" '602632316683732661972508737477279744'],\n",
" ['522550745079433503690896029362683904',\n",
" '360374127352021396321849249139523584',\n",
" '232224328255243390445419547970240512'],\n",
" ['605064041602877244147525705785147392',\n",
" '199177740608395626314729675502387200',\n",
" '919073451728990826614299904488505344'],\n",
" ['360317325587177670121175102173216768',\n",
" '320159996958358401900225542972506112',\n",
" '149208972882417206108392691362430976']],\n",
" [['403450033176818573530952678489718784',\n",
" '489675097836673889083870450733285376',\n",
" '878348273706607180847862755205578752'],\n",
" ['971247092328484848039979691100078080',\n",
" '614717678863967773495129714821431296',\n",
" '354994706810685515240551471042789376'],\n",
" ['175120840317815376550047548614639616',\n",
" '256810812410679590717730062889648128',\n",
" '32500220595245784184833778997264384'],\n",
" ['820630970242119026496832796320858112',\n",
" '450959634440169857524579916622856192',\n",
" '993116187387119510197936862871420928'],\n",
" ['141575526471072789135353484131958784',\n",
" '424325045683895850778427312979312640',\n",
" '504193266510591745485938882812837888']],\n",
" [['58216472724137800358267959826710528',\n",
" '768052861880136122281533170251726848',\n",
" '416638974419576874572829739330830336'],\n",
" ['257711363123831457072460090222575616',\n",
" '770799881808654090050761531714437120',\n",
" '553150637395247721839782017066074112'],\n",
" ['576779513601307433086800237942538240',\n",
" '70762290041252423558033638227443712',\n",
" '501500349611164148444025032284307456'],\n",
" ['213252321414434873609495117362626560',\n",
" '168126301731513097227967825069473792',\n",
" '210589547132072037234216859745648640'],\n",
" ['192717389299758634733578375953121280',\n",
" '629763589047996081611699278760116224',\n",
" '631810269670576590942597092109975552']]])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_in, a_in, b_in, out, remainder = BatchNormalizationInt(5, 5, 3, 10**36, X_in, a_in, b_in)\n",
"out, remainder"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": X_in,\n",
" \"a\": a_in,\n",
" \"b\": b_in,\n",
" \"out\": out,\n",
" \"remainder\": remainder\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"with open(\"batchNormalization_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.6 ('tf24')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "11280bdb37aa6bc5d4cf1e4de756386eb1f9eecd8dcdefa77636dfac7be2370d"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}