diff --git a/circuits/Conv1D.circom b/circuits/Conv1D.circom new file mode 100644 index 0000000..b1685b7 --- /dev/null +++ b/circuits/Conv1D.circom @@ -0,0 +1,40 @@ +pragma circom 2.0.3; + +include "./circomlib-matrix/matElemMul.circom"; +include "./circomlib-matrix/matElemSum.circom"; +include "./util.circom"; + +// Conv1D layer with valid padding +template Conv1D (nInputs, nChannels, nFilters, kernelSize, strides) { + signal input in[nInputs][nChannels]; + signal input weights[kernelSize][nChannels][nFilters]; + signal input bias[nFilters]; + signal output out[(nInputs-kernelSize)\strides+1][nFilters]; + + component mul[(nInputs-kernelSize)\strides+1][nChannels][nFilters]; + component elemSum[(nInputs-kernelSize)\strides+1][nChannels][nFilters]; + component sum[(nInputs-kernelSize)\strides+1][nFilters]; + + for (var i=0; i<(nInputs-kernelSize)\strides+1; i++) { + for (var j=0; j,\n", + " ]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.weights" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[0.84452152, 0.40819381, 0.90425158],\n", + " [0.65092274, 0.59893334, 0.60028247],\n", + " [0.44949689, 0.254961 , 0.40024966],\n", + " [0.57719296, 0.96000249, 0.08217882],\n", + " [0.66562102, 0.00446413, 0.6464117 ],\n", + " [0.85546508, 0.38582714, 0.8505983 ],\n", + " [0.27450673, 0.201367 , 0.18818527],\n", + " [0.90095107, 0.33781381, 0.84773899],\n", + " [0.36179829, 0.39354172, 0.64100907],\n", + " [0.35045615, 0.37234526, 0.48415795],\n", + " [0.11139122, 0.60499841, 0.58387442],\n", + " [0.87088195, 0.64640523, 0.48402816],\n", + " [0.9638806 , 0.29440207, 0.22027009],\n", + " [0.3764113 , 0.13575624, 0.7720717 ],\n", + " [0.42784105, 0.41073501, 0.28311926],\n", + " [0.66622245, 0.95378854, 0.66375939],\n", + " [0.41543282, 0.37529526, 0.23072244],\n", + " [0.54334445, 0.1388458 , 0.85307472],\n", + " [0.25258015, 0.37725648, 0.75018739],\n", + " [0.75532678, 0.52800654, 0.46152891]]])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.random.rand(1,20,3)\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[ 0.7218593 , -0.4875301 ],\n", + " [-0.06567734, -0.48088893],\n", + " [ 0.21620643, -0.45382428],\n", + " [ 0.1553162 , -0.5966817 ],\n", + " [ 0.23019192, -0.01706157],\n", + " [-0.05194061, -0.49539083]]], dtype=float32)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y = model.predict(X)\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "in_json = {\n", + " \"in\": (X*1000).round().astype(int).flatten().tolist(),\n", + " \"weights\": (model.weights[0].numpy()*1000).round().astype(int).flatten().tolist(),\n", + " \"bias\": np.zeros(model.weights[1].numpy().shape).tolist()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "out_json = {\n", + " \"out\": (y*1000000).round().astype(int).flatten().tolist()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import json" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"conv1D_input.json\", \"w\") as f:\n", + " json.dump(in_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"conv1D_output.json\", \"w\") as f:\n", + " json.dump(out_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tf24", + "language": "python", + "name": "tf24" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/conv2D_stride_input.json b/models/conv2D_stride_input.json new file mode 100644 index 0000000..9e27c25 --- /dev/null +++ b/models/conv2D_stride_input.json @@ -0,0 +1 @@ +{"in": [973, 455, 27, 779, 573, 643, 791, 682, 559, 200, 164, 118, 262, 737, 415, 916, 189, 868, 905, 880, 569, 158, 662, 768, 912, 183, 835, 547, 665, 882, 760, 511, 931, 911, 564, 152, 37, 373, 105, 447, 543, 840, 795, 91, 357, 518, 561, 588, 847, 638, 660, 925, 632, 747, 731, 405, 980, 638, 895, 327, 718, 461, 718, 158, 810, 519, 161, 782, 715, 862, 535, 393, 409, 14, 791, 405, 854, 711, 403, 406, 349, 327, 459, 349, 534, 930, 727, 482, 222, 238, 704, 916, 922, 134, 835, 747, 551, 935, 413, 462, 117, 963, 746, 181, 927, 679, 843, 490, 562, 944, 920, 234, 287, 500, 877, 33, 594, 141, 379, 934, 333, 767, 35, 121, 69, 326, 597, 595, 10, 951, 478, 18, 797, 203, 334, 619, 975, 733, 203, 223, 546, 968, 964, 252, 399, 236, 546, 318, 551, 708, 235, 173, 473, 639, 872, 414, 431, 180, 288, 79, 425, 298, 779, 265, 194, 890, 177, 936, 955, 714, 22, 775, 211, 717, 211, 682, 993, 441, 351, 184, 727, 344, 453, 700, 742, 475, 907, 793, 925, 354, 277, 626, 803, 380, 886, 706, 735, 646, 762, 341, 252, 414, 812, 329, 819, 432, 642, 257, 757, 562, 421, 332, 893, 254, 75, 167, 421, 207, 818, 805, 627, 877, 665, 805, 20, 495, 286, 590, 743, 648, 284, 138, 29, 394, 727, 503, 112, 20, 542, 601, 538, 918, 390, 37, 925, 913, 347, 4, 175, 165, 398, 430, 523, 173, 81, 614, 508, 262, 361, 619, 4, 12, 702, 301, 563, 847, 882, 787, 70, 65, 314, 29, 109, 828, 132, 570, 340, 942, 938, 283, 529, 359, 948, 358, 255, 801, 951, 682, 23, 986, 691, 544, 206, 910, 985, 620, 898, 872, 771, 845], "weights": [90, 64, 172, 119, -231, 77, 56, 87, 109, -132, 72, -49, -144, -204, -186, -236, -135, -200, -38, 245, 96, 129, 139, -61, -154, -157, 175, 234, -94, 137, 139, -52, 123, 31, -228, 164, 141, 0, 199, 256, 174, -155, 212, 134, 239, 158, 12, -174, -134, -88, -182, 49, 191, -183, 244, 117, -15, 149, 1, 143, -60, -36, -44, -107, 178, 221, -140, -153, 37, -244, -257, -119, 92, -150, 97, -77, -102, -13, 197, -218, 196, 54, -239, 138, 153, 237, 252, 272, 114, -62, -14, 253, 161, 236, -116, -99], "bias": [0, 0]} \ No newline at end of file diff --git a/models/conv2D_stride_output.json b/models/conv2D_stride_output.json new file mode 100644 index 0000000..20fa7df --- /dev/null +++ b/models/conv2D_stride_output.json @@ -0,0 +1 @@ +{"out": [604066, 49637, 825491, 594827, 930285, 228053, 1282346, 594940, 897624, 379819, 1158895, 446232, 593408, 64606, 1260976, 251827, 504049, 658720]} \ No newline at end of file diff --git a/models/conv2d.ipynb b/models/conv2d.ipynb index 7c77c2c..d8850b1 100644 --- a/models/conv2d.ipynb +++ b/models/conv2d.ipynb @@ -241,6 +241,336 @@ " json.dump(out_json, f)" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = Input(shape=(10,10,3))\n", + "x = Conv2D(2, 4, 3)(inputs)\n", + "model = Model(inputs, x)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_1 (InputLayer) [(None, 10, 10, 3)] 0 \n", + "_________________________________________________________________\n", + "conv2d (Conv2D) (None, 3, 3, 2) 98 \n", + "=================================================================\n", + "Total params: 98\n", + "Trainable params: 98\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[,\n", + " ]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.weights" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[0.97251485, 0.45513694, 0.02700021],\n", + " [0.77939851, 0.57285348, 0.64336143],\n", + " [0.7910932 , 0.68158499, 0.55903757],\n", + " [0.19978621, 0.16401386, 0.1178243 ],\n", + " [0.26202894, 0.73739062, 0.41515085],\n", + " [0.91602252, 0.1889969 , 0.86812371],\n", + " [0.90520645, 0.88021945, 0.56918999],\n", + " [0.1576814 , 0.66244825, 0.76843253],\n", + " [0.91151503, 0.18348028, 0.83457797],\n", + " [0.54744986, 0.6653615 , 0.88219377]],\n", + "\n", + " [[0.75996691, 0.51053919, 0.93055488],\n", + " [0.91144281, 0.56385737, 0.15222672],\n", + " [0.03655235, 0.37340661, 0.10463077],\n", + " [0.44692914, 0.54256831, 0.84009866],\n", + " [0.79509769, 0.09059503, 0.35705721],\n", + " [0.51836358, 0.56141664, 0.58831904],\n", + " [0.84739405, 0.63802925, 0.66017923],\n", + " [0.92534248, 0.63225917, 0.7473312 ],\n", + " [0.73096606, 0.40520727, 0.97955684],\n", + " [0.63831414, 0.89495898, 0.32695978]],\n", + "\n", + " [[0.71829301, 0.46062667, 0.71811823],\n", + " [0.15848069, 0.80969418, 0.51871761],\n", + " [0.16135808, 0.78216571, 0.71516724],\n", + " [0.86211842, 0.53539272, 0.39332206],\n", + " [0.40901593, 0.01396001, 0.79064577],\n", + " [0.40514149, 0.85380516, 0.71092302],\n", + " [0.40279167, 0.40553908, 0.34943154],\n", + " [0.32739584, 0.45935947, 0.34916069],\n", + " [0.53435728, 0.93021973, 0.72660915],\n", + " [0.48201736, 0.22166786, 0.23815264]],\n", + "\n", + " [[0.70384743, 0.91639052, 0.92217664],\n", + " [0.13377693, 0.83509932, 0.7465723 ],\n", + " [0.55129428, 0.93502285, 0.4134988 ],\n", + " [0.46219387, 0.11715096, 0.96329141],\n", + " [0.74610843, 0.18125007, 0.92666074],\n", + " [0.67935838, 0.84330332, 0.4897472 ],\n", + " [0.56189416, 0.94376182, 0.91971249],\n", + " [0.2344166 , 0.28675204, 0.50005015],\n", + " [0.87694834, 0.03292064, 0.59372731],\n", + " [0.1410371 , 0.37930507, 0.93440982]],\n", + "\n", + " [[0.33332779, 0.76704737, 0.03504045],\n", + " [0.12148567, 0.06901141, 0.32551204],\n", + " [0.59735274, 0.59523581, 0.00993422],\n", + " [0.9513948 , 0.47815931, 0.018291 ],\n", + " [0.79727936, 0.20293482, 0.33432458],\n", + " [0.61879149, 0.97450524, 0.73288953],\n", + " [0.20291612, 0.2230533 , 0.54581403],\n", + " [0.96841368, 0.96407929, 0.25184421],\n", + " [0.39925087, 0.23617574, 0.54586969],\n", + " [0.31766469, 0.55082408, 0.70822638]],\n", + "\n", + " [[0.23476765, 0.17316881, 0.47290996],\n", + " [0.63928832, 0.87168719, 0.41376668],\n", + " [0.43124419, 0.18044566, 0.28764125],\n", + " [0.07918316, 0.42466569, 0.29817 ],\n", + " [0.77903568, 0.26465035, 0.19378377],\n", + " [0.88960928, 0.17741272, 0.93643759],\n", + " [0.95486371, 0.71402737, 0.02170905],\n", + " [0.77465114, 0.21071205, 0.71740116],\n", + " [0.2105675 , 0.68209689, 0.99265747],\n", + " [0.44123258, 0.35068216, 0.18443967]],\n", + "\n", + " [[0.72729028, 0.34422837, 0.45330759],\n", + " [0.70008861, 0.74244728, 0.47519843],\n", + " [0.90665756, 0.79266484, 0.92471306],\n", + " [0.35356468, 0.27658691, 0.62612034],\n", + " [0.80262629, 0.38014853, 0.88572335],\n", + " [0.70577279, 0.73463977, 0.64607726],\n", + " [0.76175153, 0.34114261, 0.25165355],\n", + " [0.41433933, 0.81211749, 0.32920431],\n", + " [0.81885149, 0.43244061, 0.64193369],\n", + " [0.25711737, 0.75723915, 0.56240932]],\n", + "\n", + " [[0.4214149 , 0.33210466, 0.89250414],\n", + " [0.25449979, 0.07459641, 0.16747531],\n", + " [0.42138569, 0.20694074, 0.81756281],\n", + " [0.80522426, 0.62663345, 0.87747416],\n", + " [0.66534828, 0.80525886, 0.01988047],\n", + " [0.4948794 , 0.28630048, 0.58954056],\n", + " [0.74320486, 0.64790933, 0.28396301],\n", + " [0.13750094, 0.02872952, 0.39362795],\n", + " [0.72667091, 0.50329443, 0.11188484],\n", + " [0.01977563, 0.54219458, 0.60052706]],\n", + "\n", + " [[0.53801627, 0.91787638, 0.38982677],\n", + " [0.03749426, 0.92513169, 0.91275971],\n", + " [0.34702074, 0.00438524, 0.17541972],\n", + " [0.16459438, 0.39827819, 0.43025267],\n", + " [0.52254362, 0.17287154, 0.08078938],\n", + " [0.61399286, 0.50806298, 0.26215701],\n", + " [0.36095081, 0.61864805, 0.00384717],\n", + " [0.01225557, 0.70154596, 0.30105766],\n", + " [0.56348969, 0.84660337, 0.88242305],\n", + " [0.78735834, 0.06987492, 0.06523453]],\n", + "\n", + " [[0.31370938, 0.02861093, 0.10943704],\n", + " [0.82845993, 0.13236614, 0.56989642],\n", + " [0.33997596, 0.94155797, 0.93786532],\n", + " [0.28295679, 0.52881401, 0.35867141],\n", + " [0.94786737, 0.35825851, 0.25476474],\n", + " [0.80097957, 0.95050832, 0.68227972],\n", + " [0.02266812, 0.98557197, 0.69057788],\n", + " [0.54403132, 0.20586683, 0.90978653],\n", + " [0.9848536 , 0.62006799, 0.89784117],\n", + " [0.87154093, 0.77054866, 0.84496778]]]])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X = np.random.rand(1,10,10,3)\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[[0.60406584, 0.04963745],\n", + " [0.8254909 , 0.59482723],\n", + " [0.93028533, 0.22805347]],\n", + "\n", + " [[1.2823461 , 0.5949404 ],\n", + " [0.897624 , 0.3798192 ],\n", + " [1.1588949 , 0.44623226]],\n", + "\n", + " [[0.5934084 , 0.06460603],\n", + " [1.2609756 , 0.25182724],\n", + " [0.5040486 , 0.6587203 ]]]], dtype=float32)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y = model.predict(X)\n", + "y" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "in_json = {\n", + " \"in\": (X*1000).round().astype(int).flatten().tolist(),\n", + " \"weights\": (model.weights[0].numpy()*1000).round().astype(int).flatten().tolist(),\n", + " \"bias\": [0,0]\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "out_json = {\n", + " \"out\": (y*1000000).round().astype(int).flatten().tolist()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"conv2D_stride_input.json\", \"w\") as f:\n", + " json.dump(in_json, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"conv2D_stride_output.json\", \"w\") as f:\n", + " json.dump(out_json, f)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/package-lock.json b/package-lock.json index 572b777..dc64ca3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "circomlib-ml", - "version": "1.2.1", + "version": "1.3.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "circomlib-ml", - "version": "1.2.1", + "version": "1.3.0", "license": "GPL-3.0", "devDependencies": { "blake-hash": "^2.0.0", diff --git a/package.json b/package.json index 41e179b..2b3f33b 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "circomlib-ml", - "version": "1.2.1", + "version": "1.3.0", "description": "Circuits library for machine learning in circom", "main": "index.js", "directories": { diff --git a/test/Conv1D.js b/test/Conv1D.js new file mode 100644 index 0000000..4a70b5a --- /dev/null +++ b/test/Conv1D.js @@ -0,0 +1,49 @@ +const chai = require("chai"); +const { Console } = require("console"); +const path = require("path"); + +const wasm_tester = require("circom_tester").wasm; + +const F1Field = require("ffjavascript").F1Field; +const Scalar = require("ffjavascript").Scalar; +exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617"); +const Fr = new F1Field(exports.p); + +const assert = chai.assert; + +const json = require("../models/conv1D_input.json"); +const OUTPUT = require("../models/conv1D_output.json"); + +describe("Conv1D layer test", function () { + this.timeout(100000000); + + it("(20,3) -> (6,2)", async () => { + const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv1D_test.circom")); + //await circuit.loadConstraints(); + //assert.equal(circuit.nVars, 618); + //assert.equal(circuit.constraints.length, 486); + + let INPUT = {}; + + for (const [key, value] of Object.entries(json)) { + if (Array.isArray(value)) { + let tmpArray = []; + for (let i = 0; i < value.flat().length; i++) { + tmpArray.push(Fr.e(value.flat()[i])); + } + INPUT[key] = tmpArray; + } else { + INPUT[key] = Fr.e(value); + } + } + + const witness = await circuit.calculateWitness(INPUT, true); + + assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); + + for (var i=0; i<6*2; i++) { + assert((witness[i+1]-Fr.e(OUTPUT.out[i])) (3,3,2)", async () => { + let json = require("../models/conv2D_input.json"); + let OUTPUT = require("../models/conv2D_output.json"); + const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2D_test.circom")); //await circuit.loadConstraints(); //assert.equal(circuit.nVars, 618); @@ -44,4 +46,35 @@ describe("Conv2D layer test", function () { assert((Fr.e(OUTPUT.out[i])-witness[i+1]) (3,3,2)", async () => { + let json = require("../models/conv2D_stride_input.json"); + let OUTPUT = require("../models/conv2D_stride_output.json"); + + const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2D_stride_test.circom")); + //await circuit.loadConstraints(); + //assert.equal(circuit.nVars, 618); + //assert.equal(circuit.constraints.length, 486); + + const weights = []; + + for (var i=0; i