Files
circomlib-ml/models/flatten.ipynb
drCathieSo.eth 30a43a2712 Version 2.0.0 (#5)
* feat: `Poly` renamed to `ZeLU` with scaling implemented

* fix: assertion in `ZeLU`

* feat: `AveragePooling2D` with scaling

* feat: `BatchNorm` with scaling

* feat: `Conv1D` with scaling

* feat: `Conv2D` with scaling

* feat: `Dense` with scaling

* fix: assertion in `Dense`

* feat: `GlobalAveragePooling2D` with scaling

* feat: input-only `ArgMax`

* feat: input-only `Flatten2D`

* feat: input-only `GlobalMaxPooling2D`

* feat: input-only `MaxPooling2D`

* feat: input-only `ReLU`

* test: precision up to 36 decimals

* chore: clean up

* test: model1 with 36 decimals

* fix: ReLU should use `p//2` as threshold

* test: clean up

* test: mnist model with 18 decimals

* build: Update package.json version to 2.0.0

* chore: Update README with warning message
2023-10-24 02:50:34 +07:00

235 lines
6.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Input, Flatten\n",
"from tensorflow.keras import Model\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"inputs = Input(shape=(5,5,3))\n",
"x = Flatten()(inputs)\n",
"model = Model(inputs, x)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_1 (InputLayer) [(None, 5, 5, 3)] 0 \n",
" \n",
" flatten (Flatten) (None, 75) 0 \n",
" \n",
"=================================================================\n",
"Total params: 0\n",
"Trainable params: 0\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.weights"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[0.9191584 , 0.41015604, 0.0493302 ],\n",
" [0.20412956, 0.14984944, 0.71595293],\n",
" [0.57980447, 0.28233206, 0.30881941],\n",
" [0.98703541, 0.91977126, 0.89591016],\n",
" [0.29365768, 0.89541076, 0.97098122]],\n",
"\n",
" [[0.28270309, 0.85760979, 0.12266525],\n",
" [0.2386079 , 0.93741419, 0.83312648],\n",
" [0.02935679, 0.68497567, 0.37248647],\n",
" [0.76807667, 0.72347087, 0.84375984],\n",
" [0.89233681, 0.87703334, 0.53846864]],\n",
"\n",
" [[0.14028452, 0.61585222, 0.34271206],\n",
" [0.45404173, 0.26365195, 0.05140719],\n",
" [0.36253999, 0.51529482, 0.15006 ],\n",
" [0.82061228, 0.08937872, 0.65234282],\n",
" [0.31024437, 0.09785702, 0.40629764]],\n",
"\n",
" [[0.75192339, 0.55825739, 0.86978978],\n",
" [0.76105885, 0.54160411, 0.72517187],\n",
" [0.28701856, 0.31868524, 0.46890464],\n",
" [0.0902 , 0.3022873 , 0.48529066],\n",
" [0.24453082, 0.93271481, 0.08555694]],\n",
"\n",
" [[0.52171579, 0.22363436, 0.85212827],\n",
" [0.9823001 , 0.64424366, 0.96495129],\n",
" [0.61750385, 0.53921774, 0.75703119],\n",
" [0.57267588, 0.18643057, 0.26532282],\n",
" [0.22546175, 0.0340469 , 0.19259163]]]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.random.rand(1,5,5,3)\n",
"X"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 95ms/step\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-10-23 17:09:53.715790: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n"
]
},
{
"data": {
"text/plain": [
"array([[0.9191584 , 0.41015604, 0.0493302 , 0.20412956, 0.14984943,\n",
" 0.71595293, 0.5798045 , 0.28233206, 0.3088194 , 0.9870354 ,\n",
" 0.91977125, 0.89591014, 0.2936577 , 0.8954108 , 0.97098124,\n",
" 0.2827031 , 0.8576098 , 0.12266525, 0.2386079 , 0.93741417,\n",
" 0.8331265 , 0.02935679, 0.6849757 , 0.37248647, 0.76807666,\n",
" 0.72347087, 0.84375983, 0.8923368 , 0.87703335, 0.53846866,\n",
" 0.14028452, 0.61585224, 0.34271204, 0.45404172, 0.26365197,\n",
" 0.05140718, 0.36253998, 0.5152948 , 0.15006 , 0.82061225,\n",
" 0.08937872, 0.6523428 , 0.31024438, 0.09785703, 0.40629762,\n",
" 0.7519234 , 0.5582574 , 0.8697898 , 0.76105887, 0.5416041 ,\n",
" 0.72517186, 0.28701857, 0.31868523, 0.46890464, 0.0902 ,\n",
" 0.3022873 , 0.48529068, 0.24453081, 0.9327148 , 0.08555695,\n",
" 0.5217158 , 0.22363436, 0.85212827, 0.9823001 , 0.64424366,\n",
" 0.9649513 , 0.6175039 , 0.5392177 , 0.7570312 , 0.5726759 ,\n",
" 0.18643057, 0.26532283, 0.22546175, 0.0340469 , 0.19259164]],\n",
" dtype=float32)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = model.predict(X)\n",
"y"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"in_json = {\n",
" \"in\": (X*1e36).round().astype(int).flatten().tolist(),\n",
" \"out\": (X*1e36).round().astype(int).flatten().tolist()\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"with open(\"flatten2D_input.json\", \"w\") as f:\n",
" json.dump(in_json, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "sklearn",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}