mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-09 14:08:04 -05:00
Add UpSampling2D circuit and test files
This commit is contained in:
19
circuits/UpSampling2D.circom
Normal file
19
circuits/UpSampling2D.circom
Normal file
@@ -0,0 +1,19 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
template UpSampling2D(nRows, nCols, nChannels, size) {
|
||||
signal input in[nRows][nCols][nChannels];
|
||||
signal input out[nRows * size][nCols * size][nChannels];
|
||||
|
||||
// nearest neighbor interpolation
|
||||
for (var i = 0; i < nRows; i++) {
|
||||
for (var j = 0; j < nCols; j++) {
|
||||
for (var c = 0; c < nChannels; c++) {
|
||||
for (var k = 0; k < size; k++) {
|
||||
for (var l = 0; l < size; l++) {
|
||||
out[i * size + k][j * size + l][c] === in[i][j][c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1
models/upSampling2D_input.json
Normal file
1
models/upSampling2D_input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"in": [[[842674, 497907, 66624], [875287, 832625, 34934]]], "out": [[[842674, 497907, 66624], [842674, 497907, 66624], [875287, 832625, 34934], [875287, 832625, 34934]], [[842674, 497907, 66624], [842674, 497907, 66624], [875287, 832625, 34934], [875287, 832625, 34934]]]}
|
||||
221
models/upSampling2d.ipynb
Normal file
221
models/upSampling2d.ipynb
Normal file
@@ -0,0 +1,221 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.keras.layers import Input, UpSampling2D\n",
|
||||
"from tensorflow.keras import Model\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = Input(shape=(1,2,3))\n",
|
||||
"x = UpSampling2D(size=2)(inputs)\n",
|
||||
"model = Model(inputs, x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"model_1\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
" Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
" input_2 (InputLayer) [(None, 1, 2, 3)] 0 \n",
|
||||
" \n",
|
||||
" up_sampling2d_1 (UpSampling (None, 2, 4, 3) 0 \n",
|
||||
" 2D) \n",
|
||||
" \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 0\n",
|
||||
"Trainable params: 0\n",
|
||||
"Non-trainable params: 0\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[842674, 497907, 66624],\n",
|
||||
" [875287, 832625, 34934]]]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = (np.random.rand(1,1,2,3)*1e6).astype(int)\n",
|
||||
"X"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1/1 [==============================] - 0s 27ms/step\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([[[[842674., 497907., 66624.],\n",
|
||||
" [842674., 497907., 66624.],\n",
|
||||
" [875287., 832625., 34934.],\n",
|
||||
" [875287., 832625., 34934.]],\n",
|
||||
"\n",
|
||||
" [[842674., 497907., 66624.],\n",
|
||||
" [842674., 497907., 66624.],\n",
|
||||
" [875287., 832625., 34934.],\n",
|
||||
" [875287., 832625., 34934.]]]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"y = model.predict(X)\n",
|
||||
"y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def UpSampling2DInt(nRows, nCols, nChannels, size, input):\n",
|
||||
" out = [[[None for _ in range(nChannels)] for _ in range(nCols*size)] for _ in range(nRows*size)]\n",
|
||||
" for i in range(nRows):\n",
|
||||
" for j in range(nCols):\n",
|
||||
" for c in range(nChannels):\n",
|
||||
" for k in range(size):\n",
|
||||
" for l in range(size):\n",
|
||||
" out[i*size+k][j*size+l][c] = input[i][j][c]\n",
|
||||
" return out\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_in = [[[int(X[0][i][j][k]) for k in range(3)] for j in range(2)] for i in range(1)]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[[[842674, 497907, 66624],\n",
|
||||
" [842674, 497907, 66624],\n",
|
||||
" [875287, 832625, 34934],\n",
|
||||
" [875287, 832625, 34934]],\n",
|
||||
" [[842674, 497907, 66624],\n",
|
||||
" [842674, 497907, 66624],\n",
|
||||
" [875287, 832625, 34934],\n",
|
||||
" [875287, 832625, 34934]]]"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"out = UpSampling2DInt(1, 2, 3, 2, X_in)\n",
|
||||
"out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"assert np.all(y[0].astype(int) == np.array(out))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"in_json = {\n",
|
||||
" \"in\": X_in,\n",
|
||||
" \"out\": out\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"with open(\"upSampling2D_input.json\", \"w\") as f:\n",
|
||||
" json.dump(in_json, f)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "keras2circom",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
28
test/UpSampling2D.js
Normal file
28
test/UpSampling2D.js
Normal file
@@ -0,0 +1,28 @@
|
||||
const chai = require("chai");
|
||||
const path = require("path");
|
||||
|
||||
const wasm_tester = require("circom_tester").wasm;
|
||||
|
||||
const F1Field = require("ffjavascript").F1Field;
|
||||
const Scalar = require("ffjavascript").Scalar;
|
||||
exports.p = Scalar.fromString("21888242871839275222246405745257275088548364400416034343698204186575808495617");
|
||||
const Fr = new F1Field(exports.p);
|
||||
|
||||
const assert = chai.assert;
|
||||
|
||||
|
||||
|
||||
describe("UpSampling2D layer test", function () {
|
||||
this.timeout(100000000);
|
||||
|
||||
// UpSampling with strides==poolSize
|
||||
it("(1,2,3) -> (2,4,3)", async () => {
|
||||
const INPUT = require("../models/upSampling2D_input.json");
|
||||
|
||||
const circuit = await wasm_tester(path.join(__dirname, "circuits", "UpSampling2D_test.circom"));
|
||||
|
||||
const witness = await circuit.calculateWitness(INPUT, true);
|
||||
|
||||
assert(Fr.eq(Fr.e(witness[0]),Fr.e(1)));
|
||||
});
|
||||
});
|
||||
5
test/circuits/UpSampling2D_test.circom
Normal file
5
test/circuits/UpSampling2D_test.circom
Normal file
@@ -0,0 +1,5 @@
|
||||
pragma circom 2.0.0;
|
||||
|
||||
include "../../circuits/UpSampling2D.circom";
|
||||
|
||||
component main = UpSampling2D(1, 2, 3, 2);
|
||||
Reference in New Issue
Block a user