mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-09 14:08:04 -05:00
Flatten2D layer
This commit is contained in:
18
circuits/Flatten2D.circom
Normal file
18
circuits/Flatten2D.circom
Normal file
@@ -0,0 +1,18 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
// Conv2D layer with valid padding
|
||||
template Flatten2D (nRows, nCols, nChannels) {
|
||||
signal input in[nRows][nCols][nChannels];
|
||||
signal output out[nRows*nCols*nChannels];
|
||||
|
||||
var idx = 0;
|
||||
|
||||
for (var i=0; i<nRows; i++) {
|
||||
for (var j=0; j<nCols; j++) {
|
||||
for (var k=0; k<nChannels; k++) {
|
||||
out[idx] <== in[i][j][k];
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
239
models/flatten.ipynb
Normal file
239
models/flatten.ipynb
Normal file
@@ -0,0 +1,239 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.keras.layers import Input, Flatten\n",
|
||||
"from tensorflow.keras import Model\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = Input(shape=(5,5,3))\n",
|
||||
"x = Flatten()(inputs)\n",
|
||||
"model = Model(inputs, x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
"Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
"input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n",
|
||||
"_________________________________________________________________\n",
|
||||
"flatten (Flatten) (None, 75) 0 \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 0\n",
|
||||
"Trainable params: 0\n",
|
||||
"Non-trainable params: 0\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[0.0852918 , 0.44266819, 0.55968985],\n",
|
||||
" [0.80456378, 0.7812197 , 0.93979271],\n",
|
||||
" [0.60230819, 0.56047002, 0.96776284],\n",
|
||||
" [0.6283819 , 0.8851144 , 0.84297738],\n",
|
||||
" [0.32538761, 0.69431566, 0.41653711]],\n",
|
||||
"\n",
|
||||
" [[0.88323994, 0.49007438, 0.52891214],\n",
|
||||
" [0.75972182, 0.85386461, 0.64519541],\n",
|
||||
" [0.7372128 , 0.34378423, 0.89210331],\n",
|
||||
" [0.68795543, 0.15510637, 0.84101805],\n",
|
||||
" [0.34090138, 0.49322424, 0.00737098]],\n",
|
||||
"\n",
|
||||
" [[0.59636987, 0.80149115, 0.3530605 ],\n",
|
||||
" [0.06466776, 0.40315287, 0.15753059],\n",
|
||||
" [0.54568182, 0.95573263, 0.64114777],\n",
|
||||
" [0.6558953 , 0.10547539, 0.82302922],\n",
|
||||
" [0.60415623, 0.58044333, 0.46934783]],\n",
|
||||
"\n",
|
||||
" [[0.89339205, 0.49514657, 0.66308424],\n",
|
||||
" [0.23049011, 0.71922777, 0.19032885],\n",
|
||||
" [0.23228222, 0.16731365, 0.89744304],\n",
|
||||
" [0.64359666, 0.4594629 , 0.11503616],\n",
|
||||
" [0.62930732, 0.09412582, 0.04021055]],\n",
|
||||
"\n",
|
||||
" [[0.40104213, 0.9882606 , 0.20996853],\n",
|
||||
" [0.44420542, 0.47306763, 0.98680773],\n",
|
||||
" [0.95270149, 0.97320959, 0.54052338],\n",
|
||||
" [0.04304848, 0.31208349, 0.9046649 ],\n",
|
||||
" [0.00495649, 0.39177585, 0.67277488]]]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = np.random.rand(1,5,5,3)\n",
|
||||
"X"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[0.0852918 , 0.4426682 , 0.5596899 , 0.80456376, 0.7812197 ,\n",
|
||||
" 0.9397927 , 0.6023082 , 0.56047004, 0.9677628 , 0.6283819 ,\n",
|
||||
" 0.8851144 , 0.8429774 , 0.3253876 , 0.6943157 , 0.4165371 ,\n",
|
||||
" 0.8832399 , 0.4900744 , 0.5289121 , 0.7597218 , 0.8538646 ,\n",
|
||||
" 0.6451954 , 0.7372128 , 0.3437842 , 0.8921033 , 0.68795544,\n",
|
||||
" 0.15510637, 0.8410181 , 0.34090137, 0.49322423, 0.00737098,\n",
|
||||
" 0.59636986, 0.80149114, 0.35306048, 0.06466776, 0.40315288,\n",
|
||||
" 0.15753059, 0.54568183, 0.95573264, 0.6411478 , 0.6558953 ,\n",
|
||||
" 0.10547539, 0.8230292 , 0.60415626, 0.5804433 , 0.46934783,\n",
|
||||
" 0.893392 , 0.49514657, 0.6630842 , 0.2304901 , 0.7192278 ,\n",
|
||||
" 0.19032885, 0.23228222, 0.16731365, 0.89744306, 0.64359665,\n",
|
||||
" 0.4594629 , 0.11503616, 0.6293073 , 0.09412582, 0.04021055,\n",
|
||||
" 0.40104213, 0.98826057, 0.20996854, 0.44420543, 0.47306764,\n",
|
||||
" 0.9868077 , 0.9527015 , 0.9732096 , 0.54052335, 0.04304848,\n",
|
||||
" 0.31208348, 0.9046649 , 0.00495649, 0.39177585, 0.6727749 ]],\n",
|
||||
" dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y = model.predict(X)\n",
|
||||
"y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_json = {\n",
|
||||
" \"in\": (X*1000).round().astype(int).flatten().tolist()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out_json = {\n",
|
||||
" \"out\": (y*1000).round().astype(int).flatten().tolist()\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"flatten2D_input.json\", \"w\") as f:\n",
|
||||
" json.dump(in_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"flatten2D_output.json\", \"w\") as f:\n",
|
||||
" json.dump(out_json, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "tf24",
|
||||
"language": "python",
|
||||
"name": "tf24"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.6"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
1
models/flatten2D_input.json
Normal file
1
models/flatten2D_input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"in": [85, 443, 560, 805, 781, 940, 602, 560, 968, 628, 885, 843, 325, 694, 417, 883, 490, 529, 760, 854, 645, 737, 344, 892, 688, 155, 841, 341, 493, 7, 596, 801, 353, 65, 403, 158, 546, 956, 641, 656, 105, 823, 604, 580, 469, 893, 495, 663, 230, 719, 190, 232, 167, 897, 644, 459, 115, 629, 94, 40, 401, 988, 210, 444, 473, 987, 953, 973, 541, 43, 312, 905, 5, 392, 673]}
|
||||
1
models/flatten2D_output.json
Normal file
1
models/flatten2D_output.json
Normal file
@@ -0,0 +1 @@
|
||||
{"out": [85, 443, 560, 805, 781, 940, 602, 560, 968, 628, 885, 843, 325, 694, 417, 883, 490, 529, 760, 854, 645, 737, 344, 892, 688, 155, 841, 341, 493, 7, 596, 801, 353, 65, 403, 158, 546, 956, 641, 656, 105, 823, 604, 580, 469, 893, 495, 663, 230, 719, 190, 232, 167, 897, 644, 459, 115, 629, 94, 40, 401, 988, 210, 444, 473, 987, 953, 973, 541, 43, 312, 905, 5, 392, 673]}
|
||||
38
test/Flatten2D.js
Normal file
38
test/Flatten2D.js
Normal file
@@ -0,0 +1,38 @@
|
||||
const chai = require("chai");
|
||||
const { Console } = require("console");
|
||||
const path = require("path");
|
||||
|
||||
const wasm_tester = require("circom_tester").wasm;
|
||||
|
||||
const F1Field = require("ffjavascript").F1Field;
|
||||
const Scalar = require("ffjavascript").Scalar;
|
||||
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
|
||||
const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
|
||||
|
||||
describe("Flatten2D layer test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
it("(5,5,3) -> 75", async () => {
|
||||
let json = require("../models/flatten2D_input.json");
|
||||
let OUTPUT = require("../models/flatten2D_output.json");
|
||||
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "flatten2D_test.circom"));
|
||||
|
||||
const INPUT = {
|
||||
"in": json.in
|
||||
}
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
|
||||
for (var i=0; i<75; i++) {
|
||||
assert((witness[i+1]-Fr.e(OUTPUT.out[i]))<Fr.e(5000));
|
||||
assert((Fr.e(OUTPUT.out[i])-witness[i+1])<Fr.e(5000));
|
||||
}
|
||||
});
|
||||
});
|
||||
5
test/circuits/Flatten2D_test.circom
Normal file
5
test/circuits/Flatten2D_test.circom
Normal file
@@ -0,0 +1,5 @@
|
||||
pragma circom 2.0.3;
|
||||
|
||||
include "../../circuits/Flatten2D.circom";
|
||||
|
||||
component main = Flatten2D(5, 5, 3);
|
||||
Reference in New Issue
Block a user