mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-09 14:08:04 -05:00
Model optimized for precision w/ proof size tradeoff
This commit is contained in:
497
models/mnist_latest_optimized.ipynb
Normal file
497
models/mnist_latest_optimized.ipynb
Normal file
@@ -0,0 +1,497 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# model architecture modified from https://keras.io/examples/vision/mnist_convnet/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Softmax, Dense, Lambda, BatchNormalization, ReLU\n",
|
||||
"from tensorflow.keras import Model\n",
|
||||
"from tensorflow.keras.datasets import mnist\n",
|
||||
"from tensorflow.keras.utils import to_categorical\n",
|
||||
"from tensorflow.keras.optimizers import SGD\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import tensorflow as tf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"(X_train, y_train), (X_test, y_test) = mnist.load_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert y_train into one-hot format\n",
|
||||
"temp = []\n",
|
||||
"for i in range(len(y_train)):\n",
|
||||
" temp.append(to_categorical(y_train[i], num_classes=10))\n",
|
||||
"y_train = np.array(temp)\n",
|
||||
"# Convert y_test into one-hot format\n",
|
||||
"temp = []\n",
|
||||
"for i in range(len(y_test)): \n",
|
||||
" temp.append(to_categorical(y_test[i], num_classes=10))\n",
|
||||
"y_test = np.array(temp)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#reshaping\n",
|
||||
"X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)\n",
|
||||
"X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = Input(shape=(28,28,1))\n",
|
||||
"out = Lambda(lambda x: x/1000)(inputs)\n",
|
||||
"out = Conv2D(4, 3, use_bias=False)(out)\n",
|
||||
"out = BatchNormalization()(out)\n",
|
||||
"out = ReLU()(out)\n",
|
||||
"# out = Lambda(lambda x: x**2+x)(out)\n",
|
||||
"out = AveragePooling2D()(out)\n",
|
||||
"# out = Lambda(lambda x: x*4)(out)\n",
|
||||
"out = Conv2D(8, 3, use_bias=False)(out)\n",
|
||||
"out = BatchNormalization()(out)\n",
|
||||
"out = ReLU()(out)\n",
|
||||
"# out = Lambda(lambda x: x**2+x)(out)\n",
|
||||
"out = AveragePooling2D()(out)\n",
|
||||
"# out = Lambda(lambda x: x*4)(out)\n",
|
||||
"out = Flatten()(out)\n",
|
||||
"out = Dense(10, activation=None)(out)\n",
|
||||
"out = Softmax()(out)\n",
|
||||
"model = Model(inputs, out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
"Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
"input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda (Lambda) (None, 28, 28, 1) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"conv2d (Conv2D) (None, 26, 26, 4) 36 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"batch_normalization (BatchNo (None, 26, 26, 4) 16 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"re_lu (ReLU) (None, 26, 26, 4) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"average_pooling2d (AveragePo (None, 13, 13, 4) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"conv2d_1 (Conv2D) (None, 11, 11, 8) 288 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"batch_normalization_1 (Batch (None, 11, 11, 8) 32 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"re_lu_1 (ReLU) (None, 11, 11, 8) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"average_pooling2d_1 (Average (None, 5, 5, 8) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"flatten (Flatten) (None, 200) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"dense (Dense) (None, 10) 2010 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"softmax (Softmax) (None, 10) 0 \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 2,382\n",
|
||||
"Trainable params: 2,358\n",
|
||||
"Non-trainable params: 24\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.compile(\n",
|
||||
" loss='categorical_crossentropy',\n",
|
||||
" optimizer=SGD(learning_rate=0.01, momentum=0.9),\n",
|
||||
" metrics=['acc']\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/15\n",
|
||||
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.3852 - acc: 0.8843 - val_loss: 0.0964 - val_acc: 0.9703\n",
|
||||
"Epoch 2/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0878 - acc: 0.9740 - val_loss: 0.0707 - val_acc: 0.9770\n",
|
||||
"Epoch 3/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0677 - acc: 0.9793 - val_loss: 0.0647 - val_acc: 0.9785\n",
|
||||
"Epoch 4/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0598 - acc: 0.9818 - val_loss: 0.0611 - val_acc: 0.9811\n",
|
||||
"Epoch 5/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0555 - acc: 0.9827 - val_loss: 0.0592 - val_acc: 0.9817\n",
|
||||
"Epoch 6/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0506 - acc: 0.9833 - val_loss: 0.0720 - val_acc: 0.9768\n",
|
||||
"Epoch 7/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0454 - acc: 0.9852 - val_loss: 0.0496 - val_acc: 0.9850\n",
|
||||
"Epoch 8/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0447 - acc: 0.9867 - val_loss: 0.0427 - val_acc: 0.9863\n",
|
||||
"Epoch 9/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0409 - acc: 0.9872 - val_loss: 0.0467 - val_acc: 0.9854\n",
|
||||
"Epoch 10/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0411 - acc: 0.9877 - val_loss: 0.0442 - val_acc: 0.9858\n",
|
||||
"Epoch 11/15\n",
|
||||
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.0376 - acc: 0.9881 - val_loss: 0.0404 - val_acc: 0.9872\n",
|
||||
"Epoch 12/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0377 - acc: 0.9889 - val_loss: 0.0466 - val_acc: 0.9860\n",
|
||||
"Epoch 13/15\n",
|
||||
"1875/1875 [==============================] - 3s 2ms/step - loss: 0.0361 - acc: 0.9885 - val_loss: 0.0521 - val_acc: 0.9845\n",
|
||||
"Epoch 14/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0352 - acc: 0.9898 - val_loss: 0.0422 - val_acc: 0.9853\n",
|
||||
"Epoch 15/15\n",
|
||||
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.0348 - acc: 0.9899 - val_loss: 0.0431 - val_acc: 0.9860\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<tensorflow.python.keras.callbacks.History at 0x2a1c54bb0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.fit(X_train, y_train, epochs=15, batch_size=32, validation_data=(X_test, y_test))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"((28, 28, 1), 0, 255)"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = X_test[0]\n",
|
||||
"X.shape, X.min(), X.max()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model2 = Model(model.input, model.layers[-2].output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[ -6.2468867, -6.3268614, 5.410377 , 5.8547926, -5.245466 ,\n",
|
||||
" -3.587707 , -17.595942 , 17.122202 , -2.3852873, 3.7686806]],\n",
|
||||
" dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model2.predict(X_test[[0]]) - model.layers[-2].weights[1].numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[6.3815404e-11, 6.2323799e-11, 8.0492591e-06, 1.0712383e-05,\n",
|
||||
" 1.7036507e-10, 1.0719555e-09, 7.9848600e-16, 9.9997973e-01,\n",
|
||||
" 4.0607868e-09, 1.3976586e-06]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.predict(X_test[[0]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<matplotlib.image.AxesImage at 0x2a48ef760>"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANh0lEQVR4nO3df6zddX3H8dfL/sJeYFKwtSuVKqKxOsHlCppuSw3DAYYUo2w0GekSZskGCSxmG2ExkmxxjIiETWdSR2clCFOBQLRzksaNkLHKhZRSKFuRdVh71wvUrUXgtqXv/XG/LJdyz+dezvd7zve07+cjuTnnfN/ne77vfHtf/X7v+XzP+TgiBODY95a2GwDQH4QdSIKwA0kQdiAJwg4kMbufG5vreXGchvq5SSCVV/QLHYhxT1WrFXbb50u6RdIsSX8XETeUnn+chnSOz62zSQAFm2NTx1rXp/G2Z0n6qqQLJC2XtNr28m5fD0Bv1fmb/WxJT0fEMxFxQNKdklY10xaAptUJ+xJJP530eFe17HVsr7U9YnvkoMZrbA5AHXXCPtWbAG+49jYi1kXEcEQMz9G8GpsDUEedsO+StHTS41Ml7a7XDoBeqRP2hyWdYftdtudKulTSfc20BaBpXQ+9RcQh21dJ+idNDL2tj4gnGusMQKNqjbNHxEZJGxvqBUAPcbkskARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlaUzbb3ilpv6RXJR2KiOEmmgLQvFphr3w8Ip5v4HUA9BCn8UASdcMekn5o+xHba6d6gu21tkdsjxzUeM3NAehW3dP4FRGx2/ZCSffbfioiHpj8hIhYJ2mdJJ3oBVFzewC6VOvIHhG7q9sxSfdIOruJpgA0r+uw2x6yfcJr9yV9QtK2phoD0Kw6p/GLJN1j+7XX+VZE/KCRrgA0ruuwR8Qzks5ssBcAPcTQG5AEYQeSIOxAEoQdSIKwA0k08UGYFF747Mc61t552dPFdZ8aW1SsHxifU6wvuaNcn7/rxY61w1ueLK6LPDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLPP0J/88bc61j499PPyyqfX3PjKcnnnoZc61m557uM1N370+vHYaR1rQzf9UnHd2Zseabqd1nFkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkHNG/SVpO9II4x+f2bXtN+sVnzulYe/5D5f8zT9pe3sc/f7+L9bkf+p9i/cYP3t2xdt5bXy6u+/2Xji/WPzm/82fl63o5DhTrm8eHivWVxx3setvv+f4Vxfp71z7c9Wu3aXNs0r7YO+UvFEd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCz7PP0NB3Nxdq9V77xHqr62/esbJj7S9WLCtv+1/K33l/48r3dNHRzMx++XCxPrR1tFg/+YG7ivVfmdv5+/bn7yx/F/+xaNoju+31tsdsb5u0bIHt+23vqG5P6m2bAOqayWn8NySdf8SyayVtiogzJG2qHgMYYNOGPSIekLT3iMWrJG2o7m+QdHGzbQFoWrdv0C2KiFFJqm4Xdnqi7bW2R2yPHNR4l5sDUFfP342PiHURMRwRw3M0r9ebA9BBt2HfY3uxJFW3Y821BKAXug37fZLWVPfXSLq3mXYA9Mq04+y279DEN5efYnuXpC9IukHSt21fLulZSZf0skmUHfrvPR1rQ3d1rknSq9O89tB3X+iio2bs+f2PFesfmFv+9f3S3vd1rC37+2eK6x4qVo9O04Y9IlZ3KB2d30IBJMXlskAShB1IgrADSRB2IAnCDiTBR1zRmtmnLS3Wv3LdV4r1OZ5VrH/nlt/sWDt59KHiuscijuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7GjNU3+0pFj/yLzyVNZPHChPR73gyZfedE/HMo7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+zoqfFPfqRj7dHP3DzN2uUZhP7g6quL9bf+64+nef1cOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs6Onnr2g8/HkeJfH0Vf/53nF+vwfPFasR7Gaz7RHdtvrbY/Z3jZp2fW2f2Z7S/VzYW/bBFDXTE7jvyHp/CmW3xwRZ1U/G5ttC0DTpg17RDwgaW8fegHQQ3XeoLvK9tbqNP+kTk+yvdb2iO2RgxqvsTkAdXQb9q9JOl3SWZJGJd3U6YkRsS4ihiNieM40H2wA0DtdhT0i9kTEqxFxWNLXJZ3dbFsAmtZV2G0vnvTwU5K2dXougMEw7Ti77TskrZR0iu1dkr4gaaXtszQxlLlT0hW9axGD7C0nnFCsX/brD3as7Tv8SnHdsS++u1ifN/5wsY7XmzbsEbF6isW39qAXAD3E5bJAEoQdSIKwA0kQdiAJwg4kwUdcUcuO6z9QrH/vlL/tWFu149PFdedtZGitSRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlR9L+/+9Fifevv/HWx/pNDBzvWXvyrU4vrztNosY43hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHtys5f8crF+zef/oVif5/Kv0KWPXdax9vZ/5PPq/cSRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJz9GOfZ5X/iM7+3q1i/5PgXivXb9y8s1hd9vvPx5HBxTTRt2iO77aW2f2R7u+0nbF9dLV9g+37bO6rbk3rfLoBuzeQ0/pCkz0XE+yV9VNKVtpdLulbSpog4Q9Km6jGAATVt2CNiNCIere7vl7Rd0hJJqyRtqJ62QdLFPeoRQAPe1Bt0tpdJ+rCkzZIWRcSoNPEfgqQp/3izvdb2iO2Rgxqv2S6Abs047LaPl3SXpGsiYt9M14uIdRExHBHDczSvmx4BNGBGYbc9RxNBvz0i7q4W77G9uKovljTWmxYBNGHaoTfblnSrpO0R8eVJpfskrZF0Q3V7b086RD1nvq9Y/vOFt9V6+a9+8ZJi/W2PPVTr9dGcmYyzr5B0maTHbW+pll2niZB/2/blkp6VVP5XB9CqacMeEQ9Kcofyuc22A6BXuFwWSIKwA0kQdiAJwg4kQdiBJPiI6zFg1vL3dqytvbPe5Q/L119ZrC+77d9qvT76hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsx4Kk/7PzFvhfNn/GXCk3p1H8+UH5CRK3XR/9wZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwq8ctHZxfqmi24qVOc32wyOWhzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJmczPvlTSNyW9Q9JhSesi4hbb10v6rKTnqqdeFxEbe9VoZrtXzCrW3zm7+7H02/cvLNbn7Ct/np1Psx89ZnJRzSFJn4uIR22fIOkR2/dXtZsj4ku9aw9AU2YyP/uopNHq/n7b2yUt6XVjAJr1pv5mt71M0oclba4WXWV7q+31tqf8biTba22P2B45qPF63QLo2ozDbvt4SXdJuiYi9kn6mqTTJZ2liSP/lBdoR8S6iBiOiOE5mle/YwBdmVHYbc/RRNBvj4i7JSki9kTEqxFxWNLXJZU/rQGgVdOG3bYl3Sppe0R8edLyxZOe9ilJ25pvD0BTZvJu/ApJl0l63PaWatl1klbbPksToy87JV3Rg/5Q01++sLxYf+i3lhXrMfp4g92gTTN5N/5BSZ6ixJg6cBThCjogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Trl7ohfEOT63b9sDstkcm7Qv9k41VM6RHciCsANJEHYgCcIOJEHYgSQIO5AEYQeS6Os4u+3nJP3XpEWnSHq+bw28OYPa26D2JdFbt5rs7bSIePtUhb6G/Q0bt0ciYri1BgoGtbdB7Uuit271qzdO44EkCDuQRNthX9fy9ksGtbdB7Uuit271pbdW/2YH0D9tH9kB9AlhB5JoJey2z7f977aftn1tGz10Ynun7cdtb7E90nIv622P2d42adkC2/fb3lHdTjnHXku9XW/7Z9W+22L7wpZ6W2r7R7a3237C9tXV8lb3XaGvvuy3vv/NbnuWpP+QdJ6kXZIelrQ6Ip7sayMd2N4paTgiWr8Aw/ZvSHpR0jcj4oPVshsl7Y2IG6r/KE+KiD8dkN6ul/Ri29N4V7MVLZ48zbikiyX9nlrcd4W+flt92G9tHNnPlvR0RDwTEQck3SlpVQt9DLyIeEDS3iMWr5K0obq/QRO/LH3XobeBEBGjEfFodX+/pNemGW913xX66os2wr5E0k8nPd6lwZrvPST90PYjtte23cwUFkXEqDTxyyNpYcv9HGnaabz76Yhpxgdm33Uz/XldbYR9qu/HGqTxvxUR8auSLpB0ZXW6ipmZ0TTe/TLFNOMDodvpz+tqI+y7JC2d9PhUSbtb6GNKEbG7uh2TdI8GbyrqPa/NoFvdjrXcz/8bpGm8p5pmXAOw79qc/ryNsD8s6Qzb77I9V9Klku5roY83sD1UvXEi20OSPqHBm4r6PklrqvtrJN3bYi+vMyjTeHeaZlwt77vWpz+PiL7/SLpQE+/I/0TSn7XRQ4e+3i3psernibZ7k3SHJk7rDmrijOhySSdL2iRpR3W7YIB6u03S45K2aiJYi1vq7dc08afhVklbqp8L2953hb76st+4XBZIgivogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wNGNvRI2D7VDgAAAABJRU5ErkJggg==",
|
||||
"text/plain": [
|
||||
"<Figure size 432x288 with 1 Axes>"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"needs_background": "light"
|
||||
},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"plt.imshow(X)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"12\n",
|
||||
"(3, 3, 1, 4)\n",
|
||||
"(4,)\n",
|
||||
"(4,)\n",
|
||||
"(4,)\n",
|
||||
"(4,)\n",
|
||||
"(3, 3, 4, 8)\n",
|
||||
"(8,)\n",
|
||||
"(8,)\n",
|
||||
"(8,)\n",
|
||||
"(8,)\n",
|
||||
"(200, 10)\n",
|
||||
"(10,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(len(model.weights))\n",
|
||||
"for weights in model.weights:\n",
|
||||
" print(weights.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gamma = model.layers[3].weights[0].numpy()\n",
|
||||
"beta = model.layers[3].weights[1].numpy()\n",
|
||||
"moving_mean = model.layers[3].weights[2].numpy()\n",
|
||||
"moving_var = model.layers[3].weights[3].numpy()\n",
|
||||
"epsilon = model.layers[3].epsilon"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(array([ 8.792204, 9.772785, 10.312538, 8.802087], dtype=float32),\n",
|
||||
" array([-1.1708343 , -0.05037522, -1.342272 , -0.06242919], dtype=float32))"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a1 = gamma/(moving_var+epsilon)**.5\n",
|
||||
"b1 = beta-gamma*moving_mean/(moving_var+epsilon)**.5\n",
|
||||
"a1, b1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gamma = model.layers[7].weights[0].numpy()\n",
|
||||
"beta = model.layers[7].weights[1].numpy()\n",
|
||||
"moving_mean = model.layers[7].weights[2].numpy()\n",
|
||||
"moving_var = model.layers[7].weights[3].numpy()\n",
|
||||
"epsilon = model.layers[7].epsilon"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(array([1.4212797, 1.5523269, 1.5874738, 1.593189 , 2.330205 , 1.5186743,\n",
|
||||
" 1.3124124, 1.6141205], dtype=float32),\n",
|
||||
" array([-0.939713 , -0.15037137, -1.5690781 , 1.2707491 , -0.06281102,\n",
|
||||
" -1.2463849 , 0.5863233 , -0.38434148], dtype=float32))"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a2 = gamma/(moving_var+epsilon)**.5\n",
|
||||
"b2 = beta-gamma*moving_mean/(moving_var+epsilon)**.5\n",
|
||||
"a2, b2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_json = {\n",
|
||||
" \"in\": X.astype(int).flatten().tolist(), # X is already 1000 times to begin with\n",
|
||||
" \"conv2d_1_weights\": (model.layers[2].weights[0].numpy()*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" \"conv2d_1_bias\": (np.array([0]*4)*(10**3)**2).round().astype(int).flatten().tolist(),\n",
|
||||
" \"bn_1_a\": (a1*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" \"bn_1_b\": (b1*(10**3)**3).round().astype(int).flatten().tolist(),\n",
|
||||
" # average pooling will scale another 10**2 times\n",
|
||||
" \"conv2d_2_weights\": (model.layers[6].weights[0].numpy()*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" \"conv2d_2_bias\": (np.array([0]*8)*(10**(3*4+2))).round().astype(int).flatten().tolist(),\n",
|
||||
" \"bn_2_a\": (a2*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" \"bn_2_b\": (b2*(10**(3*5+2))).round().astype(int).flatten().tolist(),\n",
|
||||
" # average pooling will scale another 10**2 times\n",
|
||||
" \"dense_weights\":(model.layers[11].weights[0].numpy()*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" \"dense_bias\": np.zeros(model.layers[11].weights[1].numpy().shape).tolist() # zero because we are not doing softmax in circom, just argmax\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"mnist_latest_optimized_input.json\", \"w\") as f:\n",
|
||||
" json.dump(in_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "11280bdb37aa6bc5d4cf1e4de756386eb1f9eecd8dcdefa77636dfac7be2370d"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.6 ('tf24')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
1
models/mnist_latest_optimized_input.json
Normal file
1
models/mnist_latest_optimized_input.json
Normal file
File diff suppressed because one or more lines are too long
413
models/mnist_quantization.ipynb
Normal file
413
models/mnist_quantization.ipynb
Normal file
@@ -0,0 +1,413 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"accelerator": "GPU",
|
||||
"gpuClass": "standard"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install -q tensorflow-model-optimization"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "tW_c250QErxC"
|
||||
},
|
||||
"execution_count": 1,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "zwDu9CS6Eeh4"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Softmax, Dense, ReLU, BatchNormalization\n",
|
||||
"from tensorflow.keras import Model\n",
|
||||
"from tensorflow.keras.datasets import mnist\n",
|
||||
"from tensorflow.keras.utils import to_categorical\n",
|
||||
"from tensorflow.keras.optimizers import SGD\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import tensorflow as tf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import tensorflow_model_optimization as tfmot\n",
|
||||
"quantize_model = tfmot.quantization.keras.quantize_model"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Or13lcHmEg7t"
|
||||
},
|
||||
"execution_count": 3,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"(X_train, y_train), (X_test, y_test) = mnist.load_data()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "S150IfZ6Enw-"
|
||||
},
|
||||
"execution_count": 4,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Convert y_train into one-hot format\n",
|
||||
"temp = []\n",
|
||||
"for i in range(len(y_train)):\n",
|
||||
" temp.append(to_categorical(y_train[i], num_classes=10))\n",
|
||||
"y_train = np.array(temp)\n",
|
||||
"# Convert y_test into one-hot format\n",
|
||||
"temp = []\n",
|
||||
"for i in range(len(y_test)): \n",
|
||||
" temp.append(to_categorical(y_test[i], num_classes=10))\n",
|
||||
"y_test = np.array(temp)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "-2leSVDxExXc"
|
||||
},
|
||||
"execution_count": 5,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"#reshaping\n",
|
||||
"X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)\n",
|
||||
"X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "9jXXRXAAEySd"
|
||||
},
|
||||
"execution_count": 6,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"inputs = Input(shape=(28,28,1))\n",
|
||||
"# out = Lambda(lambda x: x/100)(inputs)\n",
|
||||
"out = Conv2D(4, 3, use_bias=False)(inputs)\n",
|
||||
"out = BatchNormalization()(out)\n",
|
||||
"out = ReLU()(out)\n",
|
||||
"# out = Lambda(lambda x: x**2+x)(out)\n",
|
||||
"out = AveragePooling2D()(out)\n",
|
||||
"# out = Lambda(lambda x: x*4)(out)\n",
|
||||
"out = Conv2D(8, 3, use_bias=False)(out)\n",
|
||||
"out = BatchNormalization()(out)\n",
|
||||
"out = ReLU()(out)\n",
|
||||
"# out = Lambda(lambda x: x**2+x)(out)\n",
|
||||
"out = AveragePooling2D()(out)\n",
|
||||
"# out = Lambda(lambda x: x*4)(out)\n",
|
||||
"out = Flatten()(out)\n",
|
||||
"out = Dense(10, activation=None)(out)\n",
|
||||
"out = Softmax()(out)\n",
|
||||
"model = Model(inputs, out)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "5TUZqKqVEzcE"
|
||||
},
|
||||
"execution_count": 7,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"q_aware_model = quantize_model(model)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "QMu4SdQDE0bv"
|
||||
},
|
||||
"execution_count": 8,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"q_aware_model.summary()"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "yNB8G40_E1nl",
|
||||
"outputId": "1b528140-5262-461a-a56f-84f34152cf0c"
|
||||
},
|
||||
"execution_count": 9,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Model: \"model\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
" Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
" input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n",
|
||||
" \n",
|
||||
" quantize_layer (QuantizeLay (None, 28, 28, 1) 3 \n",
|
||||
" er) \n",
|
||||
" \n",
|
||||
" quant_conv2d (QuantizeWrapp (None, 26, 26, 4) 45 \n",
|
||||
" erV2) \n",
|
||||
" \n",
|
||||
" quant_batch_normalization ( (None, 26, 26, 4) 17 \n",
|
||||
" QuantizeWrapperV2) \n",
|
||||
" \n",
|
||||
" quant_re_lu (QuantizeWrappe (None, 26, 26, 4) 3 \n",
|
||||
" rV2) \n",
|
||||
" \n",
|
||||
" quant_average_pooling2d (Qu (None, 13, 13, 4) 3 \n",
|
||||
" antizeWrapperV2) \n",
|
||||
" \n",
|
||||
" quant_conv2d_1 (QuantizeWra (None, 11, 11, 8) 305 \n",
|
||||
" pperV2) \n",
|
||||
" \n",
|
||||
" quant_batch_normalization_1 (None, 11, 11, 8) 33 \n",
|
||||
" (QuantizeWrapperV2) \n",
|
||||
" \n",
|
||||
" quant_re_lu_1 (QuantizeWrap (None, 11, 11, 8) 3 \n",
|
||||
" perV2) \n",
|
||||
" \n",
|
||||
" quant_average_pooling2d_1 ( (None, 5, 5, 8) 3 \n",
|
||||
" QuantizeWrapperV2) \n",
|
||||
" \n",
|
||||
" quant_flatten (QuantizeWrap (None, 200) 1 \n",
|
||||
" perV2) \n",
|
||||
" \n",
|
||||
" quant_dense (QuantizeWrappe (None, 10) 2015 \n",
|
||||
" rV2) \n",
|
||||
" \n",
|
||||
" quant_softmax (QuantizeWrap (None, 10) 1 \n",
|
||||
" perV2) \n",
|
||||
" \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 2,432\n",
|
||||
"Trainable params: 2,358\n",
|
||||
"Non-trainable params: 74\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"q_aware_model.compile(\n",
|
||||
" loss='categorical_crossentropy',\n",
|
||||
" optimizer=SGD(learning_rate=0.01, momentum=0.9),\n",
|
||||
" metrics=['acc']\n",
|
||||
" )"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "svS4mnhGE2dp"
|
||||
},
|
||||
"execution_count": 10,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"q_aware_model.fit(X_train, y_train, epochs=15, batch_size=32, validation_data=(X_test, y_test))"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "EIEaD9vOE32A",
|
||||
"outputId": "a8c6e838-3faa-438b-dd22-4a932f91bb89"
|
||||
},
|
||||
"execution_count": 11,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Epoch 1/15\n",
|
||||
"1875/1875 [==============================] - 23s 11ms/step - loss: 0.2013 - acc: 0.9406 - val_loss: 0.1018 - val_acc: 0.9677\n",
|
||||
"Epoch 2/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0842 - acc: 0.9743 - val_loss: 0.0808 - val_acc: 0.9740\n",
|
||||
"Epoch 3/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0694 - acc: 0.9788 - val_loss: 0.0584 - val_acc: 0.9813\n",
|
||||
"Epoch 4/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0608 - acc: 0.9815 - val_loss: 0.0538 - val_acc: 0.9824\n",
|
||||
"Epoch 5/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0556 - acc: 0.9832 - val_loss: 0.0481 - val_acc: 0.9853\n",
|
||||
"Epoch 6/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0506 - acc: 0.9843 - val_loss: 0.0464 - val_acc: 0.9850\n",
|
||||
"Epoch 7/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0485 - acc: 0.9849 - val_loss: 0.0670 - val_acc: 0.9779\n",
|
||||
"Epoch 8/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0453 - acc: 0.9858 - val_loss: 0.0448 - val_acc: 0.9860\n",
|
||||
"Epoch 9/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0432 - acc: 0.9867 - val_loss: 0.0566 - val_acc: 0.9823\n",
|
||||
"Epoch 10/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0419 - acc: 0.9874 - val_loss: 0.0385 - val_acc: 0.9870\n",
|
||||
"Epoch 11/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0397 - acc: 0.9876 - val_loss: 0.0453 - val_acc: 0.9850\n",
|
||||
"Epoch 12/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0389 - acc: 0.9880 - val_loss: 0.0597 - val_acc: 0.9816\n",
|
||||
"Epoch 13/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0372 - acc: 0.9885 - val_loss: 0.0439 - val_acc: 0.9847\n",
|
||||
"Epoch 14/15\n",
|
||||
"1875/1875 [==============================] - 11s 6ms/step - loss: 0.0360 - acc: 0.9887 - val_loss: 0.0357 - val_acc: 0.9884\n",
|
||||
"Epoch 15/15\n",
|
||||
"1875/1875 [==============================] - 12s 6ms/step - loss: 0.0356 - acc: 0.9889 - val_loss: 0.0400 - val_acc: 0.9857\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<keras.callbacks.History at 0x7f8045540d10>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 11
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)\n",
|
||||
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
|
||||
"\n",
|
||||
"model = converter.convert()"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "M1-DT7H_E4xN",
|
||||
"outputId": "0c43a3df-9006-4f78-c8e4-4e98c68c7fbf"
|
||||
},
|
||||
"execution_count": 12,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"WARNING:absl:Found untraced functions such as conv2d_layer_call_fn, conv2d_layer_call_and_return_conditional_losses, _jit_compiled_convolution_op, re_lu_layer_call_fn, re_lu_layer_call_and_return_conditional_losses while saving (showing 5 of 16). These functions will not be directly callable after loading.\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/tensorflow/lite/python/convert.py:766: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.\n",
|
||||
" warnings.warn(\"Statistics for quantized inputs were expected, but not \"\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"'''\n",
|
||||
"Create interpreter, allocate tensors\n",
|
||||
"'''\n",
|
||||
"tflite_interpreter = tf.lite.Interpreter(model_content=model)\n",
|
||||
"tflite_interpreter.allocate_tensors()\n",
|
||||
"\n",
|
||||
"'''\n",
|
||||
"Check input/output details\n",
|
||||
"'''\n",
|
||||
"input_details = tflite_interpreter.get_input_details()\n",
|
||||
"output_details = tflite_interpreter.get_output_details()\n",
|
||||
"\n",
|
||||
"print(\"== Input details ==\")\n",
|
||||
"print(\"name:\", input_details[0]['name'])\n",
|
||||
"print(\"shape:\", input_details[0]['shape'])\n",
|
||||
"print(\"type:\", input_details[0]['dtype'])\n",
|
||||
"print(\"\\n== Output details ==\")\n",
|
||||
"print(\"name:\", output_details[0]['name'])\n",
|
||||
"print(\"shape:\", output_details[0]['shape'])\n",
|
||||
"print(\"type:\", output_details[0]['dtype'])\n",
|
||||
"\n",
|
||||
"'''\n",
|
||||
"This gives a list of dictionaries. \n",
|
||||
"'''\n",
|
||||
"tensor_details = tflite_interpreter.get_tensor_details()\n",
|
||||
"\n",
|
||||
"for dict in tensor_details:\n",
|
||||
" i = dict['index']\n",
|
||||
" tensor_name = dict['name']\n",
|
||||
" scales = dict['quantization_parameters']['scales']\n",
|
||||
" zero_points = dict['quantization_parameters']['zero_points']\n",
|
||||
" tensor = tflite_interpreter.tensor(i)()\n",
|
||||
"\n",
|
||||
" print(i, type, tensor_name, scales.shape, zero_points.shape, tensor.shape)\n",
|
||||
" # print(tensor)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "LHt-zIshF0pv",
|
||||
"outputId": "64c1d7bc-2d24-42c5-d223-6ad12cfde19c"
|
||||
},
|
||||
"execution_count": 13,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"== Input details ==\n",
|
||||
"name: serving_default_input_1:0\n",
|
||||
"shape: [ 1 28 28 1]\n",
|
||||
"type: <class 'numpy.float32'>\n",
|
||||
"\n",
|
||||
"== Output details ==\n",
|
||||
"name: StatefulPartitionedCall:0\n",
|
||||
"shape: [ 1 10]\n",
|
||||
"type: <class 'numpy.float32'>\n",
|
||||
"0 <class 'type'> serving_default_input_1:0 (0,) (0,) (1, 28, 28, 1)\n",
|
||||
"1 <class 'type'> model/quant_flatten/Const (0,) (0,) (2,)\n",
|
||||
"2 <class 'type'> model/quant_dense/BiasAdd/ReadVariableOp (1,) (1,) (10,)\n",
|
||||
"3 <class 'type'> model/quant_batch_normalization_1/FusedBatchNormV3 (8,) (8,) (8,)\n",
|
||||
"4 <class 'type'> model/quant_batch_normalization/FusedBatchNormV3 (4,) (4,) (4,)\n",
|
||||
"5 <class 'type'> model/quantize_layer/AllValuesQuantize/FakeQuantWithMinMaxVars;model/quantize_layer/AllValuesQuantize/FakeQuantWithMinMaxVars/ReadVariableOp;model/quantize_layer/AllValuesQuantize/FakeQuantWithMinMaxVars/ReadVariableOp_1 (1,) (1,) (1, 28, 28, 1)\n",
|
||||
"6 <class 'type'> model/quant_conv2d/Conv2D;model/quant_conv2d/LastValueQuant/FakeQuantWithMinMaxVarsPerChannel (4,) (4,) (4, 3, 3, 1)\n",
|
||||
"7 <class 'type'> model/quant_re_lu/Relu;model/quant_batch_normalization/FusedBatchNormV3;model/quant_conv2d/Conv2D (1,) (1,) (1, 26, 26, 4)\n",
|
||||
"8 <class 'type'> model/quant_average_pooling2d/AvgPool (1,) (1,) (1, 13, 13, 4)\n",
|
||||
"9 <class 'type'> model/quant_average_pooling2d/MovingAvgQuantize/FakeQuantWithMinMaxVars;model/quant_average_pooling2d_1/MovingAvgQuantize/FakeQuantWithMinMaxVars/ReadVariableOp;model/quant_average_pooling2d/MovingAvgQuantize/FakeQuantWithMinMaxVars/ReadVariableOp_1 (1,) (1,) (1, 13, 13, 4)\n",
|
||||
"10 <class 'type'> model/quant_conv2d_1/Conv2D;model/quant_conv2d_1/LastValueQuant/FakeQuantWithMinMaxVarsPerChannel (8,) (8,) (8, 3, 3, 4)\n",
|
||||
"11 <class 'type'> model/quant_re_lu_1/Relu;model/quant_batch_normalization_1/FusedBatchNormV3;model/quant_conv2d_1/Conv2D (1,) (1,) (1, 11, 11, 8)\n",
|
||||
"12 <class 'type'> model/quant_average_pooling2d_1/AvgPool (1,) (1,) (1, 5, 5, 8)\n",
|
||||
"13 <class 'type'> model/quant_average_pooling2d_1/AvgPool1 (1,) (1,) (1, 5, 5, 8)\n",
|
||||
"14 <class 'type'> model/quant_flatten/Reshape;model/quant_average_pooling2d_1/MovingAvgQuantize/FakeQuantWithMinMaxVars (1,) (1,) (1, 200)\n",
|
||||
"15 <class 'type'> model/quant_dense/MatMul;model/quant_dense/LastValueQuant/FakeQuantWithMinMaxVars (1,) (1,) (10, 200)\n",
|
||||
"16 <class 'type'> model/quant_dense/MatMul;model/quant_dense/BiasAdd (1,) (1,) (1, 10)\n",
|
||||
"17 <class 'type'> model/quant_softmax/Softmax (1,) (1,) (1, 10)\n",
|
||||
"18 <class 'type'> StatefulPartitionedCall:0 (0,) (0,) (1, 10)\n",
|
||||
"25 <class 'type'> (0,) (0,) (1, 26, 26, 9)\n",
|
||||
"26 <class 'type'> (0,) (0,) (1, 11, 11, 36)\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"id": "PkoSWwvnIKsP"
|
||||
},
|
||||
"execution_count": 13,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user