Flatten2D layer

This commit is contained in:
Cathie So
2022-11-11 20:13:47 +08:00
parent adb9eddff0
commit 6c0254308f
6 changed files with 302 additions and 0 deletions

18
circuits/Flatten2D.circom Normal file
View File

@@ -0,0 +1,18 @@
pragma circom 2.0.3;
// Conv2D layer with valid padding
template Flatten2D (nRows, nCols, nChannels) {
signal input in[nRows][nCols][nChannels];
signal output out[nRows*nCols*nChannels];
var idx = 0;
for (var i=0; i<nRows; i++) {
for (var j=0; j<nCols; j++) {
for (var k=0; k<nChannels; k++) {
out[idx] <== in[i][j][k];
idx++;
}
}
}
}

239
models/flatten.ipynb Normal file
View File

@@ -0,0 +1,239 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Input, Flatten\n",
"from tensorflow.keras import Model\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(5,5,3))\n",
"x = Flatten()(inputs)\n",
"model = Model(inputs, x)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n",
"_________________________________________________________________\n",
"flatten (Flatten) (None, 75) 0 \n",
"=================================================================\n",
"Total params: 0\n",
"Trainable params: 0\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.weights"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.0852918 , 0.44266819, 0.55968985],\n",
" [0.80456378, 0.7812197 , 0.93979271],\n",
" [0.60230819, 0.56047002, 0.96776284],\n",
" [0.6283819 , 0.8851144 , 0.84297738],\n",
" [0.32538761, 0.69431566, 0.41653711]],\n",
"\n",
" [[0.88323994, 0.49007438, 0.52891214],\n",
" [0.75972182, 0.85386461, 0.64519541],\n",
" [0.7372128 , 0.34378423, 0.89210331],\n",
" [0.68795543, 0.15510637, 0.84101805],\n",
" [0.34090138, 0.49322424, 0.00737098]],\n",
"\n",
" [[0.59636987, 0.80149115, 0.3530605 ],\n",
" [0.06466776, 0.40315287, 0.15753059],\n",
" [0.54568182, 0.95573263, 0.64114777],\n",
" [0.6558953 , 0.10547539, 0.82302922],\n",
" [0.60415623, 0.58044333, 0.46934783]],\n",
"\n",
" [[0.89339205, 0.49514657, 0.66308424],\n",
" [0.23049011, 0.71922777, 0.19032885],\n",
" [0.23228222, 0.16731365, 0.89744304],\n",
" [0.64359666, 0.4594629 , 0.11503616],\n",
" [0.62930732, 0.09412582, 0.04021055]],\n",
"\n",
" [[0.40104213, 0.9882606 , 0.20996853],\n",
" [0.44420542, 0.47306763, 0.98680773],\n",
" [0.95270149, 0.97320959, 0.54052338],\n",
" [0.04304848, 0.31208349, 0.9046649 ],\n",
" [0.00495649, 0.39177585, 0.67277488]]]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,5,5,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.0852918 , 0.4426682 , 0.5596899 , 0.80456376, 0.7812197 ,\n",
" 0.9397927 , 0.6023082 , 0.56047004, 0.9677628 , 0.6283819 ,\n",
" 0.8851144 , 0.8429774 , 0.3253876 , 0.6943157 , 0.4165371 ,\n",
" 0.8832399 , 0.4900744 , 0.5289121 , 0.7597218 , 0.8538646 ,\n",
" 0.6451954 , 0.7372128 , 0.3437842 , 0.8921033 , 0.68795544,\n",
" 0.15510637, 0.8410181 , 0.34090137, 0.49322423, 0.00737098,\n",
" 0.59636986, 0.80149114, 0.35306048, 0.06466776, 0.40315288,\n",
" 0.15753059, 0.54568183, 0.95573264, 0.6411478 , 0.6558953 ,\n",
" 0.10547539, 0.8230292 , 0.60415626, 0.5804433 , 0.46934783,\n",
" 0.893392 , 0.49514657, 0.6630842 , 0.2304901 , 0.7192278 ,\n",
" 0.19032885, 0.23228222, 0.16731365, 0.89744306, 0.64359665,\n",
" 0.4594629 , 0.11503616, 0.6293073 , 0.09412582, 0.04021055,\n",
" 0.40104213, 0.98826057, 0.20996854, 0.44420543, 0.47306764,\n",
" 0.9868077 , 0.9527015 , 0.9732096 , 0.54052335, 0.04304848,\n",
" 0.31208348, 0.9046649 , 0.00495649, 0.39177585, 0.6727749 ]],\n",
" dtype=float32)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": (X*1000).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"out_json = {\n",
" \"out\": (y*1000).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"with open(\"flatten2D_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"with open(\"flatten2D_output.json\", \"w\") as f:\n",
" json.dump(out_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf24",
"language": "python",
"name": "tf24"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1 @@
{"in": [85, 443, 560, 805, 781, 940, 602, 560, 968, 628, 885, 843, 325, 694, 417, 883, 490, 529, 760, 854, 645, 737, 344, 892, 688, 155, 841, 341, 493, 7, 596, 801, 353, 65, 403, 158, 546, 956, 641, 656, 105, 823, 604, 580, 469, 893, 495, 663, 230, 719, 190, 232, 167, 897, 644, 459, 115, 629, 94, 40, 401, 988, 210, 444, 473, 987, 953, 973, 541, 43, 312, 905, 5, 392, 673]}

View File

@@ -0,0 +1 @@
{"out": [85, 443, 560, 805, 781, 940, 602, 560, 968, 628, 885, 843, 325, 694, 417, 883, 490, 529, 760, 854, 645, 737, 344, 892, 688, 155, 841, 341, 493, 7, 596, 801, 353, 65, 403, 158, 546, 956, 641, 656, 105, 823, 604, 580, 469, 893, 495, 663, 230, 719, 190, 232, 167, 897, 644, 459, 115, 629, 94, 40, 401, 988, 210, 444, 473, 987, 953, 973, 541, 43, 312, 905, 5, 392, 673]}

38
test/Flatten2D.js Normal file
View File

@@ -0,0 +1,38 @@
const chai = require("chai");
const { Console } = require("console");
const path = require("path");
const wasm_tester = require("circom_tester").wasm;
const F1Field = require("ffjavascript").F1Field;
const Scalar = require("ffjavascript").Scalar;
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
const Fr = new F1Field(exports.p);
const assert = chai.assert;
describe("Flatten2D layer test", function () {
this.timeout(100000000);
it("(5,5,3) -> 75", async () => {
let json = require("../models/flatten2D_input.json");
let OUTPUT = require("../models/flatten2D_output.json");
const circuit = await wasm_tester(path.join(__dirname, "circuits", "flatten2D_test.circom"));
const INPUT = {
"in": json.in
}
const witness = await circuit.calculateWitness(INPUT, true);
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
for (var i=0; i<75; i++) {
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(5000));
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000));
}
});
});

View File

@@ -0,0 +1,5 @@
pragma circom 2.0.3;
include "../../circuits/Flatten2D.circom";
component main = Flatten2D(5, 5, 3);