Ver 1.1.0 - Added Polynomial activation layer and test case, updated ArgMax to support up to 254 bits, updated README

This commit is contained in:
Cathie So
2022-06-03 00:00:16 +08:00
parent 8eeb665d27
commit 52e9350681
12 changed files with 662 additions and 74 deletions

View File

@@ -0,0 +1,61 @@
pragma circom 2.0.3;
include "../../circuits/Conv2D.circom";
include "../../circuits/Dense.circom";
include "../../circuits/ArgMax.circom";
include "../../circuits/Poly.circom";
template mnist_poly() {
signal input in[28][28][1];
signal input conv2d_weights[3][3][1][1];
signal input conv2d_bias[1];
signal input dense_weights[676][10];
signal input dense_bias[10];
signal output out;
component conv2d = Conv2D(28,28,1,1,3);
component poly[26*26];
component dense = Dense(676,10);
component argmax = ArgMax(10);
for (var i=0; i<28; i++) {
for (var j=0; j<28; j++) {
conv2d.in[i][j][0] <== in[i][j][0];
}
}
for (var i=0; i<3; i++) {
for (var j=0; j<3; j++) {
conv2d.weights[i][j][0][0] <== conv2d_weights[i][j][0][0];
}
}
conv2d.bias[0] <== conv2d_bias[0];
var idx = 0;
for (var i=0; i<26; i++) {
for (var j=0; j<26; j++) {
poly[idx] = Poly(10**18);
poly[idx].in <== conv2d.out[i][j][0];
dense.in[idx] <== poly[idx].out;
for (var k=0; k<10; k++) {
dense.weights[idx][k] <== dense_weights[idx][k];
}
idx++;
}
}
for (var i=0; i<10; i++) {
dense.bias[i] <== dense_bias[i];
}
for (var i=0; i<10; i++) {
log(dense.out[i]);
argmax.in[i] <== dense.out[i];
}
out <== argmax.out;
}
component main = mnist_poly();

View File

@@ -51,6 +51,7 @@ template mnist() {
}
for (var i=0; i<10; i++) {
log(dense.out[i]);
argmax.in[i] <== dense.out[i];
}

View File

@@ -18,8 +18,8 @@ describe("mnist test", function () {
it("should return correct output", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_test.circom"));
await circuit.loadConstraints();
assert.equal(circuit.nVars, 368866);
assert.equal(circuit.constraints.length, 362663);
assert.equal(circuit.nVars, 371086);
assert.equal(circuit.constraints.length, 364883);
const conv2d_weights = [];
const conv2d_bias = [];

60
test/mnist_poly.js Normal file
View File

@@ -0,0 +1,60 @@
const chai = require("chai");
const path = require("path");
const wasm_tester = require("circom_tester").wasm;
const F1Field = require("ffjavascript").F1Field;
const Scalar = require("ffjavascript").Scalar;
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const Fr = new F1Field(exports.p);
const assert = chai.assert;
const json = require("../models/mnist_poly_input.json");
describe("mnist poly test", function () {
this.timeout(100000000);
it("should return correct output", async () => {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_poly_test.circom"));
await circuit.loadConstraints();
assert.equal(circuit.nVars, 23622);
assert.equal(circuit.constraints.length, 16067);
const conv2d_weights = [];
const conv2d_bias = [];
const dense_weights = [];
const dense_bias = [];
for (var i=0; i<json.conv2d_weights.length; i++) {
conv2d_weights.push(Fr.e(json.conv2d_weights[i]));
}
for (var i=0; i<json.conv2d_bias.length; i++) {
conv2d_bias.push(Fr.e(json.conv2d_bias[i]));
}
for (var i=0; i<json.dense_weights.length; i++) {
dense_weights.push(Fr.e(json.dense_weights[i]));
}
for (var i=0; i<json.dense_bias.length; i++) {
dense_bias.push(Fr.e(json.dense_bias[i]));
}
const INPUT = {
"in": json.in,
"conv2d_weights": conv2d_weights,
"conv2d_bias": conv2d_bias,
"dense_weights": dense_weights,
"dense_bias": dense_bias
}
const witness = await circuit.calculateWitness(INPUT, true);
//console.log(witness[1]);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
assert(Fr.eq(Fr.e(witness[1]),Fr.e(7)));
});
});