mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-09 14:08:04 -05:00
Ver 1.1.0 - Added Polynomial activation layer and test case, updated ArgMax to support up to 254 bits, updated README
This commit is contained in:
61
test/circuits/mnist_poly_test.circom
Normal file
61
test/circuits/mnist_poly_test.circom
Normal file
@@ -0,0 +1,61 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/Conv2D.circom";
|
||||
include "../../circuits/Dense.circom";
|
||||
include "../../circuits/ArgMax.circom";
|
||||
include "../../circuits/Poly.circom";
|
||||
|
||||
template mnist_poly() {
|
||||
signal input in[28][28][1];
|
||||
signal input conv2d_weights[3][3][1][1];
|
||||
signal input conv2d_bias[1];
|
||||
signal input dense_weights[676][10];
|
||||
signal input dense_bias[10];
|
||||
signal output out;
|
||||
|
||||
component conv2d = Conv2D(28,28,1,1,3);
|
||||
component poly[26*26];
|
||||
component dense = Dense(676,10);
|
||||
component argmax = ArgMax(10);
|
||||
|
||||
for (var i=0; i<28; i++) {
|
||||
for (var j=0; j<28; j++) {
|
||||
conv2d.in[i][j][0] <== in[i][j][0];
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<3; i++) {
|
||||
for (var j=0; j<3; j++) {
|
||||
conv2d.weights[i][j][0][0] <== conv2d_weights[i][j][0][0];
|
||||
}
|
||||
}
|
||||
|
||||
conv2d.bias[0] <== conv2d_bias[0];
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<26; i++) {
|
||||
for (var j=0; j<26; j++) {
|
||||
poly[idx] = Poly(10**18);
|
||||
poly[idx].in <== conv2d.out[i][j][0];
|
||||
dense.in[idx] <== poly[idx].out;
|
||||
for (var k=0; k<10; k++) {
|
||||
dense.weights[idx][k] <== dense_weights[idx][k];
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
dense.bias[i] <== dense_bias[i];
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
log(dense.out[i]);
|
||||
argmax.in[i] <== dense.out[i];
|
||||
}
|
||||
|
||||
out <== argmax.out;
|
||||
}
|
||||
|
||||
component main = mnist_poly();
|
||||
@@ -51,6 +51,7 @@ template mnist() {
|
||||
}
|
||||
|
||||
for (var i=0; i<10; i++) {
|
||||
log(dense.out[i]);
|
||||
argmax.in[i] <== dense.out[i];
|
||||
}
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ describe("mnist test", function () {
|
||||
it("should return correct output", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_test.circom"));
|
||||
await circuit.loadConstraints();
|
||||
assert.equal(circuit.nVars, 368866);
|
||||
assert.equal(circuit.constraints.length, 362663);
|
||||
assert.equal(circuit.nVars, 371086);
|
||||
assert.equal(circuit.constraints.length, 364883);
|
||||
|
||||
const conv2d_weights = [];
|
||||
const conv2d_bias = [];
|
||||
|
||||
60
test/mnist_poly.js
Normal file
60
test/mnist_poly.js
Normal file
@@ -0,0 +1,60 @@
|
||||
const chai = require("chai");
|
||||
const path = require("path");
|
||||
|
||||
const wasm_tester = require("circom_tester").wasm;
|
||||
|
||||
const F1Field = require("ffjavascript").F1Field;
|
||||
const Scalar = require("ffjavascript").Scalar;
|
||||
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
|
||||
const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
const json = require("../models/mnist_poly_input.json");
|
||||
|
||||
describe("mnist poly test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
it("should return correct output", async () => {
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "mnist_poly_test.circom"));
|
||||
await circuit.loadConstraints();
|
||||
assert.equal(circuit.nVars, 23622);
|
||||
assert.equal(circuit.constraints.length, 16067);
|
||||
|
||||
const conv2d_weights = [];
|
||||
const conv2d_bias = [];
|
||||
const dense_weights = [];
|
||||
const dense_bias = [];
|
||||
|
||||
for (var i=0; i<json.conv2d_weights.length; i++) {
|
||||
conv2d_weights.push(Fr.e(json.conv2d_weights[i]));
|
||||
}
|
||||
|
||||
for (var i=0; i<json.conv2d_bias.length; i++) {
|
||||
conv2d_bias.push(Fr.e(json.conv2d_bias[i]));
|
||||
}
|
||||
|
||||
for (var i=0; i<json.dense_weights.length; i++) {
|
||||
dense_weights.push(Fr.e(json.dense_weights[i]));
|
||||
}
|
||||
|
||||
for (var i=0; i<json.dense_bias.length; i++) {
|
||||
dense_bias.push(Fr.e(json.dense_bias[i]));
|
||||
}
|
||||
|
||||
const INPUT = {
|
||||
"in": json.in,
|
||||
"conv2d_weights": conv2d_weights,
|
||||
"conv2d_bias": conv2d_bias,
|
||||
"dense_weights": dense_weights,
|
||||
"dense_bias": dense_bias
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
//console.log(witness[1]);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
assert(Fr.eq(Fr.e(witness[1]),Fr.e(7)));
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user