update tests

This commit is contained in:
Cathie So
2022-12-09 12:16:13 +08:00
parent 0a0c3a710f
commit 05bf4645d0
7 changed files with 130 additions and 78 deletions

View File

@@ -1,5 +1,4 @@
const chai = require("chai"); const chai = require("chai");
const { Console } = require("console");
const path = require("path"); const path = require("path");
const wasm_tester = require("circom_tester").wasm; const wasm_tester = require("circom_tester").wasm;
@@ -34,10 +33,18 @@ describe("AveragePooling2D layer test", function () {
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<2*2*3; i++) { let ape = 0;
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(500));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(500)); for (var i=0; i<OUTPUT.out.length; i++) {
console.log("actual", OUTPUT.out[i], "predicted", Fr.toString(witness[i+1]));
ape += Math.abs((OUTPUT.out[i]-parseInt(Fr.toString(witness[i+1])))/OUTPUT.out[i]);
} }
const mape = ape/OUTPUT.out.length;
console.log("mean absolute % error", mape);
assert(mape < 0.01);
}); });
// AveragePooling with strides!=poolSize // AveragePooling with strides!=poolSize
@@ -55,9 +62,17 @@ describe("AveragePooling2D layer test", function () {
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<4*4*3; i++) { let ape = 0;
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(1000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(1000)); for (var i=0; i<OUTPUT.out.length; i++) {
console.log("actual", OUTPUT.out[i], "predicted", Fr.toString(witness[i+1]));
ape += Math.abs((OUTPUT.out[i]-parseInt(Fr.toString(witness[i+1])))/OUTPUT.out[i]);
} }
const mape = ape/OUTPUT.out.length;
console.log("mean absolute % error", mape);
assert(mape < 0.01);
}); });
}); });

View File

@@ -1,5 +1,4 @@
const chai = require("chai"); const chai = require("chai");
const { Console } = require("console");
const path = require("path"); const path = require("path");
const wasm_tester = require("circom_tester").wasm; const wasm_tester = require("circom_tester").wasm;
@@ -22,27 +21,35 @@ describe("BatchNormalization layer test", function () {
const circuit = await wasm_tester(path.join(__dirname, "circuits", "batchNormalization_test.circom")); const circuit = await wasm_tester(path.join(__dirname, "circuits", "batchNormalization_test.circom"));
const a = []; let INPUT = {};
const b = [];
for (var i=0; i<json.a.length; i++) { for (const [key, value] of Object.entries(json)) {
a.push(Fr.e(json.a[i])); if (Array.isArray(value)) {
b.push(Fr.e(json.b[i])); let tmpArray = [];
} for (let i = 0; i < value.flat().length; i++) {
tmpArray.push(Fr.e(value.flat()[i]));
const INPUT = { }
"in": json.in, INPUT[key] = tmpArray;
"a": a, } else {
"b": b INPUT[key] = Fr.e(value);
}
} }
const witness = await circuit.calculateWitness(INPUT, true); const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<5*5*3; i++) { let ape = 0;
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(1000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(1000)); for (var i=0; i<OUTPUT.out.length; i++) {
// console.log("actual", OUTPUT.out[i], "predicted", Fr.toString(witness[i+1]));
ape += Math.abs((OUTPUT.out[i]-parseInt(Fr.toString(witness[i+1])))/OUTPUT.out[i]);
} }
const mape = ape/OUTPUT.out.length;
console.log("mean absolute % error", mape);
assert(mape < 0.01);
}); });
}); });

View File

@@ -1,5 +1,4 @@
const chai = require("chai"); const chai = require("chai");
const { Console } = require("console");
const path = require("path"); const path = require("path");
const wasm_tester = require("circom_tester").wasm; const wasm_tester = require("circom_tester").wasm;
@@ -41,9 +40,17 @@ describe("Conv1D layer test", function () {
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<6*2; i++) { let ape = 0;
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(5000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000)); for (var i=0; i<OUTPUT.out.length; i++) {
// console.log("actual", OUTPUT.out[i], "predicted", Fr.toString(witness[i+1]));
ape += Math.abs((OUTPUT.out[i]-parseInt(Fr.toString(witness[i+1])))/OUTPUT.out[i]);
} }
const mape = ape/OUTPUT.out.length;
console.log("mean absolute % error", mape);
assert(mape < 0.01);
}); });
}); });

View File

@@ -1,5 +1,4 @@
const chai = require("chai"); const chai = require("chai");
const { Console } = require("console");
const path = require("path"); const path = require("path");
const wasm_tester = require("circom_tester").wasm; const wasm_tester = require("circom_tester").wasm;
@@ -21,30 +20,37 @@ describe("Conv2D layer test", function () {
let OUTPUT = require("../models/conv2D_output.json"); let OUTPUT = require("../models/conv2D_output.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2D_test.circom")); const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2D_test.circom"));
//await circuit.loadConstraints();
//assert.equal(circuit.nVars, 618);
//assert.equal(circuit.constraints.length, 486);
const weights = []; let INPUT = {};
for (var i=0; i<json.weights.length; i++) { for (const [key, value] of Object.entries(json)) {
weights.push(Fr.e(json.weights[i])); if (Array.isArray(value)) {
} let tmpArray = [];
for (let i = 0; i < value.flat().length; i++) {
const INPUT = { tmpArray.push(Fr.e(value.flat()[i]));
"in": json.in, }
"weights": weights, INPUT[key] = tmpArray;
"bias": ["0","0"] } else {
INPUT[key] = Fr.e(value);
}
} }
const witness = await circuit.calculateWitness(INPUT, true); const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<3*3*2; i++) { let ape = 0;
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(5000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000)); for (var i=0; i<OUTPUT.out.length; i++) {
// console.log("actual", OUTPUT.out[i], "predicted", Fr.toString(witness[i+1]));
ape += Math.abs((OUTPUT.out[i]-parseInt(Fr.toString(witness[i+1])))/OUTPUT.out[i]);
} }
const mape = ape/OUTPUT.out.length;
console.log("mean absolute % error", mape);
assert(mape < 0.01);
}); });
it("(10,10,3) -> (3,3,2)", async () => { it("(10,10,3) -> (3,3,2)", async () => {
@@ -52,29 +58,36 @@ describe("Conv2D layer test", function () {
let OUTPUT = require("../models/conv2D_stride_output.json"); let OUTPUT = require("../models/conv2D_stride_output.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2D_stride_test.circom")); const circuit = await wasm_tester(path.join(__dirname, "circuits", "Conv2D_stride_test.circom"));
//await circuit.loadConstraints();
//assert.equal(circuit.nVars, 618);
//assert.equal(circuit.constraints.length, 486);
const weights = []; let INPUT = {};
for (var i=0; i<json.weights.length; i++) { for (const [key, value] of Object.entries(json)) {
weights.push(Fr.e(json.weights[i])); if (Array.isArray(value)) {
} let tmpArray = [];
for (let i = 0; i < value.flat().length; i++) {
const INPUT = { tmpArray.push(Fr.e(value.flat()[i]));
"in": json.in, }
"weights": weights, INPUT[key] = tmpArray;
"bias": ["0","0"] } else {
INPUT[key] = Fr.e(value);
}
} }
const witness = await circuit.calculateWitness(INPUT, true); const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<3*3*2; i++) { let ape = 0;
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(5000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000)); for (var i=0; i<OUTPUT.out.length; i++) {
// console.log("actual", OUTPUT.out[i], "predicted", Fr.toString(witness[i+1]));
ape += Math.abs((OUTPUT.out[i]-parseInt(Fr.toString(witness[i+1])))/OUTPUT.out[i]);
} }
const mape = ape/OUTPUT.out.length;
console.log("mean absolute % error", mape);
assert(mape < 0.01);
}); });
}); });

View File

@@ -1,5 +1,4 @@
const chai = require("chai"); const chai = require("chai");
const { Console } = require("console");
const path = require("path"); const path = require("path");
const wasm_tester = require("circom_tester").wasm; const wasm_tester = require("circom_tester").wasm;
@@ -30,9 +29,8 @@ describe("Flatten2D layer test", function () {
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<75; i++) { for (var i=0; i<OUTPUT.out.length; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(5000)); assert(Fr.eq(Fr.e(OUTPUT.out[i]), witness[i+1]));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000));
} }
}); });
}); });

View File

@@ -1,5 +1,4 @@
const chai = require("chai"); const chai = require("chai");
const { Console } = require("console");
const path = require("path"); const path = require("path");
const wasm_tester = require("circom_tester").wasm; const wasm_tester = require("circom_tester").wasm;
@@ -34,9 +33,8 @@ describe("MaxPooling2D layer test", function () {
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<2*2*3; i++) { for (var i=0; i<OUTPUT.out.length; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(1)); assert(Fr.eq(Fr.e(OUTPUT.out[i]),witness[i+1]));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(1));
} }
}); });
@@ -55,9 +53,8 @@ describe("MaxPooling2D layer test", function () {
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<3*3*3; i++) { for (var i=0; i<OUTPUT.out.length; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(1)); assert(Fr.eq(Fr.e(OUTPUT.out[i]),witness[i+1]));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(1));
} }
}); });
}); });

View File

@@ -1,5 +1,4 @@
const chai = require("chai"); const chai = require("chai");
const { Console } = require("console");
const path = require("path"); const path = require("path");
const wasm_tester = require("circom_tester").wasm; const wasm_tester = require("circom_tester").wasm;
@@ -34,10 +33,18 @@ describe("SumPooling2D layer test", function () {
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<2*2*3; i++) { let ape = 0;
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(2));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(2)); for (var i=0; i<OUTPUT.out.length; i++) {
// console.log("actual", OUTPUT.out[i], "predicted", Fr.toString(witness[i+1]));
ape += Math.abs((OUTPUT.out[i]-parseInt(Fr.toString(witness[i+1])))/OUTPUT.out[i]);
} }
const mape = ape/OUTPUT.out.length;
console.log("mean absolute % error", mape);
assert(mape < 0.01);
}); });
// SumPooling with strides!=poolSize // SumPooling with strides!=poolSize
@@ -55,9 +62,17 @@ describe("SumPooling2D layer test", function () {
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1))); assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<3*3*3; i++) { let ape = 0;
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(3));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(3)); for (var i=0; i<OUTPUT.out.length; i++) {
// console.log("actual", OUTPUT.out[i], "predicted", Fr.toString(witness[i+1]));
ape += Math.abs((OUTPUT.out[i]-parseInt(Fr.toString(witness[i+1])))/OUTPUT.out[i]);
} }
const mape = ape/OUTPUT.out.length;
console.log("mean absolute % error", mape);
assert(mape < 0.01);
}); });
}); });