Latest MNIST model with AvgPooling2D and BN

This commit is contained in:
Cathie So
2022-11-18 21:38:41 +08:00
parent b95a34e2ac
commit 2806c81b3d
4 changed files with 679 additions and 0 deletions

497
models/mnist_latest.ipynb Normal file
View File

@@ -0,0 +1,497 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# model architecture modified from https://keras.io/examples/vision/mnist_convnet/"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Softmax, Dense, Lambda, BatchNormalization\n",
"from tensorflow.keras import Model\n",
"from tensorflow.keras.datasets import mnist\n",
"from tensorflow.keras.utils import to_categorical\n",
"from tensorflow.keras.optimizers import Adam, SGD\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"(X_train, y_train), (X_test, y_test) = mnist.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Convert y_train into one-hot format\n",
"temp = []\n",
"for i in range(len(y_train)):\n",
" temp.append(to_categorical(y_train[i], num_classes=10))\n",
"y_train = np.array(temp)\n",
"# Convert y_test into one-hot format\n",
"temp = []\n",
"for i in range(len(y_test)): \n",
" temp.append(to_categorical(y_test[i], num_classes=10))\n",
"y_test = np.array(temp)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#reshaping\n",
"X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)\n",
"X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(28,28,1))\n",
"out = Lambda(lambda x: x/100)(inputs)\n",
"out = Conv2D(4, 3, use_bias=False)(out)\n",
"out = BatchNormalization()(out)\n",
"out = Lambda(lambda x: x**2+x)(out)\n",
"out = AveragePooling2D()(out)\n",
"# out = Lambda(lambda x: x*4)(out)\n",
"out = Conv2D(8, 3, use_bias=False)(out)\n",
"out = BatchNormalization()(out)\n",
"out = Lambda(lambda x: x**2+x)(out)\n",
"out = AveragePooling2D()(out)\n",
"# out = Lambda(lambda x: x*4)(out)\n",
"out = Flatten()(out)\n",
"out = Dense(10, activation=None)(out)\n",
"out = Softmax()(out)\n",
"model = Model(inputs, out)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n",
"_________________________________________________________________\n",
"lambda (Lambda) (None, 28, 28, 1) 0 \n",
"_________________________________________________________________\n",
"conv2d (Conv2D) (None, 26, 26, 4) 36 \n",
"_________________________________________________________________\n",
"batch_normalization (BatchNo (None, 26, 26, 4) 16 \n",
"_________________________________________________________________\n",
"lambda_1 (Lambda) (None, 26, 26, 4) 0 \n",
"_________________________________________________________________\n",
"average_pooling2d (AveragePo (None, 13, 13, 4) 0 \n",
"_________________________________________________________________\n",
"conv2d_1 (Conv2D) (None, 11, 11, 8) 288 \n",
"_________________________________________________________________\n",
"batch_normalization_1 (Batch (None, 11, 11, 8) 32 \n",
"_________________________________________________________________\n",
"lambda_2 (Lambda) (None, 11, 11, 8) 0 \n",
"_________________________________________________________________\n",
"average_pooling2d_1 (Average (None, 5, 5, 8) 0 \n",
"_________________________________________________________________\n",
"flatten (Flatten) (None, 200) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 10) 2010 \n",
"_________________________________________________________________\n",
"softmax (Softmax) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 2,382\n",
"Trainable params: 2,358\n",
"Non-trainable params: 24\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model.compile(\n",
" loss='categorical_crossentropy',\n",
" optimizer=SGD(learning_rate=0.01, momentum=0.9),\n",
" metrics=['acc']\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.3128 - acc: 0.9105 - val_loss: 0.0669 - val_acc: 0.9780\n",
"Epoch 2/15\n",
"1875/1875 [==============================] - 8s 4ms/step - loss: 0.0758 - acc: 0.9770 - val_loss: 0.0558 - val_acc: 0.9819\n",
"Epoch 3/15\n",
"1875/1875 [==============================] - 8s 5ms/step - loss: 0.0601 - acc: 0.9815 - val_loss: 0.0501 - val_acc: 0.9836\n",
"Epoch 4/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0537 - acc: 0.9835 - val_loss: 0.0446 - val_acc: 0.9854\n",
"Epoch 5/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0467 - acc: 0.9851 - val_loss: 0.0434 - val_acc: 0.9852\n",
"Epoch 6/15\n",
"1875/1875 [==============================] - 10s 5ms/step - loss: 0.0456 - acc: 0.9862 - val_loss: 0.0450 - val_acc: 0.9854\n",
"Epoch 7/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0423 - acc: 0.9864 - val_loss: 0.0351 - val_acc: 0.9879\n",
"Epoch 8/15\n",
"1875/1875 [==============================] - 8s 5ms/step - loss: 0.0407 - acc: 0.9871 - val_loss: 0.0447 - val_acc: 0.9857\n",
"Epoch 9/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0407 - acc: 0.9866 - val_loss: 0.0370 - val_acc: 0.9881\n",
"Epoch 10/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0363 - acc: 0.9894 - val_loss: 0.0403 - val_acc: 0.9868\n",
"Epoch 11/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0364 - acc: 0.9879 - val_loss: 0.0497 - val_acc: 0.9832\n",
"Epoch 12/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0337 - acc: 0.9903 - val_loss: 0.0394 - val_acc: 0.9879\n",
"Epoch 13/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0337 - acc: 0.9889 - val_loss: 0.0365 - val_acc: 0.9881\n",
"Epoch 14/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0331 - acc: 0.9901 - val_loss: 0.0435 - val_acc: 0.9867\n",
"Epoch 15/15\n",
"1875/1875 [==============================] - 9s 5ms/step - loss: 0.0333 - acc: 0.9893 - val_loss: 0.0353 - val_acc: 0.9884\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x16bc87250>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(X_train, y_train, epochs=15, batch_size=32, validation_data=(X_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((28, 28, 1), 0, 255)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = X_test[0]\n",
"X.shape, X.min(), X.max()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"model2 = Model(model.input, model.layers[-2].output)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ -6.012382 , -6.4132414, 7.764461 , 9.919108 , -10.74805 ,\n",
" -5.17411 , -15.251893 , 20.096642 , -0.7057824, 5.5778084]],\n",
" dtype=float32)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model2.predict(X_test[[0]]) - model.layers[-2].weights[1].numpy()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[4.0013891e-12, 3.3880292e-12, 4.8922980e-06, 3.8929902e-05,\n",
" 4.1282377e-14, 1.0473054e-11, 3.6530270e-16, 9.9995577e-01,\n",
" 6.8406081e-10, 3.4419332e-07]], dtype=float32)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict(X_test[[0]])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x1777af070>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANh0lEQVR4nO3df6zddX3H8dfL/sJeYFKwtSuVKqKxOsHlCppuSw3DAYYUo2w0GekSZskGCSxmG2ExkmxxjIiETWdSR2clCFOBQLRzksaNkLHKhZRSKFuRdVh71wvUrUXgtqXv/XG/LJdyz+dezvd7zve07+cjuTnnfN/ne77vfHtf/X7v+XzP+TgiBODY95a2GwDQH4QdSIKwA0kQdiAJwg4kMbufG5vreXGchvq5SSCVV/QLHYhxT1WrFXbb50u6RdIsSX8XETeUnn+chnSOz62zSQAFm2NTx1rXp/G2Z0n6qqQLJC2XtNr28m5fD0Bv1fmb/WxJT0fEMxFxQNKdklY10xaAptUJ+xJJP530eFe17HVsr7U9YnvkoMZrbA5AHXXCPtWbAG+49jYi1kXEcEQMz9G8GpsDUEedsO+StHTS41Ml7a7XDoBeqRP2hyWdYftdtudKulTSfc20BaBpXQ+9RcQh21dJ+idNDL2tj4gnGusMQKNqjbNHxEZJGxvqBUAPcbkskARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlaUzbb3ilpv6RXJR2KiOEmmgLQvFphr3w8Ip5v4HUA9BCn8UASdcMekn5o+xHba6d6gu21tkdsjxzUeM3NAehW3dP4FRGx2/ZCSffbfioiHpj8hIhYJ2mdJJ3oBVFzewC6VOvIHhG7q9sxSfdIOruJpgA0r+uw2x6yfcJr9yV9QtK2phoD0Kw6p/GLJN1j+7XX+VZE/KCRrgA0ruuwR8Qzks5ssBcAPcTQG5AEYQeSIOxAEoQdSIKwA0k08UGYFF747Mc61t552dPFdZ8aW1SsHxifU6wvuaNcn7/rxY61w1ueLK6LPDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLPP0J/88bc61j499PPyyqfX3PjKcnnnoZc61m557uM1N370+vHYaR1rQzf9UnHd2Zseabqd1nFkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkHNG/SVpO9II4x+f2bXtN+sVnzulYe/5D5f8zT9pe3sc/f7+L9bkf+p9i/cYP3t2xdt5bXy6u+/2Xji/WPzm/82fl63o5DhTrm8eHivWVxx3setvv+f4Vxfp71z7c9Wu3aXNs0r7YO+UvFEd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCz7PP0NB3Nxdq9V77xHqr62/esbJj7S9WLCtv+1/K33l/48r3dNHRzMx++XCxPrR1tFg/+YG7ivVfmdv5+/bn7yx/F/+xaNoju+31tsdsb5u0bIHt+23vqG5P6m2bAOqayWn8NySdf8SyayVtiogzJG2qHgMYYNOGPSIekLT3iMWrJG2o7m+QdHGzbQFoWrdv0C2KiFFJqm4Xdnqi7bW2R2yPHNR4l5sDUFfP342PiHURMRwRw3M0r9ebA9BBt2HfY3uxJFW3Y821BKAXug37fZLWVPfXSLq3mXYA9Mq04+y279DEN5efYnuXpC9IukHSt21fLulZSZf0skmUHfrvPR1rQ3d1rknSq9O89tB3X+iio2bs+f2PFesfmFv+9f3S3vd1rC37+2eK6x4qVo9O04Y9IlZ3KB2d30IBJMXlskAShB1IgrADSRB2IAnCDiTBR1zRmtmnLS3Wv3LdV4r1OZ5VrH/nlt/sWDt59KHiuscijuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7GjNU3+0pFj/yLzyVNZPHChPR73gyZfedE/HMo7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+zoqfFPfqRj7dHP3DzN2uUZhP7g6quL9bf+64+nef1cOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs6Onnr2g8/HkeJfH0Vf/53nF+vwfPFasR7Gaz7RHdtvrbY/Z3jZp2fW2f2Z7S/VzYW/bBFDXTE7jvyHp/CmW3xwRZ1U/G5ttC0DTpg17RDwgaW8fegHQQ3XeoLvK9tbqNP+kTk+yvdb2iO2RgxqvsTkAdXQb9q9JOl3SWZJGJd3U6YkRsS4ihiNieM40H2wA0DtdhT0i9kTEqxFxWNLXJZ3dbFsAmtZV2G0vnvTwU5K2dXougMEw7Ti77TskrZR0iu1dkr4gaaXtszQxlLlT0hW9axGD7C0nnFCsX/brD3as7Tv8SnHdsS++u1ifN/5wsY7XmzbsEbF6isW39qAXAD3E5bJAEoQdSIKwA0kQdiAJwg4kwUdcUcuO6z9QrH/vlL/tWFu149PFdedtZGitSRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlR9L+/+9Fifevv/HWx/pNDBzvWXvyrU4vrztNosY43hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHtys5f8crF+zef/oVif5/Kv0KWPXdax9vZ/5PPq/cSRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJz9GOfZ5X/iM7+3q1i/5PgXivXb9y8s1hd9vvPx5HBxTTRt2iO77aW2f2R7u+0nbF9dLV9g+37bO6rbk3rfLoBuzeQ0/pCkz0XE+yV9VNKVtpdLulbSpog4Q9Km6jGAATVt2CNiNCIere7vl7Rd0hJJqyRtqJ62QdLFPeoRQAPe1Bt0tpdJ+rCkzZIWRcSoNPEfgqQp/3izvdb2iO2Rgxqv2S6Abs047LaPl3SXpGsiYt9M14uIdRExHBHDczSvmx4BNGBGYbc9RxNBvz0i7q4W77G9uKovljTWmxYBNGHaoTfblnSrpO0R8eVJpfskrZF0Q3V7b086RD1nvq9Y/vOFt9V6+a9+8ZJi/W2PPVTr9dGcmYyzr5B0maTHbW+pll2niZB/2/blkp6VVP5XB9CqacMeEQ9Kcofyuc22A6BXuFwWSIKwA0kQdiAJwg4kQdiBJPiI6zFg1vL3dqytvbPe5Q/L119ZrC+77d9qvT76hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsx4Kk/7PzFvhfNn/GXCk3p1H8+UH5CRK3XR/9wZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwq8ctHZxfqmi24qVOc32wyOWhzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJmczPvlTSNyW9Q9JhSesi4hbb10v6rKTnqqdeFxEbe9VoZrtXzCrW3zm7+7H02/cvLNbn7Ct/np1Psx89ZnJRzSFJn4uIR22fIOkR2/dXtZsj4ku9aw9AU2YyP/uopNHq/n7b2yUt6XVjAJr1pv5mt71M0oclba4WXWV7q+31tqf8biTba22P2B45qPF63QLo2ozDbvt4SXdJuiYi9kn6mqTTJZ2liSP/lBdoR8S6iBiOiOE5mle/YwBdmVHYbc/RRNBvj4i7JSki9kTEqxFxWNLXJZU/rQGgVdOG3bYl3Sppe0R8edLyxZOe9ilJ25pvD0BTZvJu/ApJl0l63PaWatl1klbbPksToy87JV3Rg/5Q01++sLxYf+i3lhXrMfp4g92gTTN5N/5BSZ6ixJg6cBThCjogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Trl7ohfEOT63b9sDstkcm7Qv9k41VM6RHciCsANJEHYgCcIOJEHYgSQIO5AEYQeS6Os4u+3nJP3XpEWnSHq+bw28OYPa26D2JdFbt5rs7bSIePtUhb6G/Q0bt0ciYri1BgoGtbdB7Uuit271qzdO44EkCDuQRNthX9fy9ksGtbdB7Uuit271pbdW/2YH0D9tH9kB9AlhB5JoJey2z7f977aftn1tGz10Ynun7cdtb7E90nIv622P2d42adkC2/fb3lHdTjnHXku9XW/7Z9W+22L7wpZ6W2r7R7a3237C9tXV8lb3XaGvvuy3vv/NbnuWpP+QdJ6kXZIelrQ6Ip7sayMd2N4paTgiWr8Aw/ZvSHpR0jcj4oPVshsl7Y2IG6r/KE+KiD8dkN6ul/Ri29N4V7MVLZ48zbikiyX9nlrcd4W+flt92G9tHNnPlvR0RDwTEQck3SlpVQt9DLyIeEDS3iMWr5K0obq/QRO/LH3XobeBEBGjEfFodX+/pNemGW913xX66os2wr5E0k8nPd6lwZrvPST90PYjtte23cwUFkXEqDTxyyNpYcv9HGnaabz76Yhpxgdm33Uz/XldbYR9qu/HGqTxvxUR8auSLpB0ZXW6ipmZ0TTe/TLFNOMDodvpz+tqI+y7JC2d9PhUSbtb6GNKEbG7uh2TdI8GbyrqPa/NoFvdjrXcz/8bpGm8p5pmXAOw79qc/ryNsD8s6Qzb77I9V9Klku5roY83sD1UvXEi20OSPqHBm4r6PklrqvtrJN3bYi+vMyjTeHeaZlwt77vWpz+PiL7/SLpQE+/I/0TSn7XRQ4e+3i3psernibZ7k3SHJk7rDmrijOhySSdL2iRpR3W7YIB6u03S45K2aiJYi1vq7dc08afhVklbqp8L2953hb76st+4XBZIgivogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wNGNvRI2D7VDgAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(X)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12\n",
"(3, 3, 1, 4)\n",
"(4,)\n",
"(4,)\n",
"(4,)\n",
"(4,)\n",
"(3, 3, 4, 8)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(200, 10)\n",
"(10,)\n"
]
}
],
"source": [
"print(len(model.weights))\n",
"for weights in model.weights:\n",
" print(weights.shape)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"gamma = model.layers[3].weights[0].numpy()\n",
"beta = model.layers[3].weights[1].numpy()\n",
"moving_mean = model.layers[3].weights[2].numpy()\n",
"moving_var = model.layers[3].weights[3].numpy()\n",
"epsilon = model.layers[3].epsilon"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0.93366355, 1.3045508 , 0.52121127, 1.162181 ], dtype=float32),\n",
" array([-0.41578865, 0.18303993, 1.0352895 , 0.01960986], dtype=float32))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a1 = gamma/(moving_var+epsilon)**.5\n",
"b1 = beta-gamma*moving_mean/(moving_var+epsilon)**.5\n",
"a1, b1"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"gamma = model.layers[7].weights[0].numpy()\n",
"beta = model.layers[7].weights[1].numpy()\n",
"moving_mean = model.layers[7].weights[2].numpy()\n",
"moving_var = model.layers[7].weights[3].numpy()\n",
"epsilon = model.layers[7].epsilon"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0.07683876, 0.06093166, 0.08585645, 0.07198451, 0.07364487,\n",
" 0.07903063, 0.08180231, 0.06326427], dtype=float32),\n",
" array([-0.11803538, -0.9628091 , 0.41078255, -1.1416371 , -0.23502782,\n",
" -0.7226403 , -1.27386 , 0.06447531], dtype=float32))"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a2 = gamma/(moving_var+epsilon)**.5\n",
"b2 = beta-gamma*moving_mean/(moving_var+epsilon)**.5\n",
"a2, b2"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": X.astype(int).flatten().tolist(), # X is already 100 times to begin with\n",
" \"conv2d_1_weights\": (model.layers[2].weights[0].numpy()*(10**2)).round().astype(int).flatten().tolist(),\n",
" \"conv2d_1_bias\": (np.array([0]*4)*(10**2)**2).round().astype(int).flatten().tolist(),\n",
" \"bn_1_a\": (a1*(10**2)).round().astype(int).flatten().tolist(),\n",
" \"bn_1_b\": (b1*(10**2)**3).round().astype(int).flatten().tolist(),\n",
" # poly layer would be (10**2)**3=10**6 times as well\n",
" # average pooling will scale another 10**2 times\n",
" \"conv2d_2_weights\": (model.layers[6].weights[0].numpy()*(10**2)).round().astype(int).flatten().tolist(),\n",
" \"conv2d_2_bias\": (np.array([0]*8)*((10**2)**8)).round().astype(int).flatten().tolist(),\n",
" \"bn_2_a\": (a2*(10**2)).round().astype(int).flatten().tolist(),\n",
" \"bn_2_b\": (b2*(10**2)**9).round().astype(int).flatten().tolist(),\n",
" # poly layer would be (10**2)**9=10**18 times as well\n",
" # average pooling will scale another 10**2 times\n",
" \"dense_weights\":(model.layers[11].weights[0].numpy()*(10**2)).round().astype(int).flatten().tolist(),\n",
" \"dense_bias\": np.zeros(model.layers[11].weights[1].numpy().shape).tolist() # zero because we are not doing softmax in circom, just argmax\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"with open(\"mnist_latest_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "11280bdb37aa6bc5d4cf1e4de756386eb1f9eecd8dcdefa77636dfac7be2370d"
},
"kernelspec": {
"display_name": "Python 3.8.6 ('tf24')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,136 @@
pragma circom 2.0.3;
include "../../circuits/Conv2D.circom";
include "../../circuits/Dense.circom";
include "../../circuits/ArgMax.circom";
include "../../circuits/Poly.circom";
include "../../circuits/AveragePooling2D.circom";
include "../../circuits/BatchNormalization2D.circom";
template mnist_latest() {
signal input in[28][28][1];
signal input conv2d_1_weights[3][3][1][4];
signal input conv2d_1_bias[4];
signal input bn_1_a[4];
signal input bn_1_b[4];
signal input conv2d_2_weights[3][3][4][8];
signal input conv2d_2_bias[8];
signal input bn_2_a[8];
signal input bn_2_b[8];
signal input dense_weights[200][10];
signal input dense_bias[10];
signal output out;
component conv2d_1 = Conv2D(28,28,1,4,3,1);
component bn_1 = BatchNormalization2D(26,26,4);
component poly_1[26][26][4];
component avg2d_1 = AveragePooling2D(26,26,4,2,2,25);
component conv2d_2 = Conv2D(13,13,4,8,3,1);
component bn_2 = BatchNormalization2D(11,11,8);
component poly_2[11][11][8];
component avg2d_2 = AveragePooling2D(11,11,8,2,2,25);
component dense = Dense(200,10);
component argmax = ArgMax(10);
for (var i=0; i<28; i++) {
for (var j=0; j<28; j++) {
conv2d_1.in[i][j][0] <== in[i][j][0];
}
}
for (var m=0; m<4; m++) {
for (var i=0; i<3; i++) {
for (var j=0; j<3; j++) {
conv2d_1.weights[i][j][0][m] <== conv2d_1_weights[i][j][0][m];
}
}
conv2d_1.bias[m] <== conv2d_1_bias[m];
}
for (var k=0; k<4; k++) {
bn_1.a[k] <== bn_1_a[k];
bn_1.b[k] <== bn_1_b[k];
for (var i=0; i<26; i++) {
for (var j=0; j<26; j++) {
bn_1.in[i][j][k] <== conv2d_1.out[i][j][k];
}
}
}
for (var i=0; i<26; i++) {
for (var j=0; j<26; j++) {
for (var k=0; k<4; k++) {
poly_1[i][j][k] = Poly(10**6);
poly_1[i][j][k].in <== bn_1.out[i][j][k];
avg2d_1.in[i][j][k] <== poly_1[i][j][k].out;
}
}
}
for (var i=0; i<13; i++) {
for (var j=0; j<13; j++) {
for (var k=0; k<4; k++) {
conv2d_2.in[i][j][k] <== avg2d_1.out[i][j][k];
}
}
}
for (var m=0; m<8; m++) {
for (var i=0; i<3; i++) {
for (var j=0; j<3; j++) {
for (var k=0; k<4; k++) {
conv2d_2.weights[i][j][k][m] <== conv2d_2_weights[i][j][k][m];
}
}
}
conv2d_2.bias[m] <== conv2d_2_bias[m];
}
for (var k=0; k<8; k++) {
bn_2.a[k] <== bn_2_a[k];
bn_2.b[k] <== bn_2_b[k];
for (var i=0; i<11; i++) {
for (var j=0; j<11; j++) {
bn_2.in[i][j][k] <== conv2d_2.out[i][j][k];
}
}
}
for (var i=0; i<11; i++) {
for (var j=0; j<11; j++) {
for (var k=0; k<8; k++) {
poly_2[i][j][k] = Poly(10**18);
poly_2[i][j][k].in <== bn_2.out[i][j][k];
avg2d_2.in[i][j][k] <== poly_2[i][j][k].out;
}
}
}
var idx = 0;
for (var i=0; i<5; i++) {
for (var j=0; j<5; j++) {
for (var k=0; k<8; k++) {
dense.in[idx] <== avg2d_2.out[i][j][k];
for (var m=0; m<10; m++) {
dense.weights[idx][m] <== dense_weights[idx][m];
}
idx++;
}
}
}
for (var i=0; i<10; i++) {
dense.bias[i] <== dense_bias[i];
}
for (var i=0; i<10; i++) {
log(dense.out[i]);
argmax.in[i] <== dense.out[i];
}
out <== argmax.out;
}
component main = mnist_latest();

45
test/mnist_latest.js Normal file
View File

@@ -0,0 +1,45 @@
const chai = require("chai");
const path = require("path");
const wasm_tester = require("circom_tester").wasm;
const F1Field = require("ffjavascript").F1Field;
const Scalar = require("ffjavascript").Scalar;
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const Fr = new F1Field(exports.p);
const assert = chai.assert;
const json = require("../models/mnist_latest_input.json");
describe("mnist latest test", function () {
this.timeout(100000000);
it("should return correct output", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_latest_test.circom"));
//await circuit.loadConstraints();
//assert.equal(circuit.nVars, 70524);
//assert.equal(circuit.constraints.length, 67403);
let INPUT = {};
for (const [key, value] of Object.entries(json)) {
if (Array.isArray(value)) {
let tmpArray = [];
for (let i = 0; i < value.flat().length; i++) {
tmpArray.push(Fr.e(value.flat()[i]));
}
INPUT[key] = tmpArray;
} else {
INPUT[key] = Fr.e(value);
}
}
const witness = await circuit.calculateWitness(INPUT, true);
//console.log(witness[1]);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
assert(Fr.eq(Fr.e(witness[1]),Fr.e(7)));
});
});