v1.2.0 - Added SumPooling2D layer, added high accuracy (98%+) model for MNIST

This commit is contained in:
Cathie So
2022-06-25 23:49:13 +08:00
parent 52e9350681
commit 8b339b41f8
17 changed files with 932 additions and 108 deletions

39
test/SumPooling2D.js Normal file
View File

@@ -0,0 +1,39 @@
const chai = require("chai");
const { Console } = require("console");
const path = require("path");
const wasm_tester = require("circom_tester").wasm;
const F1Field = require("ffjavascript").F1Field;
const Scalar = require("ffjavascript").Scalar;
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const Fr = new F1Field(exports.p);
const assert = chai.assert;
const json = require("../models/sumPooling2D_input.json");
const OUTPUT = require("../models/sumPooling2D_output.json");
describe("SumPooling2D layer test", function () {
this.timeout(100000000);
it("(5,5,3) -> (2,2,3)", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "SumPooling2D_test.circom"));
await circuit.loadConstraints();
assert.equal(circuit.nVars, 76);
assert.equal(circuit.constraints.length, 0);
const INPUT = {
"in": json.in
}
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<2*2*3; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(2));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(2));
}
});
});

View File

@@ -0,0 +1,5 @@
pragma circom 2.0.3;
include "../../circuits/SumPooling2D.circom";
component main = SumPooling2D(5, 5, 3, 2);

View File

@@ -0,0 +1,108 @@
pragma circom 2.0.3;
include "../../circuits/Conv2D.circom";
include "../../circuits/Dense.circom";
include "../../circuits/ArgMax.circom";
include "../../circuits/Poly.circom";
include "../../circuits/SumPooling2D.circom";
template mnist_convnet() {
signal input in[28][28][1];
signal input conv2d_1_weights[3][3][1][4];
signal input conv2d_1_bias[4];
signal input conv2d_2_weights[3][3][4][8];
signal input conv2d_2_bias[8];
signal input dense_weights[200][10];
signal input dense_bias[10];
signal output out;
component conv2d_1 = Conv2D(28,28,1,4,3);
component poly_1[26][26][4];
component sum2d_1 = SumPooling2D(26,26,4,2);
component conv2d_2 = Conv2D(13,13,4,8,3);
component poly_2[11][11][8];
component sum2d_2 = SumPooling2D(11,11,8,2);
component dense = Dense(200,10);
component argmax = ArgMax(10);
for (var i=0; i<28; i++) {
for (var j=0; j<28; j++) {
conv2d_1.in[i][j][0] <== in[i][j][0];
}
}
for (var m=0; m<4; m++) {
for (var i=0; i<3; i++) {
for (var j=0; j<3; j++) {
conv2d_1.weights[i][j][0][m] <== conv2d_1_weights[i][j][0][m];
}
}
conv2d_1.bias[m] <== conv2d_1_bias[m];
}
for (var i=0; i<26; i++) {
for (var j=0; j<26; j++) {
for (var k=0; k<4; k++) {
poly_1[i][j][k] = Poly(10**6);
poly_1[i][j][k].in <== conv2d_1.out[i][j][k];
sum2d_1.in[i][j][k] <== poly_1[i][j][k].out;
}
}
}
for (var i=0; i<13; i++) {
for (var j=0; j<13; j++) {
for (var k=0; k<4; k++) {
conv2d_2.in[i][j][k] <== sum2d_1.out[i][j][k];
}
}
}
for (var m=0; m<8; m++) {
for (var i=0; i<3; i++) {
for (var j=0; j<3; j++) {
for (var k=0; k<4; k++) {
conv2d_2.weights[i][j][k][m] <== conv2d_2_weights[i][j][k][m];
}
}
}
conv2d_2.bias[m] <== conv2d_2_bias[m];
}
for (var i=0; i<11; i++) {
for (var j=0; j<11; j++) {
for (var k=0; k<8; k++) {
poly_2[i][j][k] = Poly(10**15);
poly_2[i][j][k].in <== conv2d_2.out[i][j][k];
sum2d_2.in[i][j][k] <== poly_2[i][j][k].out;
}
}
}
var idx = 0;
for (var i=0; i<5; i++) {
for (var j=0; j<5; j++) {
for (var k=0; k<8; k++) {
dense.in[idx] <== sum2d_2.out[i][j][k];
for (var m=0; m<10; m++) {
dense.weights[idx][m] <== dense_weights[idx][m];
}
idx++;
}
}
}
for (var i=0; i<10; i++) {
dense.bias[i] <== dense_bias[i];
}
for (var i=0; i<10; i++) {
log(dense.out[i]);
argmax.in[i] <== dense.out[i];
}
out <== argmax.out;
}
component main = mnist_convnet();

45
test/mnist_convnet.js Normal file
View File

@@ -0,0 +1,45 @@
const chai = require("chai");
const path = require("path");
const wasm_tester = require("circom_tester").wasm;
const F1Field = require("ffjavascript").F1Field;
const Scalar = require("ffjavascript").Scalar;
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const Fr = new F1Field(exports.p);
const assert = chai.assert;
const json = require("../models/mnist_convnet_input.json");
describe("mnist convnet test", function () {
this.timeout(100000000);
it("should return correct output", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_convnet_test.circom"));
await circuit.loadConstraints();
assert.equal(circuit.nVars, 70524);
assert.equal(circuit.constraints.length, 67403);
let INPUT = {};
for (const [key, value] of Object.entries(json)) {
if (Array.isArray(value)) {
let tmpArray = [];
for (let i = 0; i < value.flat().length; i++) {
tmpArray.push(Fr.e(value.flat()[i]));
}
INPUT[key] = tmpArray;
} else {
INPUT[key] = Fr.e(value);
}
}
const witness = await circuit.calculateWitness(INPUT, true);
//console.log(witness[1]);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
assert(Fr.eq(Fr.e(witness[1]),Fr.e(7)));
});
});