mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-10 06:28:08 -05:00
ver 1.0.1 - restructured ReLU, added ArgMax from zk-mnist, mnist test case
This commit is contained in:
40
circuits/ArgMax.circom
Normal file
40
circuits/ArgMax.circom
Normal file
@@ -0,0 +1,40 @@
|
||||
// from 0xZKML/zk-mnist
|
||||
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../node_modules/circomlib/circuits/comparators.circom";
|
||||
include "../node_modules/circomlib/circuits/switcher.circom";
|
||||
|
||||
template ArgMax (n) {
|
||||
signal input in[n];
|
||||
signal output out;
|
||||
component gts[n]; // store comparators
|
||||
component switchers[n+1]; // switcher for comparing maxs
|
||||
component aswitchers[n+1]; // switcher for arg max
|
||||
|
||||
signal maxs[n+1];
|
||||
signal amaxs[n+1];
|
||||
|
||||
maxs[0] <== in[0];
|
||||
amaxs[0] <== 0;
|
||||
for(var i = 0; i < n; i++) {
|
||||
gts[i] = GreaterThan(30);
|
||||
switchers[i+1] = Switcher();
|
||||
aswitchers[i+1] = Switcher();
|
||||
|
||||
gts[i].in[1] <== maxs[i];
|
||||
gts[i].in[0] <== in[i];
|
||||
|
||||
switchers[i+1].sel <== gts[i].out;
|
||||
switchers[i+1].L <== maxs[i];
|
||||
switchers[i+1].R <== in[i];
|
||||
|
||||
aswitchers[i+1].sel <== gts[i].out;
|
||||
aswitchers[i+1].L <== amaxs[i];
|
||||
aswitchers[i+1].R <== i;
|
||||
amaxs[i+1] <== aswitchers[i+1].outL;
|
||||
maxs[i+1] <== switchers[i+1].outL;
|
||||
}
|
||||
|
||||
out <== amaxs[n];
|
||||
}
|
||||
@@ -3,19 +3,13 @@ pragma circom 2.0.3;
|
||||
include "util.circom";
|
||||
|
||||
// ReLU layer
|
||||
template ReLU (m,n) {
|
||||
signal input in[m][n];
|
||||
signal output out[m][n];
|
||||
template ReLU () {
|
||||
signal input in;
|
||||
signal output out;
|
||||
|
||||
component isPositive[m][n];
|
||||
component isPositive = IsPositive();
|
||||
|
||||
isPositive.in <== in;
|
||||
|
||||
for (var i=0; i<m; i++) {
|
||||
for (var j=0; j<n; j++) {
|
||||
isPositive[i][j] = IsPositive();
|
||||
|
||||
isPositive[i][j].in <== in[i][j];
|
||||
|
||||
out[i][j] <== in[i][j] * isPositive[i][j].out;
|
||||
}
|
||||
}
|
||||
out <== in * isPositive.out;
|
||||
}
|
||||
356
models/mnist.ipynb
Normal file
356
models/mnist.ipynb
Normal file
@@ -0,0 +1,356 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.keras.layers import Input, Conv2D, Dense, ReLU, Flatten\n",
|
||||
"from tensorflow.keras import Model\n",
|
||||
"from tensorflow.keras.datasets import mnist\n",
|
||||
"from tensorflow.keras.utils import to_categorical\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"(X_train, y_train), (X_test, y_test) = mnist.load_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert y_train into one-hot format\n",
|
||||
"temp = []\n",
|
||||
"for i in range(len(y_train)):\n",
|
||||
" temp.append(to_categorical(y_train[i], num_classes=10))\n",
|
||||
"y_train = np.array(temp)\n",
|
||||
"# Convert y_test into one-hot format\n",
|
||||
"temp = []\n",
|
||||
"for i in range(len(y_test)): \n",
|
||||
" temp.append(to_categorical(y_test[i], num_classes=10))\n",
|
||||
"y_test = np.array(temp)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#reshaping\n",
|
||||
"X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)\n",
|
||||
"X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = Input(shape=(28,28,1))\n",
|
||||
"out = Conv2D(1,3)(inputs)\n",
|
||||
"out = ReLU()(out)\n",
|
||||
"out = Flatten()(out)\n",
|
||||
"out = Dense(10, activation='softmax')(out)\n",
|
||||
"model = Model(inputs, out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
"Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
"input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"conv2d (Conv2D) (None, 26, 26, 1) 10 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"re_lu (ReLU) (None, 26, 26, 1) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"flatten (Flatten) (None, 676) 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"dense (Dense) (None, 10) 6770 \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 6,780\n",
|
||||
"Trainable params: 6,780\n",
|
||||
"Non-trainable params: 0\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.compile(\n",
|
||||
" loss='categorical_crossentropy',\n",
|
||||
" optimizer='adam',\n",
|
||||
" metrics=['acc']\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/10\n",
|
||||
"1875/1875 [==============================] - 1s 630us/step - loss: 25.0087 - acc: 0.7775 - val_loss: 0.9836 - val_acc: 0.8439\n",
|
||||
"Epoch 2/10\n",
|
||||
"1875/1875 [==============================] - 1s 516us/step - loss: 0.5950 - acc: 0.8797 - val_loss: 0.2956 - val_acc: 0.9179\n",
|
||||
"Epoch 3/10\n",
|
||||
"1875/1875 [==============================] - 1s 512us/step - loss: 0.3047 - acc: 0.9162 - val_loss: 0.2921 - val_acc: 0.9207\n",
|
||||
"Epoch 4/10\n",
|
||||
"1875/1875 [==============================] - 1s 578us/step - loss: 0.2999 - acc: 0.9158 - val_loss: 0.3009 - val_acc: 0.9103\n",
|
||||
"Epoch 5/10\n",
|
||||
"1875/1875 [==============================] - 1s 533us/step - loss: 0.3126 - acc: 0.9144 - val_loss: 0.2853 - val_acc: 0.9210\n",
|
||||
"Epoch 6/10\n",
|
||||
"1875/1875 [==============================] - 1s 513us/step - loss: 0.3161 - acc: 0.9130 - val_loss: 0.3630 - val_acc: 0.9095\n",
|
||||
"Epoch 7/10\n",
|
||||
"1875/1875 [==============================] - 1s 560us/step - loss: 0.3093 - acc: 0.9162 - val_loss: 0.3176 - val_acc: 0.9085\n",
|
||||
"Epoch 8/10\n",
|
||||
"1875/1875 [==============================] - 1s 570us/step - loss: 0.2946 - acc: 0.9174 - val_loss: 0.2941 - val_acc: 0.9158\n",
|
||||
"Epoch 9/10\n",
|
||||
"1875/1875 [==============================] - 1s 524us/step - loss: 0.2761 - acc: 0.9224 - val_loss: 0.3004 - val_acc: 0.9216\n",
|
||||
"Epoch 10/10\n",
|
||||
"1875/1875 [==============================] - 1s 558us/step - loss: 0.2763 - acc: 0.9222 - val_loss: 0.3071 - val_acc: 0.9220\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<tensorflow.python.keras.callbacks.History at 0x11a8a3730>"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(28, 28, 1)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = X_test[0]\n",
|
||||
"X.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[1.5824087e-11, 6.1573762e-21, 3.0310914e-11, 1.3590869e-04,\n",
|
||||
" 2.2671966e-12, 3.1620637e-10, 1.7203982e-17, 9.9986279e-01,\n",
|
||||
" 1.9324822e-09, 1.3374932e-06]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y = model.predict(X_test[[0]])\n",
|
||||
"y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<matplotlib.image.AxesImage at 0x11ef012b0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAANh0lEQVR4nO3df6zddX3H8dfL/sJeYFKwtSuVKqKxOsHlCppuSw3DAYYUo2w0GekSZskGCSxmG2ExkmxxjIiETWdSR2clCFOBQLRzksaNkLHKhZRSKFuRdVh71wvUrUXgtqXv/XG/LJdyz+dezvd7zve07+cjuTnnfN/ne77vfHtf/X7v+XzP+TgiBODY95a2GwDQH4QdSIKwA0kQdiAJwg4kMbufG5vreXGchvq5SSCVV/QLHYhxT1WrFXbb50u6RdIsSX8XETeUnn+chnSOz62zSQAFm2NTx1rXp/G2Z0n6qqQLJC2XtNr28m5fD0Bv1fmb/WxJT0fEMxFxQNKdklY10xaAptUJ+xJJP530eFe17HVsr7U9YnvkoMZrbA5AHXXCPtWbAG+49jYi1kXEcEQMz9G8GpsDUEedsO+StHTS41Ml7a7XDoBeqRP2hyWdYftdtudKulTSfc20BaBpXQ+9RcQh21dJ+idNDL2tj4gnGusMQKNqjbNHxEZJGxvqBUAPcbkskARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlaUzbb3ilpv6RXJR2KiOEmmgLQvFphr3w8Ip5v4HUA9BCn8UASdcMekn5o+xHba6d6gu21tkdsjxzUeM3NAehW3dP4FRGx2/ZCSffbfioiHpj8hIhYJ2mdJJ3oBVFzewC6VOvIHhG7q9sxSfdIOruJpgA0r+uw2x6yfcJr9yV9QtK2phoD0Kw6p/GLJN1j+7XX+VZE/KCRrgA0ruuwR8Qzks5ssBcAPcTQG5AEYQeSIOxAEoQdSIKwA0k08UGYFF747Mc61t552dPFdZ8aW1SsHxifU6wvuaNcn7/rxY61w1ueLK6LPDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLPP0J/88bc61j499PPyyqfX3PjKcnnnoZc61m557uM1N370+vHYaR1rQzf9UnHd2Zseabqd1nFkB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkHNG/SVpO9II4x+f2bXtN+sVnzulYe/5D5f8zT9pe3sc/f7+L9bkf+p9i/cYP3t2xdt5bXy6u+/2Xji/WPzm/82fl63o5DhTrm8eHivWVxx3setvv+f4Vxfp71z7c9Wu3aXNs0r7YO+UvFEd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCz7PP0NB3Nxdq9V77xHqr62/esbJj7S9WLCtv+1/K33l/48r3dNHRzMx++XCxPrR1tFg/+YG7ivVfmdv5+/bn7yx/F/+xaNoju+31tsdsb5u0bIHt+23vqG5P6m2bAOqayWn8NySdf8SyayVtiogzJG2qHgMYYNOGPSIekLT3iMWrJG2o7m+QdHGzbQFoWrdv0C2KiFFJqm4Xdnqi7bW2R2yPHNR4l5sDUFfP342PiHURMRwRw3M0r9ebA9BBt2HfY3uxJFW3Y821BKAXug37fZLWVPfXSLq3mXYA9Mq04+y279DEN5efYnuXpC9IukHSt21fLulZSZf0skmUHfrvPR1rQ3d1rknSq9O89tB3X+iio2bs+f2PFesfmFv+9f3S3vd1rC37+2eK6x4qVo9O04Y9IlZ3KB2d30IBJMXlskAShB1IgrADSRB2IAnCDiTBR1zRmtmnLS3Wv3LdV4r1OZ5VrH/nlt/sWDt59KHiuscijuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7GjNU3+0pFj/yLzyVNZPHChPR73gyZfedE/HMo7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+zoqfFPfqRj7dHP3DzN2uUZhP7g6quL9bf+64+nef1cOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs6Onnr2g8/HkeJfH0Vf/53nF+vwfPFasR7Gaz7RHdtvrbY/Z3jZp2fW2f2Z7S/VzYW/bBFDXTE7jvyHp/CmW3xwRZ1U/G5ttC0DTpg17RDwgaW8fegHQQ3XeoLvK9tbqNP+kTk+yvdb2iO2RgxqvsTkAdXQb9q9JOl3SWZJGJd3U6YkRsS4ihiNieM40H2wA0DtdhT0i9kTEqxFxWNLXJZ3dbFsAmtZV2G0vnvTwU5K2dXougMEw7Ti77TskrZR0iu1dkr4gaaXtszQxlLlT0hW9axGD7C0nnFCsX/brD3as7Tv8SnHdsS++u1ifN/5wsY7XmzbsEbF6isW39qAXAD3E5bJAEoQdSIKwA0kQdiAJwg4kwUdcUcuO6z9QrH/vlL/tWFu149PFdedtZGitSRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlR9L+/+9Fifevv/HWx/pNDBzvWXvyrU4vrztNosY43hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHtys5f8crF+zef/oVif5/Kv0KWPXdax9vZ/5PPq/cSRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJz9GOfZ5X/iM7+3q1i/5PgXivXb9y8s1hd9vvPx5HBxTTRt2iO77aW2f2R7u+0nbF9dLV9g+37bO6rbk3rfLoBuzeQ0/pCkz0XE+yV9VNKVtpdLulbSpog4Q9Km6jGAATVt2CNiNCIere7vl7Rd0hJJqyRtqJ62QdLFPeoRQAPe1Bt0tpdJ+rCkzZIWRcSoNPEfgqQp/3izvdb2iO2Rgxqv2S6Abs047LaPl3SXpGsiYt9M14uIdRExHBHDczSvmx4BNGBGYbc9RxNBvz0i7q4W77G9uKovljTWmxYBNGHaoTfblnSrpO0R8eVJpfskrZF0Q3V7b086RD1nvq9Y/vOFt9V6+a9+8ZJi/W2PPVTr9dGcmYyzr5B0maTHbW+pll2niZB/2/blkp6VVP5XB9CqacMeEQ9Kcofyuc22A6BXuFwWSIKwA0kQdiAJwg4kQdiBJPiI6zFg1vL3dqytvbPe5Q/L119ZrC+77d9qvT76hyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsx4Kk/7PzFvhfNn/GXCk3p1H8+UH5CRK3XR/9wZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwq8ctHZxfqmi24qVOc32wyOWhzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJmczPvlTSNyW9Q9JhSesi4hbb10v6rKTnqqdeFxEbe9VoZrtXzCrW3zm7+7H02/cvLNbn7Ct/np1Psx89ZnJRzSFJn4uIR22fIOkR2/dXtZsj4ku9aw9AU2YyP/uopNHq/n7b2yUt6XVjAJr1pv5mt71M0oclba4WXWV7q+31tqf8biTba22P2B45qPF63QLo2ozDbvt4SXdJuiYi9kn6mqTTJZ2liSP/lBdoR8S6iBiOiOE5mle/YwBdmVHYbc/RRNBvj4i7JSki9kTEqxFxWNLXJZU/rQGgVdOG3bYl3Sppe0R8edLyxZOe9ilJ25pvD0BTZvJu/ApJl0l63PaWatl1klbbPksToy87JV3Rg/5Q01++sLxYf+i3lhXrMfp4g92gTTN5N/5BSZ6ixJg6cBThCjogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Trl7ohfEOT63b9sDstkcm7Qv9k41VM6RHciCsANJEHYgCcIOJEHYgSQIO5AEYQeS6Os4u+3nJP3XpEWnSHq+bw28OYPa26D2JdFbt5rs7bSIePtUhb6G/Q0bt0ciYri1BgoGtbdB7Uuit271qzdO44EkCDuQRNthX9fy9ksGtbdB7Uuit271pbdW/2YH0D9tH9kB9AlhB5JoJey2z7f977aftn1tGz10Ynun7cdtb7E90nIv622P2d42adkC2/fb3lHdTjnHXku9XW/7Z9W+22L7wpZ6W2r7R7a3237C9tXV8lb3XaGvvuy3vv/NbnuWpP+QdJ6kXZIelrQ6Ip7sayMd2N4paTgiWr8Aw/ZvSHpR0jcj4oPVshsl7Y2IG6r/KE+KiD8dkN6ul/Ri29N4V7MVLZ48zbikiyX9nlrcd4W+flt92G9tHNnPlvR0RDwTEQck3SlpVQt9DLyIeEDS3iMWr5K0obq/QRO/LH3XobeBEBGjEfFodX+/pNemGW913xX66os2wr5E0k8nPd6lwZrvPST90PYjtte23cwUFkXEqDTxyyNpYcv9HGnaabz76Yhpxgdm33Uz/XldbYR9qu/HGqTxvxUR8auSLpB0ZXW6ipmZ0TTe/TLFNOMDodvpz+tqI+y7JC2d9PhUSbtb6GNKEbG7uh2TdI8GbyrqPa/NoFvdjrXcz/8bpGm8p5pmXAOw79qc/ryNsD8s6Qzb77I9V9Klku5roY83sD1UvXEi20OSPqHBm4r6PklrqvtrJN3bYi+vMyjTeHeaZlwt77vWpz+PiL7/SLpQE+/I/0TSn7XRQ4e+3i3psernibZ7k3SHJk7rDmrijOhySSdL2iRpR3W7YIB6u03S45K2aiJYi1vq7dc08afhVklbqp8L2953hb76st+4XBZIgivogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wNGNvRI2D7VDgAAAABJRU5ErkJggg==",
|
||||
"text/plain": [
|
||||
"<Figure size 432x288 with 1 Axes>"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"needs_background": "light"
|
||||
},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"plt.imshow(X)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 1) dtype=float32, numpy=\n",
|
||||
" array([[[[-0.01493068]],\n",
|
||||
" \n",
|
||||
" [[ 0.00165418]],\n",
|
||||
" \n",
|
||||
" [[ 0.01054219]]],\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" [[[ 0.01664183]],\n",
|
||||
" \n",
|
||||
" [[ 0.01126822]],\n",
|
||||
" \n",
|
||||
" [[ 0.00358304]]],\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" [[[-0.00954879]],\n",
|
||||
" \n",
|
||||
" [[-0.00813981]],\n",
|
||||
" \n",
|
||||
" [[ 0.0057789 ]]]], dtype=float32)>,\n",
|
||||
" <tf.Variable 'conv2d/bias:0' shape=(1,) dtype=float32, numpy=array([0.02201873], dtype=float32)>,\n",
|
||||
" <tf.Variable 'dense/kernel:0' shape=(676, 10) dtype=float32, numpy=\n",
|
||||
" array([[-0.0390173 , 0.44924182, -0.27243868, ..., 0.31146002,\n",
|
||||
" -0.28309578, -0.11028677],\n",
|
||||
" [-0.01179668, 0.55754834, -0.2430928 , ..., 0.10775765,\n",
|
||||
" -0.38104013, -0.06206248],\n",
|
||||
" [-0.03452352, 0.46869493, -0.24846792, ..., 0.2193103 ,\n",
|
||||
" -0.37002295, -0.06395224],\n",
|
||||
" ...,\n",
|
||||
" [ 0.01729827, 0.4342358 , -0.11721515, ..., 0.16974102,\n",
|
||||
" -0.10287298, -0.01753861],\n",
|
||||
" [ 0.04901088, 0.47999075, -0.00402304, ..., 0.255874 ,\n",
|
||||
" -0.30667993, -0.06307992],\n",
|
||||
" [-0.01235796, 0.54158044, -0.165757 , ..., 0.25959158,\n",
|
||||
" -0.46181145, -0.01021514]], dtype=float32)>,\n",
|
||||
" <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=\n",
|
||||
" array([-0.14666335, 0.6149221 , -0.08960506, -0.33668435, 0.22668077,\n",
|
||||
" 0.47629577, 0.00557043, 0.3598895 , -0.6543999 , -0.10022835],\n",
|
||||
" dtype=float32)>]"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_json = {\n",
|
||||
" \"in\": X.astype(int).flatten().tolist(),\n",
|
||||
" \"conv2d_weights\": (model.weights[0].numpy()*1000).round().astype(int).flatten().tolist(),\n",
|
||||
" \"conv2d_bias\": (model.weights[1].numpy()*1000).round().astype(int).flatten().tolist(),\n",
|
||||
" \"dense_weights\":(model.weights[2].numpy()*1000).round().astype(int).flatten().tolist(),\n",
|
||||
" \"dense_bias\":(model.weights[3].numpy()*1000000).round().astype(int).flatten().tolist(),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"mnist_input.json\", \"w\") as f:\n",
|
||||
" json.dump(in_json, f)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "11280bdb37aa6bc5d4cf1e4de756386eb1f9eecd8dcdefa77636dfac7be2370d"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.6 ('tf24')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
1
models/mnist_input.json
Normal file
1
models/mnist_input.json
Normal file
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "circomlib-ml",
|
||||
"version": "1.0.0",
|
||||
"version": "1.0.1",
|
||||
"description": "Circuits library for machine learning in circom",
|
||||
"main": "index.js",
|
||||
"directories": {
|
||||
|
||||
@@ -2,4 +2,18 @@ pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/ReLU.circom";
|
||||
|
||||
component main = ReLU(1,3);
|
||||
template relu_test() {
|
||||
signal input in[3];
|
||||
signal output out[3];
|
||||
|
||||
component relu[3];
|
||||
|
||||
for (var i=0; i<3; i++) {
|
||||
relu[i] = ReLU();
|
||||
|
||||
relu[i].in <== in[i];
|
||||
out[i] <== relu[i].out;
|
||||
}
|
||||
}
|
||||
|
||||
component main = relu_test();
|
||||
60
test/circuits/mnist_test.circom
Normal file
60
test/circuits/mnist_test.circom
Normal file
@@ -0,0 +1,60 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/Conv2D.circom";
|
||||
include "../../circuits/Dense.circom";
|
||||
include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/ReLU.circom";
|
||||
|
||||
template mnist() {
|
||||
signal input in[28][28][1];
|
||||
signal input conv2d_weights[3][3][1][1];
|
||||
signal input conv2d_bias[1];
|
||||
signal input dense_weights[676][10];
|
||||
signal input dense_bias[10];
|
||||
signal output out;
|
||||
|
||||
component conv2d = Conv2D(28,28,1,1,3);
|
||||
component relu[26*26];
|
||||
component dense = Dense(676,10);
|
||||
component argmax = ArgMax(10);
|
||||
|
||||
for (var i=0; i<28; i++) {
|
||||
for (var j=0; j<28; j++) {
|
||||
conv2d.in[i][j][0] <== in[i][j][0];
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<3; i++) {
|
||||
for (var j=0; j<3; j++) {
|
||||
conv2d.weights[i][j][0][0] <== conv2d_weights[i][j][0][0];
|
||||
}
|
||||
}
|
||||
|
||||
conv2d.bias[0] <== conv2d_bias[0];
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<26; i++) {
|
||||
for (var j=0; j<26; j++) {
|
||||
relu[idx] = ReLU();
|
||||
relu[idx].in <== conv2d.out[i][j][0];
|
||||
dense.in[idx] <== relu[idx].out;
|
||||
for (var k=0; k<10; k++) {
|
||||
dense.weights[idx][k] <== dense_weights[idx][k];
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
dense.bias[i] <== dense_bias[i];
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
argmax.in[i] <== dense.out[i];
|
||||
}
|
||||
|
||||
out <== argmax.out;
|
||||
}
|
||||
|
||||
component main = mnist();
|
||||
@@ -10,7 +10,7 @@ template model1() {
|
||||
signal output out;
|
||||
|
||||
component Dense32 = Dense(3,2);
|
||||
component relu = ReLU(2,1);
|
||||
component relu[2];
|
||||
component Dense21 = Dense(2,1);
|
||||
|
||||
for (var i=0; i<3; i++) {
|
||||
@@ -25,11 +25,12 @@ template model1() {
|
||||
}
|
||||
|
||||
for (var i=0; i<2; i++) {
|
||||
relu.in[i][0] <== Dense32.out[i];
|
||||
relu[i] = ReLU();
|
||||
relu[i].in <== Dense32.out[i];
|
||||
}
|
||||
|
||||
for (var i=0; i<2; i++) {
|
||||
Dense21.in[i] <== relu.out[i][0];
|
||||
Dense21.in[i] <== relu[i].out;
|
||||
Dense21.weights[i][0] <== Dense21weights[i][0];
|
||||
}
|
||||
|
||||
|
||||
60
test/mnist.js
Normal file
60
test/mnist.js
Normal file
@@ -0,0 +1,60 @@
|
||||
const chai = require("chai");
|
||||
const path = require("path");
|
||||
|
||||
const wasm_tester = require("circom_tester").wasm;
|
||||
|
||||
const F1Field = require("ffjavascript").F1Field;
|
||||
const Scalar = require("ffjavascript").Scalar;
|
||||
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
|
||||
const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
const json = require("../models/mnist_input.json");
|
||||
|
||||
describe("mnist test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
it("should return correct output", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_test.circom"));
|
||||
await circuit.loadConstraints();
|
||||
assert.equal(circuit.nVars, 368866);
|
||||
assert.equal(circuit.constraints.length, 362663);
|
||||
|
||||
const conv2d_weights = [];
|
||||
const conv2d_bias = [];
|
||||
const dense_weights = [];
|
||||
const dense_bias = [];
|
||||
|
||||
for (var i=0; i<json.conv2d_weights.length; i++) {
|
||||
conv2d_weights.push(Fr.e(json.conv2d_weights[i]));
|
||||
}
|
||||
|
||||
for (var i=0; i<json.conv2d_bias.length; i++) {
|
||||
conv2d_bias.push(Fr.e(json.conv2d_bias[i]));
|
||||
}
|
||||
|
||||
for (var i=0; i<json.dense_weights.length; i++) {
|
||||
dense_weights.push(Fr.e(json.dense_weights[i]));
|
||||
}
|
||||
|
||||
for (var i=0; i<json.dense_bias.length; i++) {
|
||||
dense_bias.push(Fr.e(json.dense_bias[i]));
|
||||
}
|
||||
|
||||
const INPUT = {
|
||||
"in": json.in,
|
||||
"conv2d_weights": conv2d_weights,
|
||||
"conv2d_bias": conv2d_bias,
|
||||
"dense_weights": dense_weights,
|
||||
"dense_bias": dense_bias
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
//console.log(witness[1]);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
assert(Fr.eq(Fr.e(witness[1]),Fr.e(7)));
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user