AveragePooling2D layer

This commit is contained in:
Cathie So
2022-11-18 16:51:32 +08:00
parent 9c51c60bbf
commit a17cc6f145
9 changed files with 564 additions and 0 deletions

63
test/AveragePooling2D.js Normal file
View File

@@ -0,0 +1,63 @@
const chai = require("chai");
const { Console } = require("console");
const path = require("path");
const wasm_tester = require("circom_tester").wasm;
const F1Field = require("ffjavascript").F1Field;
const Scalar = require("ffjavascript").Scalar;
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const Fr = new F1Field(exports.p);
const assert = chai.assert;
describe("AveragePooling2D layer test", function () {
this.timeout(100000000);
// AveragePooling with strides==poolSize
it("(5,5,3) -> (2,2,3)", async () => {
const json = require("../models/averagePooling2D_input.json");
const OUTPUT = require("../models/averagePooling2D_output.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "AveragePooling2D_test.circom"));
//await circuit.loadConstraints();
//assert.equal(circuit.nVars, 76);
//assert.equal(circuit.constraints.length, 0);
const INPUT = {
"in": json.in
}
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<2*2*3; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(500));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(500));
}
});
// AveragePooling with strides!=poolSize
it("(10,10,3) -> (4,4,3)", async () => {
const json = require("../models/averagePooling2D_stride_input.json");
const OUTPUT = require("../models/averagePooling2D_stride_output.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "AveragePooling2D_stride_test.circom"));
const INPUT = {
"in": json.in
}
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<4*4*3; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(1000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(1000));
}
});
});

View File

@@ -0,0 +1,6 @@
pragma circom 2.0.3;
include "../../circuits/AveragePooling2D.circom";
// poolSize!=strides
component main = AveragePooling2D(10, 10, 3, 3, 2, 111);

View File

@@ -0,0 +1,6 @@
pragma circom 2.0.3;
include "../../circuits/AveragePooling2D.circom";
// poolSize=strides - default Keras settings
component main = AveragePooling2D(5, 5, 3, 2, 2, 250);