mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-07 21:24:01 -05:00
v1.2.0 - Added SumPooling2D layer, added high accuracy (98%+) model for MNIST
This commit is contained in:
@@ -32,7 +32,7 @@ In `models/mnist_poly.ipynb`, a sample model of Conv2d-Poly-Dense layers was tra
|
||||
- Weights in the `Dense` layer were scaled by `10**9` time for precision again.
|
||||
- Biases in the `Dense` layer had been omitted for simplcity, since `ArgMax` layer is not affected by the biases. However, if the biases were to be included (for example in a deeper network as an intermediate layer), they would have to be scaled by `(10**9)**5=10**45` times to adjust correctly.
|
||||
|
||||
We can easily see that a deeper network would have to sacrifice precision, due to the limitation that Circom works under a finite field of modulo `p` which is around 254 bits. As `log(2**254)~76`, we need to make sure total scaling do not aggregate to exceed `10**76` (or even less) times. On average, a network with `l` layers should be scaled by less than or equal to `76//l` times.
|
||||
We can easily see that a deeper network would have to sacrifice precision, due to the limitation that Circom works under a finite field of modulo `p` which is around 254 bits. As `log(2**254)~76`, we need to make sure total scaling do not aggregate to exceed `10**76` (or even less) times. On average, a network with `l` layers should be scaled by less than or equal to `10**(76//l)` times.
|
||||
|
||||
## Circuits to be added:
|
||||
- max/sum-pooling
|
||||
## TODO:
|
||||
- add strides parameter to `Conv2D` and `SumPooling2D`
|
||||
@@ -4,7 +4,7 @@ include "./circomlib-matrix/matElemMul.circom";
|
||||
include "./circomlib-matrix/matElemSum.circom";
|
||||
include "./util.circom";
|
||||
|
||||
// Conv2D layer
|
||||
// Conv2D layer with strides=0
|
||||
template Conv2D (nRows, nCols, nChannels, nFilters, kernelSize) {
|
||||
signal input in[nRows][nCols][nChannels];
|
||||
signal input weights[kernelSize][kernelSize][nChannels][nFilters];
|
||||
|
||||
26
circuits/SumPooling2D.circom
Normal file
26
circuits/SumPooling2D.circom
Normal file
@@ -0,0 +1,26 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "./circomlib-matrix/matElemSum.circom";
|
||||
include "./util.circom";
|
||||
|
||||
// SumPooling2D layer, basically AveragePooling2D layer with a constant scaling, more optimized for circom, strides=poolSize like Keras default
|
||||
template SumPooling2D (nRows, nCols, nChannels, poolSize) {
|
||||
signal input in[nRows][nCols][nChannels];
|
||||
signal output out[nRows\poolSize][nCols\poolSize][nChannels];
|
||||
|
||||
component elemSum[nRows\poolSize][nCols\poolSize][nChannels];
|
||||
|
||||
for (var i=0; i<nRows\poolSize; i++) {
|
||||
for (var j=0; j<nCols\poolSize; j++) {
|
||||
for (var k=0; k<nChannels; k++) {
|
||||
elemSum[i][j][k] = matElemSum(poolSize,poolSize);
|
||||
for (var x=0; x<poolSize; x++) {
|
||||
for (var y=0; y<poolSize; y++) {
|
||||
elemSum[i][j][k].a[x][y] <== in[i*poolSize+x][j*poolSize+y][k];
|
||||
}
|
||||
}
|
||||
out[i][j][k] <== elemSum[i][j][k].out;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -126,31 +126,31 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10\n",
|
||||
"1875/1875 [==============================] - 2s 1ms/step - loss: 19.5098 - acc: 0.7534 - val_loss: 1.0343 - val_acc: 0.8552\n",
|
||||
"1875/1875 [==============================] - 1s 640us/step - loss: 19.4036 - acc: 0.7586 - val_loss: 1.0043 - val_acc: 0.8697\n",
|
||||
"Epoch 2/10\n",
|
||||
"1875/1875 [==============================] - 1s 749us/step - loss: 0.6325 - acc: 0.8769 - val_loss: 0.3121 - val_acc: 0.9113\n",
|
||||
"1875/1875 [==============================] - 1s 483us/step - loss: 0.6379 - acc: 0.8767 - val_loss: 0.3041 - val_acc: 0.9168\n",
|
||||
"Epoch 3/10\n",
|
||||
"1875/1875 [==============================] - 1s 621us/step - loss: 0.3079 - acc: 0.9126 - val_loss: 0.3085 - val_acc: 0.9119\n",
|
||||
"1875/1875 [==============================] - 1s 461us/step - loss: 0.3112 - acc: 0.9124 - val_loss: 0.2960 - val_acc: 0.9166\n",
|
||||
"Epoch 4/10\n",
|
||||
"1875/1875 [==============================] - 1s 686us/step - loss: 0.3070 - acc: 0.9111 - val_loss: 0.2962 - val_acc: 0.9152\n",
|
||||
"1875/1875 [==============================] - 1s 486us/step - loss: 0.3033 - acc: 0.9151 - val_loss: 0.2906 - val_acc: 0.9196\n",
|
||||
"Epoch 5/10\n",
|
||||
"1875/1875 [==============================] - 2s 932us/step - loss: 0.3172 - acc: 0.9097 - val_loss: 0.3112 - val_acc: 0.9102\n",
|
||||
"1875/1875 [==============================] - 1s 462us/step - loss: 0.3027 - acc: 0.9152 - val_loss: 0.2918 - val_acc: 0.9184\n",
|
||||
"Epoch 6/10\n",
|
||||
"1875/1875 [==============================] - 1s 653us/step - loss: 0.2994 - acc: 0.9140 - val_loss: 0.2838 - val_acc: 0.9236\n",
|
||||
"1875/1875 [==============================] - 1s 499us/step - loss: 0.3119 - acc: 0.9163 - val_loss: 0.2949 - val_acc: 0.9206\n",
|
||||
"Epoch 7/10\n",
|
||||
"1875/1875 [==============================] - 1s 610us/step - loss: 0.2918 - acc: 0.9176 - val_loss: 0.3029 - val_acc: 0.9172\n",
|
||||
"1875/1875 [==============================] - 1s 459us/step - loss: 0.3006 - acc: 0.9178 - val_loss: 0.2878 - val_acc: 0.9163\n",
|
||||
"Epoch 8/10\n",
|
||||
"1875/1875 [==============================] - 1s 684us/step - loss: 0.2884 - acc: 0.9199 - val_loss: 0.4005 - val_acc: 0.9010\n",
|
||||
"1875/1875 [==============================] - 1s 481us/step - loss: 0.2898 - acc: 0.9219 - val_loss: 0.2750 - val_acc: 0.9246\n",
|
||||
"Epoch 9/10\n",
|
||||
"1875/1875 [==============================] - 1s 749us/step - loss: 0.2869 - acc: 0.9196 - val_loss: 0.2758 - val_acc: 0.9234\n",
|
||||
"1875/1875 [==============================] - 1s 457us/step - loss: 0.2803 - acc: 0.9232 - val_loss: 0.2823 - val_acc: 0.9243\n",
|
||||
"Epoch 10/10\n",
|
||||
"1875/1875 [==============================] - 1s 623us/step - loss: 0.2813 - acc: 0.9190 - val_loss: 0.2861 - val_acc: 0.9244\n"
|
||||
"1875/1875 [==============================] - 1s 483us/step - loss: 0.2751 - acc: 0.9244 - val_loss: 0.3078 - val_acc: 0.9240\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<tensorflow.python.keras.callbacks.History at 0x12e8a7a90>"
|
||||
"<tensorflow.python.keras.callbacks.History at 0x156182c70>"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
@@ -200,8 +200,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[ -6.9767556, -20.781387 , -4.4705815, 5.46858 , -4.1251545,\n",
|
||||
" -6.630072 , -19.10683 , 13.8626585, -1.8164067, 1.7395192]],\n",
|
||||
"array([[ -4.7280216, -14.685291 , -1.7771893, 4.1886683, -5.325138 ,\n",
|
||||
" -2.5930152, -11.220819 , 8.523948 , -1.0928547, 1.6447238]],\n",
|
||||
" dtype=float32)"
|
||||
]
|
||||
},
|
||||
@@ -211,7 +211,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model2.predict(X_test[[0]])"
|
||||
"model2.predict(X_test[[0]]) - model.weights[3].numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -222,9 +222,9 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[8.9013635e-10, 8.9987182e-16, 1.0911241e-08, 2.2615024e-04,\n",
|
||||
" 1.5413157e-08, 1.2589813e-09, 4.8021055e-15, 9.9976820e-01,\n",
|
||||
" 1.5508422e-07, 5.4310872e-06]], dtype=float32)"
|
||||
"array([[9.0031011e-07, 1.0410367e-10, 2.1521399e-05, 5.5709118e-03,\n",
|
||||
" 8.1659567e-07, 1.7811672e-05, 1.4445434e-09, 9.9381268e-01,\n",
|
||||
" 1.9175804e-05, 5.5624702e-04]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
@@ -244,7 +244,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<matplotlib.image.AxesImage at 0x138d48160>"
|
||||
"<matplotlib.image.AxesImage at 0x157410be0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
@@ -277,43 +277,43 @@
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 1) dtype=float32, numpy=\n",
|
||||
" array([[[[ 0.00499532]],\n",
|
||||
" array([[[[-0.0031894 ]],\n",
|
||||
" \n",
|
||||
" [[ 0.00937087]],\n",
|
||||
" [[ 0.01395892]],\n",
|
||||
" \n",
|
||||
" [[-0.01510819]]],\n",
|
||||
" [[-0.01202957]]],\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" [[[ 0.01347945]],\n",
|
||||
" [[[ 0.01273286]],\n",
|
||||
" \n",
|
||||
" [[-0.0076614 ]],\n",
|
||||
" [[ 0.00727907]],\n",
|
||||
" \n",
|
||||
" [[-0.00259014]]],\n",
|
||||
" [[-0.00096878]]],\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" [[[ 0.00545743]],\n",
|
||||
" [[[-0.01515301]],\n",
|
||||
" \n",
|
||||
" [[-0.00029777]],\n",
|
||||
" [[ 0.00046251]],\n",
|
||||
" \n",
|
||||
" [[ 0.00870595]]]], dtype=float32)>,\n",
|
||||
" <tf.Variable 'conv2d/bias:0' shape=(1,) dtype=float32, numpy=array([0.01775588], dtype=float32)>,\n",
|
||||
" [[ 0.00114259]]]], dtype=float32)>,\n",
|
||||
" <tf.Variable 'conv2d/bias:0' shape=(1,) dtype=float32, numpy=array([0.01268097], dtype=float32)>,\n",
|
||||
" <tf.Variable 'dense/kernel:0' shape=(676, 10) dtype=float32, numpy=\n",
|
||||
" array([[-0.2504781 , 0.46691304, -0.07233011, ..., -0.00884265,\n",
|
||||
" -0.3515338 , -0.04016368],\n",
|
||||
" [-0.07648478, 0.4378654 , 0.04816974, ..., 0.14323361,\n",
|
||||
" -0.27090573, 0.00745438],\n",
|
||||
" [-0.03752105, 0.35444894, -0.11057906, ..., 0.01232609,\n",
|
||||
" -0.3564083 , -0.09844871],\n",
|
||||
" array([[-0.20955266, 0.40033442, -0.1978653 , ..., 0.32269812,\n",
|
||||
" -0.2694615 , 0.06487904],\n",
|
||||
" [-0.07851421, 0.36309943, -0.11626323, ..., 0.15947255,\n",
|
||||
" -0.3209268 , 0.13795587],\n",
|
||||
" [-0.21611924, 0.3673279 , -0.12897709, ..., 0.32853407,\n",
|
||||
" -0.4003139 , 0.03121791],\n",
|
||||
" ...,\n",
|
||||
" [-0.23108476, 0.13397478, -0.25663534, ..., 0.12351229,\n",
|
||||
" -0.24287994, -0.02877483],\n",
|
||||
" [-0.2296395 , 0.33206815, -0.11455406, ..., 0.07636273,\n",
|
||||
" -0.45205915, -0.12896551],\n",
|
||||
" [-0.09236625, 0.41229486, 0.17483838, ..., 0.07781664,\n",
|
||||
" -0.35363778, -0.1946791 ]], dtype=float32)>,\n",
|
||||
" [-0.22975925, 0.25338897, -0.34073418, ..., -0.11573094,\n",
|
||||
" -0.19876844, 0.06433479],\n",
|
||||
" [-0.23668756, 0.3116613 , -0.12199575, ..., 0.19775279,\n",
|
||||
" -0.16193576, -0.01519678],\n",
|
||||
" [-0.22618775, 0.45856324, 0.08593661, ..., 0.2934398 ,\n",
|
||||
" -0.3539578 , -0.04992562]], dtype=float32)>,\n",
|
||||
" <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=\n",
|
||||
" array([-0.4428678 , 0.6187441 , 0.33343422, -0.2756751 , 0.01635278,\n",
|
||||
" 0.9593875 , -0.06502363, 0.39655322, -0.9115569 , -0.3696858 ],\n",
|
||||
" array([-0.2101967 , 0.681965 , 0.01303466, -0.39655647, 0.28932407,\n",
|
||||
" 0.6396667 , -0.15234365, 0.45215404, -0.7866978 , -0.15671334],\n",
|
||||
" dtype=float32)>]"
|
||||
]
|
||||
},
|
||||
|
||||
408
models/mnist_convnet.ipynb
Normal file
408
models/mnist_convnet.ipynb
Normal file
@@ -0,0 +1,408 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# model architecture modified from https://keras.io/examples/vision/mnist_convnet/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Lambda, Softmax, Dense\n",
|
||||
"from tensorflow.keras import Model\n",
|
||||
"from tensorflow.keras.datasets import mnist\n",
|
||||
"from tensorflow.keras.utils import to_categorical\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import tensorflow as tf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"(X_train, y_train), (X_test, y_test) = mnist.load_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert y_train into one-hot format\n",
|
||||
"temp = []\n",
|
||||
"for i in range(len(y_train)):\n",
|
||||
" temp.append(to_categorical(y_train[i], num_classes=10))\n",
|
||||
"y_train = np.array(temp)\n",
|
||||
"# Convert y_test into one-hot format\n",
|
||||
"temp = []\n",
|
||||
"for i in range(len(y_test)): \n",
|
||||
" temp.append(to_categorical(y_test[i], num_classes=10))\n",
|
||||
"y_test = np.array(temp)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#reshaping\n",
|
||||
"X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)\n",
|
||||
"X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = Input(shape=(28,28,1))\n",
|
||||
"out = Lambda(lambda x: x/1000)(inputs)\n",
|
||||
"out = Conv2D(4, 3)(out)\n",
|
||||
"out = Lambda(lambda x: x**2+x)(out)\n",
|
||||
"out = AveragePooling2D()(out)\n",
|
||||
"out = Lambda(lambda x: x*4)(out)\n",
|
||||
"out = Conv2D(8, 3)(out)\n",
|
||||
"out = Lambda(lambda x: x**2+x)(out)\n",
|
||||
"out = AveragePooling2D()(out)\n",
|
||||
"out = Lambda(lambda x: x*4)(out)\n",
|
||||
"out = Flatten()(out)\n",
|
||||
"out = Dense(10, activation=None)(out)\n",
|
||||
"out = Softmax()(out)\n",
|
||||
"model = Model(inputs, out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
"Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
"input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda (Lambda) (None, 28, 28, 1) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"conv2d (Conv2D) (None, 26, 26, 4) 40 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda_1 (Lambda) (None, 26, 26, 4) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"average_pooling2d (AveragePo (None, 13, 13, 4) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda_2 (Lambda) (None, 13, 13, 4) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"conv2d_1 (Conv2D) (None, 11, 11, 8) 296 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda_3 (Lambda) (None, 11, 11, 8) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"average_pooling2d_1 (Average (None, 5, 5, 8) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda_4 (Lambda) (None, 5, 5, 8) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"flatten (Flatten) (None, 200) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"dense (Dense) (None, 10) 2010 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"softmax (Softmax) (None, 10) 0 \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 2,346\n",
|
||||
"Trainable params: 2,346\n",
|
||||
"Non-trainable params: 0\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.compile(\n",
|
||||
" loss='categorical_crossentropy',\n",
|
||||
" optimizer='adam',\n",
|
||||
" metrics=['acc']\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/15\n",
|
||||
"469/469 [==============================] - 5s 10ms/step - loss: 0.9272 - acc: 0.7005 - val_loss: 0.2338 - val_acc: 0.9316\n",
|
||||
"Epoch 2/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.2102 - acc: 0.9396 - val_loss: 0.1324 - val_acc: 0.9587\n",
|
||||
"Epoch 3/15\n",
|
||||
"469/469 [==============================] - 5s 10ms/step - loss: 0.1367 - acc: 0.9605 - val_loss: 0.0977 - val_acc: 0.9709\n",
|
||||
"Epoch 4/15\n",
|
||||
"469/469 [==============================] - 5s 10ms/step - loss: 0.1066 - acc: 0.9684 - val_loss: 0.0837 - val_acc: 0.9762\n",
|
||||
"Epoch 5/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0962 - acc: 0.9727 - val_loss: 0.0762 - val_acc: 0.9776\n",
|
||||
"Epoch 6/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0830 - acc: 0.9755 - val_loss: 0.0709 - val_acc: 0.9788\n",
|
||||
"Epoch 7/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0766 - acc: 0.9764 - val_loss: 0.0639 - val_acc: 0.9806\n",
|
||||
"Epoch 8/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0702 - acc: 0.9790 - val_loss: 0.0620 - val_acc: 0.9815\n",
|
||||
"Epoch 9/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0703 - acc: 0.9782 - val_loss: 0.0573 - val_acc: 0.9837\n",
|
||||
"Epoch 10/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0639 - acc: 0.9802 - val_loss: 0.0570 - val_acc: 0.9821\n",
|
||||
"Epoch 11/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0627 - acc: 0.9804 - val_loss: 0.0535 - val_acc: 0.9839\n",
|
||||
"Epoch 12/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0607 - acc: 0.9813 - val_loss: 0.0515 - val_acc: 0.9832\n",
|
||||
"Epoch 13/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0562 - acc: 0.9829 - val_loss: 0.0498 - val_acc: 0.9852\n",
|
||||
"Epoch 14/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0538 - acc: 0.9834 - val_loss: 0.0493 - val_acc: 0.9843\n",
|
||||
"Epoch 15/15\n",
|
||||
"469/469 [==============================] - 4s 9ms/step - loss: 0.0528 - acc: 0.9841 - val_loss: 0.0489 - val_acc: 0.9838\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<tensorflow.python.keras.callbacks.History at 0x12fe77370>"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.fit(X_train, y_train, epochs=15, batch_size=128, validation_data=(X_test, y_test))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"((28, 28, 1), 0, 255)"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = X_test[0]\n",
|
||||
"X.shape, X.min(), X.max()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model2 = Model(model.input, model.layers[-2].output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[ -8.095224 , -5.8302927, -1.2153628, 2.650765 , -19.186575 ,\n",
|
||||
" -5.7322216, -26.104668 , 15.262588 , -4.949901 , -0.8113966]],\n",
|
||||
" dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model2.predict(X_test[[0]]) - model.weights[5].numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[7.0393115e-11, 7.1936546e-10, 6.7945763e-08, 2.8899435e-06,\n",
|
||||
" 9.6229656e-16, 7.9535123e-10, 1.0399456e-18, 9.9999690e-01,\n",
|
||||
" 1.6393171e-09, 9.2139523e-08]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.predict(X_test[[0]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<matplotlib.image.AxesImage at 0x13a72d970>"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANh0lEQVR4nO3df6zddX3H8dfL/sJeYFKwtSuVKqKxOsHlCppuSw3DAYYUo2w0GekSZskGCSxmG2ExkmxxjIiETWdSR2clCFOBQLRzksaNkLHKhZRSKFuRdVh71wvUrUXgtqXv/XG/LJdyz+dezvd7zve07+cjuTnnfN/ne77vfHtf/X7v+XzP+TgiBODY95a2GwDQH4QdSIKwA0kQdiAJwg4kMbufG5vreXGchvq5SSCVV/QLHYhxT1WrFXbb50u6RdIsSX8XETeUnn+chnSOz62zSQAFm2NTx1rXp/G2Z0n6qqQLJC2XtNr28m5fD0Bv1fmb/WxJT0fEMxFxQNKdklY10xaAptUJ+xJJP530eFe17HVsr7U9YnvkoMZrbA5AHXXCPtWbAG+49jYi1kXEcEQMz9G8GpsDUEedsO+StHTS41Ml7a7XDoBeqRP2hyWdYftdtudKulTSfc20BaBpXQ+9RcQh21dJ+idNDL2tj4gnGusMQKNqjbNHxEZJGxvqBUAPcbkskARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlaUzbb3ilpv6RXJR2KiOEmmgLQvFphr3w8Ip5v4HUA9BCn8UASdcMekn5o+xHba6d6gu21tkdsjxzUeM3NAehW3dP4FRGx2/ZCSffbfioiHpj8hIhYJ2mdJJ3oBVFzewC6VOvIHhG7q9sxSfdIOruJpgA0r+uw2x6yfcJr9yV9QtK2phoD0Kw6p/GLJN1j+7XX+VZE/KCRrgA0ruuwR8Qzks5ssBcAPcTQG5AEYQeSIOxAEoQdSIKwA0k08UGYFF747Mc61t552dPFdZ8aW1SsHxifU6wvuaNcn7/rxY61w1ueLK6LPDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLPP0J/88bc61j499PPyyqfX3PjKcnnnoZc61m557uM1N370+vHYaR1rQzf9UnHd2Zseabqd1nFkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkHNG/SVpO9II4x+f2bXtN+sVnzulYe/5D5f8zT9pe3sc/f7+L9bkf+p9i/cYP3t2xdt5bXy6u+/2Xji/WPzm/82fl63o5DhTrm8eHivWVxx3setvv+f4Vxfp71z7c9Wu3aXNs0r7YO+UvFEd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCz7PP0NB3Nxdq9V77xHqr62/esbJj7S9WLCtv+1/K33l/48r3dNHRzMx++XCxPrR1tFg/+YG7ivVfmdv5+/bn7yx/F/+xaNoju+31tsdsb5u0bIHt+23vqG5P6m2bAOqayWn8NySdf8SyayVtiogzJG2qHgMYYNOGPSIekLT3iMWrJG2o7m+QdHGzbQFoWrdv0C2KiFFJqm4Xdnqi7bW2R2yPHNR4l5sDUFfP342PiHURMRwRw3M0r9ebA9BBt2HfY3uxJFW3Y821BKAXug37fZLWVPfXSLq3mXYA9Mq04+y279DEN5efYnuXpC9IukHSt21fLulZSZf0skmUHfrvPR1rQ3d1rknSq9O89tB3X+iio2bs+f2PFesfmFv+9f3S3vd1rC37+2eK6x4qVo9O04Y9IlZ3KB2d30IBJMXlskAShB1IgrADSRB2IAnCDiTBR1zRmtmnLS3Wv3LdV4r1OZ5VrH/nlt/sWDt59KHiuscijuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7GjNU3+0pFj/yLzyVNZPHChPR73gyZfedE/HMo7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+zoqfFPfqRj7dHP3DzN2uUZhP7g6quL9bf+64+nef1cOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs6Onnr2g8/HkeJfH0Vf/53nF+vwfPFasR7Gaz7RHdtvrbY/Z3jZp2fW2f2Z7S/VzYW/bBFDXTE7jvyHp/CmW3xwRZ1U/G5ttC0DTpg17RDwgaW8fegHQQ3XeoLvK9tbqNP+kTk+yvdb2iO2RgxqvsTkAdXQb9q9JOl3SWZJGJd3U6YkRsS4ihiNieM40H2wA0DtdhT0i9kTEqxFxWNLXJZ3dbFsAmtZV2G0vnvTwU5K2dXougMEw7Ti77TskrZR0iu1dkr4gaaXtszQxlLlT0hW9axGD7C0nnFCsX/brD3as7Tv8SnHdsS++u1ifN/5wsY7XmzbsEbF6isW39qAXAD3E5bJAEoQdSIKwA0kQdiAJwg4kwUdcUcuO6z9QrH/vlL/tWFu149PFdedtZGitSRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlR9L+/+9Fifevv/HWx/pNDBzvWXvyrU4vrztNosY43hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHtys5f8crF+zef/oVif5/Kv0KWPXdax9vZ/5PPq/cSRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJz9GOfZ5X/iM7+3q1i/5PgXivXb9y8s1hd9vvPx5HBxTTRt2iO77aW2f2R7u+0nbF9dLV9g+37bO6rbk3rfLoBuzeQ0/pCkz0XE+yV9VNKVtpdLulbSpog4Q9Km6jGAATVt2CNiNCIere7vl7Rd0hJJqyRtqJ62QdLFPeoRQAPe1Bt0tpdJ+rCkzZIWRcSoNPEfgqQp/3izvdb2iO2Rgxqv2S6Abs047LaPl3SXpGsiYt9M14uIdRExHBHDczSvmx4BNGBGYbc9RxNBvz0i7q4W77G9uKovljTWmxYBNGHaoTfblnSrpO0R8eVJpfskrZF0Q3V7b086RD1nvq9Y/vOFt9V6+a9+8ZJi/W2PPVTr9dGcmYyzr5B0maTHbW+pll2niZB/2/blkp6VVP5XB9CqacMeEQ9Kcofyuc22A6BXuFwWSIKwA0kQdiAJwg4kQdiBJPiI6zFg1vL3dqytvbPe5Q/L119ZrC+77d9qvT76hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsx4Kk/7PzFvhfNn/GXCk3p1H8+UH5CRK3XR/9wZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwq8ctHZxfqmi24qVOc32wyOWhzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJmczPvlTSNyW9Q9JhSesi4hbb10v6rKTnqqdeFxEbe9VoZrtXzCrW3zm7+7H02/cvLNbn7Ct/np1Psx89ZnJRzSFJn4uIR22fIOkR2/dXtZsj4ku9aw9AU2YyP/uopNHq/n7b2yUt6XVjAJr1pv5mt71M0oclba4WXWV7q+31tqf8biTba22P2B45qPF63QLo2ozDbvt4SXdJuiYi9kn6mqTTJZ2liSP/lBdoR8S6iBiOiOE5mle/YwBdmVHYbc/RRNBvj4i7JSki9kTEqxFxWNLXJZU/rQGgVdOG3bYl3Sppe0R8edLyxZOe9ilJ25pvD0BTZvJu/ApJl0l63PaWatl1klbbPksToy87JV3Rg/5Q01++sLxYf+i3lhXrMfp4g92gTTN5N/5BSZ6ixJg6cBThCjogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Trl7ohfEOT63b9sDstkcm7Qv9k41VM6RHciCsANJEHYgCcIOJEHYgSQIO5AEYQeS6Os4u+3nJP3XpEWnSHq+bw28OYPa26D2JdFbt5rs7bSIePtUhb6G/Q0bt0ciYri1BgoGtbdB7Uuit271qzdO44EkCDuQRNthX9fy9ksGtbdB7Uuit271pbdW/2YH0D9tH9kB9AlhB5JoJey2z7f977aftn1tGz10Ynun7cdtb7E90nIv622P2d42adkC2/fb3lHdTjnHXku9XW/7Z9W+22L7wpZ6W2r7R7a3237C9tXV8lb3XaGvvuy3vv/NbnuWpP+QdJ6kXZIelrQ6Ip7sayMd2N4paTgiWr8Aw/ZvSHpR0jcj4oPVshsl7Y2IG6r/KE+KiD8dkN6ul/Ri29N4V7MVLZ48zbikiyX9nlrcd4W+flt92G9tHNnPlvR0RDwTEQck3SlpVQt9DLyIeEDS3iMWr5K0obq/QRO/LH3XobeBEBGjEfFodX+/pNemGW913xX66os2wr5E0k8nPd6lwZrvPST90PYjtte23cwUFkXEqDTxyyNpYcv9HGnaabz76Yhpxgdm33Uz/XldbYR9qu/HGqTxvxUR8auSLpB0ZXW6ipmZ0TTe/TLFNOMDodvpz+tqI+y7JC2d9PhUSbtb6GNKEbG7uh2TdI8GbyrqPa/NoFvdjrXcz/8bpGm8p5pmXAOw79qc/ryNsD8s6Qzb77I9V9Klku5roY83sD1UvXEi20OSPqHBm4r6PklrqvtrJN3bYi+vMyjTeHeaZlwt77vWpz+PiL7/SLpQE+/I/0TSn7XRQ4e+3i3psernibZ7k3SHJk7rDmrijOhySSdL2iRpR3W7YIB6u03S45K2aiJYi1vq7dc08afhVklbqp8L2953hb76st+4XBZIgivogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wNGNvRI2D7VDgAAAABJRU5ErkJggg==",
|
||||
"text/plain": [
|
||||
"<Figure size 432x288 with 1 Axes>"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"needs_background": "light"
|
||||
},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"plt.imshow(X)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"6\n",
|
||||
"(3, 3, 1, 4)\n",
|
||||
"(4,)\n",
|
||||
"(3, 3, 4, 8)\n",
|
||||
"(8,)\n",
|
||||
"(200, 10)\n",
|
||||
"(10,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(len(model.weights))\n",
|
||||
"for weights in model.weights:\n",
|
||||
" print(weights.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_json = {\n",
|
||||
" \"in\": X.astype(int).flatten().tolist(), # X is already 1000 times to begin with\n",
|
||||
" \"conv2d_1_weights\": (model.weights[0].numpy()*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" \"conv2d_1_bias\": (model.weights[1].numpy()*(10**3)*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" # poly layer would be (10**3)**2=10**6 times as well\n",
|
||||
" \"conv2d_2_weights\": (model.weights[2].numpy()*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" \"conv2d_2_bias\": (model.weights[3].numpy()*((10**3)**5)).round().astype(int).flatten().tolist(),\n",
|
||||
" # poly layer would be (10**3)**5=10**15 times as well\n",
|
||||
" \"dense_weights\":(model.weights[4].numpy()*(10**3)).round().astype(int).flatten().tolist(),\n",
|
||||
" \"dense_bias\": np.zeros(model.weights[5].numpy().shape).tolist() # zero because we are not doing softmax in circom, just argmax\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"mnist_convnet_input.json\", \"w\") as f:\n",
|
||||
" json.dump(in_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "11280bdb37aa6bc5d4cf1e4de756386eb1f9eecd8dcdefa77636dfac7be2370d"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.6 ('tf24')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
1
models/mnist_convnet_input.json
Normal file
1
models/mnist_convnet_input.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -130,31 +130,31 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10\n",
|
||||
"1875/1875 [==============================] - 2s 977us/step - loss: 0.8025 - acc: 0.7974 - val_loss: 0.3139 - val_acc: 0.9094\n",
|
||||
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.7945 - acc: 0.7904 - val_loss: 0.2996 - val_acc: 0.9128\n",
|
||||
"Epoch 2/10\n",
|
||||
"1875/1875 [==============================] - 2s 935us/step - loss: 0.3174 - acc: 0.9057 - val_loss: 0.2842 - val_acc: 0.9156\n",
|
||||
"1875/1875 [==============================] - 2s 891us/step - loss: 0.3095 - acc: 0.9074 - val_loss: 0.2835 - val_acc: 0.9179\n",
|
||||
"Epoch 3/10\n",
|
||||
"1875/1875 [==============================] - 2s 950us/step - loss: 0.2982 - acc: 0.9117 - val_loss: 0.2834 - val_acc: 0.9201\n",
|
||||
"1875/1875 [==============================] - 2s 892us/step - loss: 0.2993 - acc: 0.9119 - val_loss: 0.2725 - val_acc: 0.9225\n",
|
||||
"Epoch 4/10\n",
|
||||
"1875/1875 [==============================] - 2s 941us/step - loss: 0.2956 - acc: 0.9119 - val_loss: 0.2831 - val_acc: 0.9172\n",
|
||||
"1875/1875 [==============================] - 2s 890us/step - loss: 0.2810 - acc: 0.9160 - val_loss: 0.2666 - val_acc: 0.9199\n",
|
||||
"Epoch 5/10\n",
|
||||
"1875/1875 [==============================] - 2s 956us/step - loss: 0.2800 - acc: 0.9165 - val_loss: 0.2655 - val_acc: 0.9233\n",
|
||||
"1875/1875 [==============================] - 2s 901us/step - loss: 0.2723 - acc: 0.9192 - val_loss: 0.2606 - val_acc: 0.9231\n",
|
||||
"Epoch 6/10\n",
|
||||
"1875/1875 [==============================] - 2s 871us/step - loss: 0.2773 - acc: 0.9182 - val_loss: 0.2695 - val_acc: 0.9217\n",
|
||||
"1875/1875 [==============================] - 2s 895us/step - loss: 0.2691 - acc: 0.9205 - val_loss: 0.2595 - val_acc: 0.9243\n",
|
||||
"Epoch 7/10\n",
|
||||
"1875/1875 [==============================] - 2s 859us/step - loss: 0.2782 - acc: 0.9174 - val_loss: 0.2653 - val_acc: 0.9219\n",
|
||||
"1875/1875 [==============================] - 2s 898us/step - loss: 0.2608 - acc: 0.9220 - val_loss: 0.2535 - val_acc: 0.9267\n",
|
||||
"Epoch 8/10\n",
|
||||
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2694 - acc: 0.9211 - val_loss: 0.2658 - val_acc: 0.9226\n",
|
||||
"1875/1875 [==============================] - 2s 948us/step - loss: 0.2593 - acc: 0.9243 - val_loss: 0.2546 - val_acc: 0.9277\n",
|
||||
"Epoch 9/10\n",
|
||||
"1875/1875 [==============================] - 2s 970us/step - loss: 0.2620 - acc: 0.9217 - val_loss: 0.2626 - val_acc: 0.9240\n",
|
||||
"1875/1875 [==============================] - 2s 903us/step - loss: 0.2557 - acc: 0.9241 - val_loss: 0.2650 - val_acc: 0.9229\n",
|
||||
"Epoch 10/10\n",
|
||||
"1875/1875 [==============================] - 2s 975us/step - loss: 0.2613 - acc: 0.9238 - val_loss: 0.2727 - val_acc: 0.9219\n"
|
||||
"1875/1875 [==============================] - 2s 892us/step - loss: 0.2517 - acc: 0.9256 - val_loss: 0.2534 - val_acc: 0.9244\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<tensorflow.python.keras.callbacks.History at 0x11e088e20>"
|
||||
"<tensorflow.python.keras.callbacks.History at 0x16c9ca5b0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
@@ -204,9 +204,9 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[ -0.47087878, -6.194486 , 3.474044 , 8.786284 ,\n",
|
||||
" -0.59760684, 1.9168414 , -10.066298 , 15.407355 ,\n",
|
||||
" -3.065394 , 3.696538 ]], dtype=float32)"
|
||||
"array([[-0.6522014 , -4.587718 , 3.597416 , 9.337387 , -0.67536736,\n",
|
||||
" 1.8601764 , -9.597713 , 16.50223 , -2.1612463 , 5.6732273 ]],\n",
|
||||
" dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
@@ -215,7 +215,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model2.predict(X_test[[0]])"
|
||||
"model2.predict(X_test[[0]]) - model.weights[3].numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -226,9 +226,9 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[1.26936129e-07, 4.14814960e-10, 6.55908707e-06, 1.33021001e-03,\n",
|
||||
" 1.11827255e-07, 1.38216228e-06, 8.63670854e-12, 9.98653293e-01,\n",
|
||||
" 9.47985068e-09, 8.19353590e-06]], dtype=float32)"
|
||||
"array([[3.4876599e-08, 6.4541522e-10, 2.5880247e-06, 8.2364457e-04,\n",
|
||||
" 3.2681388e-08, 4.5281831e-07, 4.5530966e-12, 9.9915338e-01,\n",
|
||||
" 8.1701694e-09, 1.9931980e-05]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
@@ -248,7 +248,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<matplotlib.image.AxesImage at 0x12f167bb0>"
|
||||
"<matplotlib.image.AxesImage at 0x177f54d90>"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
@@ -277,48 +277,24 @@
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"4\n",
|
||||
"(3, 3, 1, 1)\n",
|
||||
"(1,)\n",
|
||||
"(676, 10)\n",
|
||||
"(10,)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 1) dtype=float32, numpy=\n",
|
||||
" array([[[[ 0.52698594]],\n",
|
||||
" \n",
|
||||
" [[ 0.08442891]],\n",
|
||||
" \n",
|
||||
" [[ 0.01869087]]],\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" [[[-1.3245686 ]],\n",
|
||||
" \n",
|
||||
" [[-1.3917689 ]],\n",
|
||||
" \n",
|
||||
" [[-1.8389475 ]]],\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" [[[-0.27898604]],\n",
|
||||
" \n",
|
||||
" [[-0.448968 ]],\n",
|
||||
" \n",
|
||||
" [[-0.3638724 ]]]], dtype=float32)>,\n",
|
||||
" <tf.Variable 'conv2d/bias:0' shape=(1,) dtype=float32, numpy=array([0.3618775], dtype=float32)>,\n",
|
||||
" <tf.Variable 'dense/kernel:0' shape=(676, 10) dtype=float32, numpy=\n",
|
||||
" array([[-0.03674026, -0.18812837, 0.01979426, ..., 0.05463602,\n",
|
||||
" 0.01662535, -0.00871159],\n",
|
||||
" [ 0.02171878, -0.18518244, 0.11909918, ..., -0.03949559,\n",
|
||||
" 0.02754857, 0.0684126 ],\n",
|
||||
" [-0.01097946, -0.07011281, 0.12056817, ..., -0.05811585,\n",
|
||||
" 0.09220186, -0.07498543],\n",
|
||||
" ...,\n",
|
||||
" [-0.0361205 , -0.14020608, 0.12612993, ..., -0.09845617,\n",
|
||||
" -0.01772444, 0.07013445],\n",
|
||||
" [-0.07452048, -0.2042869 , 0.07853597, ..., 0.01920246,\n",
|
||||
" -0.02843344, -0.07306738],\n",
|
||||
" [ 0.00806268, -0.0407033 , 0.0513223 , ..., -0.0267654 ,\n",
|
||||
" 0.11145967, -0.10121571]], dtype=float32)>,\n",
|
||||
" <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=\n",
|
||||
" array([-0.01817241, -0.07616118, 0.054643 , 0.06845271, -0.05770804,\n",
|
||||
" 0.0395476 , -0.03720428, -0.03012715, 0.02956592, -0.02378519],\n",
|
||||
" dtype=float32)>]"
|
||||
"<tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=\n",
|
||||
"array([-0.0237312 , -0.0778916 , 0.03348656, 0.05635946, -0.06557445,\n",
|
||||
" 0.02756624, -0.02198206, -0.00755905, 0.03398759, -0.00089445],\n",
|
||||
" dtype=float32)>"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
@@ -327,7 +303,10 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.weights"
|
||||
"print(len(model.weights))\n",
|
||||
"for weights in model.weights:\n",
|
||||
" print(weights.shape)\n",
|
||||
"model.weights[3]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
File diff suppressed because one or more lines are too long
1
models/sumPooling2D_input.json
Normal file
1
models/sumPooling2D_input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"in": [831, 157, 238, 3, 837, 953, 307, 219, 149, 136, 987, 121, 472, 514, 749, 747, 253, 504, 145, 8, 436, 699, 775, 404, 295, 994, 850, 888, 339, 962, 765, 591, 71, 944, 170, 448, 544, 838, 952, 305, 152, 947, 908, 589, 19, 701, 431, 710, 206, 201, 230, 820, 168, 674, 101, 356, 367, 587, 793, 933, 780, 463, 547, 204, 375, 599, 529, 842, 30, 291, 104, 536, 252, 833, 670]}
|
||||
1
models/sumPooling2D_output.json
Normal file
1
models/sumPooling2D_output.json
Normal file
@@ -0,0 +1 @@
|
||||
{"out": [1726, 1255, 2131, 1438, 2975, 1523, 2616, 1393, 1458, 1770, 1514, 2940]}
|
||||
211
models/sumPooling2d.ipynb
Normal file
211
models/sumPooling2d.ipynb
Normal file
@@ -0,0 +1,211 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.keras.layers import Input, AveragePooling2D, Lambda\n",
|
||||
"from tensorflow.keras import Model\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = Input(shape=(5,5,3))\n",
|
||||
"x = AveragePooling2D(pool_size=2)(inputs)\n",
|
||||
"x = Lambda(lambda x: x*4)(x)\n",
|
||||
"model = Model(inputs, x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
"Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
"input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"average_pooling2d (AveragePo (None, 2, 2, 3) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"lambda (Lambda) (None, 2, 2, 3) 0 \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 0\n",
|
||||
"Trainable params: 0\n",
|
||||
"Non-trainable params: 0\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[0.83128186, 0.15650764, 0.23798145],\n",
|
||||
" [0.00277366, 0.8374127 , 0.95278315],\n",
|
||||
" [0.3074389 , 0.21931738, 0.14886067],\n",
|
||||
" [0.13590018, 0.98728255, 0.12085182],\n",
|
||||
" [0.47212572, 0.51380922, 0.74891219]],\n",
|
||||
"\n",
|
||||
" [[0.74680338, 0.2533205 , 0.5039968 ],\n",
|
||||
" [0.14475403, 0.00791911, 0.4361197 ],\n",
|
||||
" [0.69925568, 0.77507624, 0.40388991],\n",
|
||||
" [0.29508251, 0.99375606, 0.84959701],\n",
|
||||
" [0.88844918, 0.33910189, 0.9617212 ]],\n",
|
||||
"\n",
|
||||
" [[0.76480625, 0.591287 , 0.0714191 ],\n",
|
||||
" [0.94371681, 0.1695303 , 0.4476252 ],\n",
|
||||
" [0.54372616, 0.83818804, 0.95211573],\n",
|
||||
" [0.30485104, 0.15165265, 0.94709317],\n",
|
||||
" [0.90827137, 0.58854675, 0.01857002]],\n",
|
||||
"\n",
|
||||
" [[0.70123418, 0.43090173, 0.7096038 ],\n",
|
||||
" [0.20637783, 0.20096581, 0.22956612],\n",
|
||||
" [0.81978383, 0.16775403, 0.67412096],\n",
|
||||
" [0.1011535 , 0.35596916, 0.36702071],\n",
|
||||
" [0.5874605 , 0.79341372, 0.93292159]],\n",
|
||||
"\n",
|
||||
" [[0.77997124, 0.46311399, 0.5465576 ],\n",
|
||||
" [0.20406287, 0.37547625, 0.59862253],\n",
|
||||
" [0.52933135, 0.84249092, 0.02969684],\n",
|
||||
" [0.29114617, 0.10405779, 0.5359062 ],\n",
|
||||
" [0.25197146, 0.83297465, 0.67025403]]]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = np.random.rand(1,5,5,3)\n",
|
||||
"X"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[1.7256129, 1.2551599, 2.1308813],\n",
|
||||
" [1.4376774, 2.9754324, 1.5231993]],\n",
|
||||
"\n",
|
||||
" [[2.6161351, 1.3926848, 1.4582142],\n",
|
||||
" [1.7695144, 1.5135639, 2.9403505]]]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y = model.predict(X)\n",
|
||||
"y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_json = {\n",
|
||||
" \"in\": (X*1000).round().astype(int).flatten().tolist()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out_json = {\n",
|
||||
" \"out\": (y*1000).round().astype(int).flatten().tolist()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"sumPooling2D_input.json\", \"w\") as f:\n",
|
||||
" json.dump(in_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"sumPooling2D_output.json\", \"w\") as f:\n",
|
||||
" json.dump(out_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "tf24",
|
||||
"language": "python",
|
||||
"name": "tf24"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "circomlib-ml",
|
||||
"version": "1.1.0",
|
||||
"version": "1.2.0",
|
||||
"description": "Circuits library for machine learning in circom",
|
||||
"main": "index.js",
|
||||
"directories": {
|
||||
|
||||
39
test/SumPooling2D.js
Normal file
39
test/SumPooling2D.js
Normal file
@@ -0,0 +1,39 @@
|
||||
const chai = require("chai");
|
||||
const { Console } = require("console");
|
||||
const path = require("path");
|
||||
|
||||
const wasm_tester = require("circom_tester").wasm;
|
||||
|
||||
const F1Field = require("ffjavascript").F1Field;
|
||||
const Scalar = require("ffjavascript").Scalar;
|
||||
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
|
||||
const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
const json = require("../models/sumPooling2D_input.json");
|
||||
const OUTPUT = require("../models/sumPooling2D_output.json");
|
||||
|
||||
describe("SumPooling2D layer test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
it("(5,5,3) -> (2,2,3)", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "SumPooling2D_test.circom"));
|
||||
await circuit.loadConstraints();
|
||||
assert.equal(circuit.nVars, 76);
|
||||
assert.equal(circuit.constraints.length, 0);
|
||||
|
||||
const INPUT = {
|
||||
"in": json.in
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
|
||||
for (var i=0; i<2*2*3; i++) {
|
||||
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(2));
|
||||
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(2));
|
||||
}
|
||||
});
|
||||
});
|
||||
5
test/circuits/SumPooling2D_test.circom
Normal file
5
test/circuits/SumPooling2D_test.circom
Normal file
@@ -0,0 +1,5 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/SumPooling2D.circom";
|
||||
|
||||
component main = SumPooling2D(5, 5, 3, 2);
|
||||
108
test/circuits/mnist_convnet_test.circom
Normal file
108
test/circuits/mnist_convnet_test.circom
Normal file
@@ -0,0 +1,108 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/Conv2D.circom";
|
||||
include "../../circuits/Dense.circom";
|
||||
include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/Poly.circom";
|
||||
include "../../circuits/SumPooling2D.circom";
|
||||
|
||||
template mnist_convnet() {
|
||||
signal input in[28][28][1];
|
||||
signal input conv2d_1_weights[3][3][1][4];
|
||||
signal input conv2d_1_bias[4];
|
||||
signal input conv2d_2_weights[3][3][4][8];
|
||||
signal input conv2d_2_bias[8];
|
||||
signal input dense_weights[200][10];
|
||||
signal input dense_bias[10];
|
||||
signal output out;
|
||||
|
||||
component conv2d_1 = Conv2D(28,28,1,4,3);
|
||||
component poly_1[26][26][4];
|
||||
component sum2d_1 = SumPooling2D(26,26,4,2);
|
||||
component conv2d_2 = Conv2D(13,13,4,8,3);
|
||||
component poly_2[11][11][8];
|
||||
component sum2d_2 = SumPooling2D(11,11,8,2);
|
||||
component dense = Dense(200,10);
|
||||
component argmax = ArgMax(10);
|
||||
|
||||
for (var i=0; i<28; i++) {
|
||||
for (var j=0; j<28; j++) {
|
||||
conv2d_1.in[i][j][0] <== in[i][j][0];
|
||||
}
|
||||
}
|
||||
|
||||
for (var m=0; m<4; m++) {
|
||||
for (var i=0; i<3; i++) {
|
||||
for (var j=0; j<3; j++) {
|
||||
conv2d_1.weights[i][j][0][m] <== conv2d_1_weights[i][j][0][m];
|
||||
}
|
||||
}
|
||||
conv2d_1.bias[m] <== conv2d_1_bias[m];
|
||||
}
|
||||
|
||||
for (var i=0; i<26; i++) {
|
||||
for (var j=0; j<26; j++) {
|
||||
for (var k=0; k<4; k++) {
|
||||
poly_1[i][j][k] = Poly(10**6);
|
||||
poly_1[i][j][k].in <== conv2d_1.out[i][j][k];
|
||||
sum2d_1.in[i][j][k] <== poly_1[i][j][k].out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<13; i++) {
|
||||
for (var j=0; j<13; j++) {
|
||||
for (var k=0; k<4; k++) {
|
||||
conv2d_2.in[i][j][k] <== sum2d_1.out[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var m=0; m<8; m++) {
|
||||
for (var i=0; i<3; i++) {
|
||||
for (var j=0; j<3; j++) {
|
||||
for (var k=0; k<4; k++) {
|
||||
conv2d_2.weights[i][j][k][m] <== conv2d_2_weights[i][j][k][m];
|
||||
}
|
||||
}
|
||||
}
|
||||
conv2d_2.bias[m] <== conv2d_2_bias[m];
|
||||
}
|
||||
|
||||
for (var i=0; i<11; i++) {
|
||||
for (var j=0; j<11; j++) {
|
||||
for (var k=0; k<8; k++) {
|
||||
poly_2[i][j][k] = Poly(10**15);
|
||||
poly_2[i][j][k].in <== conv2d_2.out[i][j][k];
|
||||
sum2d_2.in[i][j][k] <== poly_2[i][j][k].out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<5; i++) {
|
||||
for (var j=0; j<5; j++) {
|
||||
for (var k=0; k<8; k++) {
|
||||
dense.in[idx] <== sum2d_2.out[i][j][k];
|
||||
for (var m=0; m<10; m++) {
|
||||
dense.weights[idx][m] <== dense_weights[idx][m];
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
dense.bias[i] <== dense_bias[i];
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
log(dense.out[i]);
|
||||
argmax.in[i] <== dense.out[i];
|
||||
}
|
||||
|
||||
out <== argmax.out;
|
||||
}
|
||||
|
||||
component main = mnist_convnet();
|
||||
45
test/mnist_convnet.js
Normal file
45
test/mnist_convnet.js
Normal file
@@ -0,0 +1,45 @@
|
||||
const chai = require("chai");
|
||||
const path = require("path");
|
||||
|
||||
const wasm_tester = require("circom_tester").wasm;
|
||||
|
||||
const F1Field = require("ffjavascript").F1Field;
|
||||
const Scalar = require("ffjavascript").Scalar;
|
||||
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
|
||||
const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
const json = require("../models/mnist_convnet_input.json");
|
||||
|
||||
describe("mnist convnet test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
it("should return correct output", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_convnet_test.circom"));
|
||||
await circuit.loadConstraints();
|
||||
assert.equal(circuit.nVars, 70524);
|
||||
assert.equal(circuit.constraints.length, 67403);
|
||||
|
||||
let INPUT = {};
|
||||
|
||||
for (const [key, value] of Object.entries(json)) {
|
||||
if (Array.isArray(value)) {
|
||||
let tmpArray = [];
|
||||
for (let i = 0; i < value.flat().length; i++) {
|
||||
tmpArray.push(Fr.e(value.flat()[i]));
|
||||
}
|
||||
INPUT[key] = tmpArray;
|
||||
} else {
|
||||
INPUT[key] = Fr.e(value);
|
||||
}
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
//console.log(witness[1]);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
assert(Fr.eq(Fr.e(witness[1]),Fr.e(7)));
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user