diff --git a/circuits/BatchNormalization2D.circom b/circuits/BatchNormalization2D.circom new file mode 100644 index 0000000..adedd36 --- /dev/null +++ b/circuits/BatchNormalization2D.circom @@ -0,0 +1,16 @@ +pragma circom 2.0.3; + +template BatchNormalization2D(nRows, nCols, nChannels) { + signal input in[nRows][nCols][nChannels]; + signal input a[nChannels]; + signal input b[nChannels]; + signal output out[nRows][nCols][nChannels]; + + for (var i=0; i" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.compile(\"adam\", \"mse\")\n", + "model.fit(X_train, y_train, epochs=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.layers[1].weights" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[0.53603175, 0.16288177, 0.54647377],\n", + " [0.58968269, 0.51633636, 0.09084824],\n", + " [0.58954409, 0.60274177, 0.16628126],\n", + " [0.03036366, 0.25776433, 0.97264483],\n", + " [0.06753911, 0.4969747 , 0.02626603]],\n", + "\n", + " [[0.92805506, 0.92962356, 0.94991846],\n", + " [0.40984699, 0.57242913, 0.73624703],\n", + " [0.27120968, 0.30428539, 0.6547197 ],\n", + " [0.6895789 , 0.12203021, 0.56160566],\n", + " [0.35853814, 0.61396961, 0.30326431]],\n", + "\n", + " [[0.49895694, 0.26192641, 0.41918769],\n", + " [0.56496371, 0.13934069, 0.77930897],\n", + " [0.9652276 , 0.68000352, 0.59384582],\n", + " [0.18267196, 0.26760574, 0.93864666],\n", + " [0.49916607, 0.63215712, 0.38614211]],\n", + "\n", + " [[0.46365438, 0.3845917 , 0.6604073 ],\n", + " [0.59669509, 0.22802217, 0.62536791],\n", + " [0.37852067, 0.51773501, 0.96948045],\n", + " [0.46492378, 0.09701206, 0.90831063],\n", + " [0.31265477, 0.43007139, 0.82608669]],\n", + "\n", + " [[0.32252988, 0.28388506, 0.15159293],\n", + " [0.54518128, 0.73664414, 0.27618411],\n", + " [0.41446863, 0.45379391, 0.65724072],\n", + " [0.1670575 , 0.82368301, 0.41525341],\n", + " [0.05091919, 0.78432432, 0.29655634]]]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.random.rand(1,5,5,3)\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[ 0.4612311 , 0.02846454, 0.47584018],\n", + " [ 0.5232636 , 0.4369807 , -0.05033689],\n", + " [ 0.5231033 , 0.5368464 , 0.03677658],\n", + " [-0.12343314, 0.13812803, 0.9680018 ],\n", + " [-0.08045009, 0.41460282, -0.12491935]],\n", + "\n", + " [[ 0.9144969 , 0.9146502 , 0.94175637],\n", + " [ 0.31533363, 0.50181156, 0.69499886],\n", + " [ 0.1550382 , 0.19189616, 0.6008475 ],\n", + " [ 0.63876563, -0.01875092, 0.4933151 ],\n", + " [ 0.25600925, 0.54982334, 0.19497083]],\n", + "\n", + " [[ 0.41836447, 0.14293846, 0.3288444 ],\n", + " [ 0.49468288, 0.00125621, 0.7447288 ],\n", + " [ 0.9574766 , 0.62614405, 0.53054756],\n", + " [ 0.05266898, 0.14950253, 0.92873925],\n", + " [ 0.41860625, 0.57084405, 0.29068184]],\n", + "\n", + " [[ 0.37754688, 0.28471267, 0.6074158 ],\n", + " [ 0.53137136, 0.10375258, 0.5669507 ],\n", + " [ 0.27911344, 0.43859717, 0.96434736],\n", + " [ 0.37901458, -0.04766643, 0.8937058 ],\n", + " [ 0.20295788, 0.33727726, 0.7987498 ]],\n", + "\n", + " [[ 0.2143757 , 0.16831785, 0.01981383],\n", + " [ 0.47181004, 0.69160825, 0.16369738],\n", + " [ 0.32067728, 0.3646953 , 0.6037588 ],\n", + " [ 0.03461521, 0.7922061 , 0.32430092],\n", + " [-0.0996664 , 0.74671614, 0.18722416]]]], dtype=float32)" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new_model = Model(inputs, model.layers[1].output)\n", + "y = new_model.predict(X)\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "gamma = model.layers[1].weights[0].numpy()\n", + "beta = model.layers[1].weights[1].numpy()\n", + "moving_mean = model.layers[1].weights[2].numpy()\n", + "moving_var = model.layers[1].weights[3].numpy()\n", + "epsilon = model.layers[1].epsilon" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "a = gamma/(moving_var+epsilon)**.5\n", + "b = beta-gamma*moving_mean/(moving_var+epsilon)**.5" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[-2.03794953e-08, 5.91488761e-09, 1.45154126e-08],\n", + " [ 7.98100520e-08, 1.98810133e-08, 7.91968675e-09],\n", + " [ 4.81973802e-08, -2.20449419e-08, -1.88879183e-09],\n", + " [ 2.64212180e-10, 1.55602272e-08, -7.21641903e-08],\n", + " [-8.82741547e-10, -2.68981404e-08, 1.23617632e-08]],\n", + "\n", + " [[ 2.51851505e-08, -1.65522146e-08, -4.00145346e-08],\n", + " [ 2.10715177e-08, -7.90676278e-08, -8.24527382e-08],\n", + " [ 5.33205935e-09, -1.58572262e-08, 1.61795255e-09],\n", + " [ 4.38547865e-10, 3.73349793e-09, -5.56143572e-08],\n", + " [ 3.41888787e-08, 3.81396181e-09, 5.69365846e-09]],\n", + "\n", + " [[ 2.33330081e-08, 6.56881610e-09, -1.21710513e-08],\n", + " [-2.89912733e-09, 4.82403046e-09, -1.72175791e-08],\n", + " [ 6.43033263e-08, -4.13130072e-08, -1.50603818e-09],\n", + " [ 7.43607972e-09, 4.41849135e-09, 2.61667081e-08],\n", + " [ 1.46215137e-08, -6.68016564e-08, -4.45506274e-08]],\n", + "\n", + " [[ 1.13439502e-08, -1.31488602e-08, 1.62490282e-08],\n", + " [-3.02629793e-08, 4.91692291e-09, -2.13291591e-08],\n", + " [-4.75657103e-09, -4.41433378e-08, -1.26229787e-07],\n", + " [ 9.41011957e-09, 2.70651909e-09, -1.61074364e-08],\n", + " [ 1.15060916e-09, 3.34245048e-09, -5.07417969e-08]],\n", + "\n", + " [[ 4.33865668e-09, 4.46010895e-09, -6.16794888e-09],\n", + " [ 3.34336615e-08, -4.74926964e-09, -1.07198432e-08],\n", + " [ 2.74091595e-08, 1.35368101e-08, -5.05023202e-08],\n", + " [ 4.70631774e-09, -2.86037326e-08, -1.74001409e-08],\n", + " [-1.16552118e-08, 3.97495532e-08, -2.19518315e-09]]]])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y-(gamma*(X-moving_mean)/((moving_var+epsilon)**.5)+beta)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[-1.61475423e-08, 5.64923888e-09, -2.40047562e-08],\n", + " [ 8.45288702e-08, 8.12395617e-09, -3.95743559e-09],\n", + " [ 5.29149408e-08, -3.66111857e-08, -1.81769167e-08],\n", + " [-9.26127786e-11, 1.22097854e-08, -1.35605033e-07],\n", + " [-9.02211070e-10, -3.80257174e-08, 4.26113381e-09]],\n", + "\n", + " [[ 3.29745908e-08, -4.17459400e-08, -1.02126435e-07],\n", + " [ 2.41583839e-08, -9.26483562e-08, -1.32070041e-07],\n", + " [ 7.16083581e-09, -2.07201465e-08, -4.32319784e-08],\n", + " [ 6.06389250e-09, 4.79600271e-09, -9.50193745e-08],\n", + " [ 3.68101333e-08, -1.11173180e-08, -1.86046535e-08]],\n", + "\n", + " [[ 2.72285189e-08, 3.08305820e-09, -4.32480702e-08],\n", + " [ 1.59537394e-09, 5.32374209e-09, -6.93529612e-08],\n", + " [ 7.24300957e-08, -5.83911616e-08, -4.27963234e-08],\n", + " [ 8.46140438e-09, 7.48088591e-10, -3.52860654e-08],\n", + " [ 1.85189223e-08, -8.23242431e-08, -7.36952816e-08]],\n", + "\n", + " [[ 1.49191015e-08, -2.06226749e-08, -2.89334895e-08],\n", + " [-2.54805258e-08, 2.53344934e-09, -6.44627220e-08],\n", + " [-1.95398131e-09, -5.59458676e-08, -1.89485590e-07],\n", + " [ 1.29967902e-08, 4.58240647e-09, -7.57862879e-08],\n", + " [ 3.35548667e-09, -5.60998636e-09, -1.05612541e-07]],\n", + "\n", + " [[ 6.63314770e-09, 2.60438060e-10, -2.15971630e-08],\n", + " [ 3.77486434e-08, -2.36689068e-08, -3.34346205e-08],\n", + " [ 3.05379656e-08, 3.81311405e-09, -9.54996694e-08],\n", + " [ 5.58994609e-09, -5.03531512e-08, -4.82471003e-08],\n", + " [-1.18255018e-08, 1.92797525e-08, -2.61012420e-08]]]])" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y-(a*X+b)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1.1562214, 1.1557811, 1.1548455], dtype=float32)" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "in_json = {\n", + " \"in\": (X*1000).round().astype(int).flatten().tolist(),\n", + " \"a\": (a*1000).round().astype(int).flatten().tolist(),\n", + " \"b\": (b*1000*1000).round().astype(int).flatten().tolist()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "out_json = {\n", + " \"out\": (y*1000000).round().astype(int).flatten().tolist()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"batchNormalization_input.json\", \"w\") as f:\n", + " json.dump(in_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"batchNormalization_output.json\", \"w\") as f:\n", + " json.dump(out_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 5, 5, 3)" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3,)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3,)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.6 ('tf24')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "11280bdb37aa6bc5d4cf1e4de756386eb1f9eecd8dcdefa77636dfac7be2370d" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/batchNormalization_input.json b/models/batchNormalization_input.json new file mode 100644 index 0000000..9e9aa19 --- /dev/null +++ b/models/batchNormalization_input.json @@ -0,0 +1 @@ +{"in": [536, 163, 546, 590, 516, 91, 590, 603, 166, 30, 258, 973, 68, 497, 26, 928, 930, 950, 410, 572, 736, 271, 304, 655, 690, 122, 562, 359, 614, 303, 499, 262, 419, 565, 139, 779, 965, 680, 594, 183, 268, 939, 499, 632, 386, 464, 385, 660, 597, 228, 625, 379, 518, 969, 465, 97, 908, 313, 430, 826, 323, 284, 152, 545, 737, 276, 414, 454, 657, 167, 824, 415, 51, 784, 297], "a": [1156, 1156, 1155], "b": [-158540, -159791, -155253]} \ No newline at end of file diff --git a/models/batchNormalization_output.json b/models/batchNormalization_output.json new file mode 100644 index 0000000..408a367 --- /dev/null +++ b/models/batchNormalization_output.json @@ -0,0 +1 @@ +{"out": [461231, 28465, 475840, 523264, 436981, -50337, 523103, 536846, 36777, -123433, 138128, 968002, -80450, 414603, -124919, 914497, 914650, 941756, 315334, 501812, 694999, 155038, 191896, 600847, 638766, -18751, 493315, 256009, 549823, 194971, 418364, 142938, 328844, 494683, 1256, 744729, 957477, 626144, 530548, 52669, 149503, 928739, 418606, 570844, 290682, 377547, 284713, 607416, 531371, 103753, 566951, 279113, 438597, 964347, 379015, -47666, 893706, 202958, 337277, 798750, 214376, 168318, 19814, 471810, 691608, 163697, 320677, 364695, 603759, 34615, 792206, 324301, -99666, 746716, 187224]} \ No newline at end of file diff --git a/test/BatchNormalization.js b/test/BatchNormalization.js new file mode 100644 index 0000000..c149746 --- /dev/null +++ b/test/BatchNormalization.js @@ -0,0 +1,48 @@ +const chai = require("chai"); +const { Console } = require("console"); +const path = require("path"); + +const wasm_tester = require("circom_tester").wasm; + +const F1Field = require("ffjavascript").F1Field; +const Scalar = require("ffjavascript").Scalar; +exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617"); +const Fr = new F1Field(exports.p); + +const assert = chai.assert; + + + +describe("BatchNormalization layer test", function () { + this.timeout(100000000); + + it("(5,5,3) -> (5,5,3)", async () => { + let json = require("../models/batchNormalization_input.json"); + let OUTPUT = require("../models/batchNormalization_output.json"); + + const circuit = await wasm_tester(path.join(__dirname, "circuits", "batchNormalization_test.circom")); + + const a = []; + const b = []; + + for (var i=0; i