Files
ezkl/examples/notebooks/nbeats_timeseries_forecasting.ipynb
2025-06-27 22:58:10 +02:00

996 lines
35 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "TB8jFLoLwZ8K"
},
"source": [
"# N-BEATS Time Series Forecasting\n",
"In this tutorial we utilize the N-BEATS (Neural basis expansion analysis for interpretable time series forecasting\n",
") for forecasting the price of ethereum.\n",
"\n",
"For more details regarding N-BEATS, visit this link [https://arxiv.org/abs/1905.10437](https://arxiv.org/abs/1905.10437)\n",
"\n",
"The code for N-BEATS used is adapted from [nbeats-pytorch](https://pypi.org/project/nbeats-pytorch/)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ccy0MgZLwY1Z"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import torch\n",
"from torch import nn, optim\n",
"from torch.nn import functional as F\n",
"from torch.nn.functional import mse_loss, l1_loss, binary_cross_entropy, cross_entropy\n",
"from torch.optim import Optimizer\n",
"import matplotlib.pyplot as plt\n",
"import requests\n",
"import json\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ovxRhGv0xS0i"
},
"outputs": [],
"source": [
"# Fetch data\n",
"coins = [\"ETH\"]\n",
"days_ago_to_fetch = 2000 # 7 years\n",
"coin_history = {}\n",
"hist_length = 0\n",
"average_returns = {}\n",
"cumulative_returns = {}\n",
"\n",
"def index_history_coin(hist):\n",
" hist = hist.set_index('time')\n",
" hist.index = pd.to_datetime(hist.index, unit='s')\n",
" return hist\n",
"\n",
"def filter_history_by_date(hist):\n",
" result = hist[hist.index.year >= 2017]\n",
" return result\n",
"\n",
"def fetch_history_coin(coin):\n",
" endpoint_url = \"https://min-api.cryptocompare.com/data/histoday?fsym={}&tsym=USD&limit={:d}\".format(coin, days_ago_to_fetch)\n",
" res = requests.get(endpoint_url)\n",
" hist = pd.DataFrame(json.loads(res.content)['Data'])\n",
" hist = index_history_coin(hist)\n",
" hist = filter_history_by_date(hist)\n",
" return hist\n",
"\n",
"def get_history_from_file(filename):\n",
" return pd.read_csv(filename)\n",
"\n",
"\n",
"for coin in coins:\n",
" coin_history[coin] = fetch_history_coin(coin)\n",
"\n",
"hist_length = len(coin_history[coins[0]])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CNeFMmvpx5ig"
},
"outputs": [],
"source": [
"# save to file\n",
"# coin_history['ETH'].to_csv(\"eth_price.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_wPJcU8EyOsF"
},
"outputs": [],
"source": [
"# read from file\n",
"coin_history['ETH'] = get_history_from_file(\"eth_price.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ij_DBZl7yQqE",
"outputId": "3b7838de-fa00-4560-cbcb-62c11e311a0f"
},
"outputs": [],
"source": [
"# calculate returns\n",
"\n",
"def add_all_returns():\n",
" for coin in coins:\n",
" hist = coin_history[coin]\n",
" hist['return'] = (hist['close'] - hist['open']) / hist['open']\n",
" average = hist[\"return\"].mean()\n",
" average_returns[coin] = average\n",
" cumulative_returns[coin] = (hist[\"return\"] + 1).prod() - 1\n",
" hist['excess_return'] = hist['return'] - average\n",
" coin_history[coin] = hist\n",
"\n",
"add_all_returns()\n",
"\n",
"# display data\n",
"cumulative_returns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nW3xWLCNyeGN",
"outputId": "a5d7f42b-447b-4ec4-8b42-4bd845ba3b3b"
},
"outputs": [],
"source": [
"# average return per day\n",
"average_returns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "J9AOW3czykk-",
"outputId": "23d9564c-ae0e-4bb4-cf3e-bdb46e9e1639"
},
"outputs": [],
"source": [
"# Excess matrix\n",
"excess_matrix = np.zeros((hist_length, len(coins)))\n",
"\n",
"for i in range(0, hist_length):\n",
" for idx, coin in enumerate(coins):\n",
" excess_matrix[i][idx] = coin_history[coin].iloc[i]['excess_return']\n",
"\n",
"excess_matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"id": "A3GuFe_-yq_Q",
"outputId": "3313aa46-88ef-4dbb-e07d-6cc6c9584eb9"
},
"outputs": [],
"source": [
"# pretty print excess matrix\n",
"pretty_matrix = pd.DataFrame(excess_matrix).copy()\n",
"pretty_matrix.columns = coins\n",
"pretty_matrix.index = coin_history[coins[0]].index\n",
"\n",
"pretty_matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rRYQ63frys8K",
"outputId": "d4df3245-f4e6-4511-dd8a-7c7e52cb4982"
},
"outputs": [],
"source": [
"# Risk modelling\n",
"\n",
"# variance co-var matrix\n",
"product_matrix = np.matmul(excess_matrix.transpose(), excess_matrix)\n",
"var_covar_matrix = product_matrix / hist_length\n",
"\n",
"var_covar_matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "F_0y258Iz2-1",
"outputId": "f161c55d-5600-41da-f065-821c87340f33"
},
"outputs": [],
"source": [
"# pretty var_covar\n",
"pretty_var_covar = pd.DataFrame(var_covar_matrix).copy()\n",
"pretty_var_covar.columns = coins\n",
"pretty_var_covar.index = coins\n",
"\n",
"pretty_var_covar"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_HS7sq_Xz4Js"
},
"outputs": [],
"source": [
"# Std dev\n",
"\n",
"std_dev = np.zeros((len(coins), 1))\n",
"neg_std_dev = np.zeros((len(coins), 1))\n",
"\n",
"for idx, coin in enumerate(coins):\n",
" std_dev[idx][0] = np.std(coin_history[coin]['return'])\n",
" coin_history[coin]['downside_return'] = 0\n",
"\n",
" coin_history[coin].loc[coin_history[coin]['return'] < 0,\n",
" 'downside_return'] = coin_history[coin]['return']**2\n",
" neg_std_dev[idx][0] = np.sqrt(coin_history[coin]['downside_return'].mean())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "gLpf-k1az77u",
"outputId": "d6dc063f-05b7-4a9f-de59-cc60b0cfee5e"
},
"outputs": [],
"source": [
"# pretty std\n",
"pretty_std = pd.DataFrame(std_dev).copy()\n",
"pretty_neg_std = pd.DataFrame(neg_std_dev).copy()\n",
"pretty_comb = pd.concat([pretty_std, pretty_neg_std], axis=1)\n",
"\n",
"pretty_comb.columns = ['std dev', 'neg std dev']\n",
"pretty_comb.index = coins\n",
"\n",
"pretty_comb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MAd19PQcz9A5"
},
"outputs": [],
"source": [
"# std_product_mat\n",
"std_product_matrix = np.matmul(std_dev, std_dev.transpose())\n",
"\n",
"# neg_prod_mat\n",
"neg_std_product_matrix = np.matmul(neg_std_dev, neg_std_dev.transpose())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "O8aD5ZiYz-8E",
"outputId": "945391f5-8e3b-4369-cc0c-1cf927f72ef2"
},
"outputs": [],
"source": [
"pretty_std_prod = pd.DataFrame(std_product_matrix).copy()\n",
"pretty_std_prod.columns = coins\n",
"pretty_std_prod.index = coins\n",
"\n",
"pretty_std_prod"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81
},
"id": "KecwUguO0Ago",
"outputId": "f77f4bd2-3314-4d7e-8525-078825a83a8c"
},
"outputs": [],
"source": [
"# Corr matrix\n",
"corr_matrix = var_covar_matrix / std_product_matrix\n",
"pretty_corr = pd.DataFrame(corr_matrix).copy()\n",
"pretty_corr.columns = coins\n",
"pretty_corr.index = coins\n",
"\n",
"pretty_corr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 424
},
"id": "m62kaZFu0C0p",
"outputId": "b3c10014-afe1-4cdb-e5a3-3a1361b54501"
},
"outputs": [],
"source": [
"# see additional stuff we have added to the DF\n",
"coin_history['ETH']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JdtkDBc90rM6",
"outputId": "19d1c7af-809e-4b1b-a1a0-b50bba67d425"
},
"outputs": [],
"source": [
"def simulate_portfolio_growth(initial_amount, daily_returns):\n",
" portfolio_value = [initial_amount]\n",
" for ret in daily_returns:\n",
" portfolio_value.append(portfolio_value[-1] * (1 + ret))\n",
" return portfolio_value\n",
"\n",
"initial_investment = 100000\n",
"\n",
"eth_portfolio = simulate_portfolio_growth(initial_investment, coin_history[\"ETH\"]['return'])\n",
"\n",
"print(\"ETH Portfolio Growth:\", eth_portfolio)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 564
},
"id": "ADBacxHA27v0",
"outputId": "3f5f8f35-4efc-473d-a5af-12515fa897b6"
},
"outputs": [],
"source": [
"# Plotting the growth\n",
"plt.figure(figsize=(10,6))\n",
"plt.plot(eth_portfolio, label='ETH Portfolio', color='blue')\n",
"plt.title('Portfolio Growth Over Time')\n",
"plt.xlabel('Days')\n",
"plt.ylabel('Portfolio Value')\n",
"plt.legend()\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fWAx5OZ-302J",
"outputId": "a94346ac-4c16-428d-eac1-d8fbbd9208b4"
},
"outputs": [],
"source": [
"# close dataframe\n",
"eth_df = coin_history['ETH'][['close']].copy()\n",
"\n",
"# Convert to tensor\n",
"close_tensor = torch.tensor(eth_df.values)\n",
"\n",
"# return dataframe\n",
"eth_df = coin_history['ETH'][['return']].copy()\n",
"\n",
"# Convert to tensor\n",
"return_tensor = torch.tensor(eth_df.values)\n",
"\n",
"# return_tensor = torch.tensor(eth_df.values)\n",
"print(close_tensor)\n",
"print(return_tensor)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4M6hIqZ15aqs"
},
"outputs": [],
"source": [
"# Adapted from https://pypi.org/project/nbeats-pytorch/\n",
"# author = {Philippe Remy},\n",
"\n",
"def squeeze_last_dim(tensor):\n",
" if len(tensor.shape) == 3 and tensor.shape[-1] == 1: # (128, 10, 1) => (128, 10).\n",
" return tensor[..., 0]\n",
" return tensor\n",
"\n",
"\n",
"def seasonality_model(thetas, t, device):\n",
" p = thetas.size()[-1]\n",
" assert p <= thetas.shape[1], 'thetas_dim is too big.'\n",
" p1, p2 = (p // 2, p // 2) if p % 2 == 0 else (p // 2, p // 2 + 1)\n",
" s1 = torch.tensor(np.array([np.cos(2 * np.pi * i * t) for i in range(p1)])).float() # H/2-1\n",
" s2 = torch.tensor(np.array([np.sin(2 * np.pi * i * t) for i in range(p2)])).float()\n",
" S = torch.cat([s1, s2])\n",
" return thetas.mm(S.to(device))\n",
"\n",
"\n",
"def trend_model(thetas, t, device):\n",
" p = thetas.size()[-1]\n",
" assert p <= 4, 'thetas_dim is too big.'\n",
" T = torch.tensor(np.array([t ** i for i in range(p)])).float()\n",
" return thetas.mm(T.to(device))\n",
"\n",
"\n",
"def linear_space(backcast_length, forecast_length, is_forecast=True):\n",
" horizon = forecast_length if is_forecast else backcast_length\n",
" return np.arange(0, horizon) / horizon\n",
"\n",
"class Block(nn.Module):\n",
"\n",
" def __init__(self, units, thetas_dim, device, backcast_length=10, forecast_length=5, share_thetas=False,\n",
" nb_harmonics=None):\n",
" super(Block, self).__init__()\n",
" self.units = units\n",
" self.thetas_dim = thetas_dim\n",
" self.backcast_length = backcast_length\n",
" self.forecast_length = forecast_length\n",
" self.share_thetas = share_thetas\n",
" self.fc1 = nn.Linear(backcast_length, units)\n",
" self.fc2 = nn.Linear(units, units)\n",
" self.fc3 = nn.Linear(units, units)\n",
" self.fc4 = nn.Linear(units, units)\n",
" self.device = device\n",
" self.backcast_linspace = linear_space(backcast_length, forecast_length, is_forecast=False)\n",
" self.forecast_linspace = linear_space(backcast_length, forecast_length, is_forecast=True)\n",
" if share_thetas:\n",
" self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False)\n",
" else:\n",
" self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False)\n",
" self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False)\n",
"\n",
" def forward(self, x):\n",
" x = squeeze_last_dim(x)\n",
" x = F.relu(self.fc1(x.to(self.device)))\n",
" x = F.relu(self.fc2(x))\n",
" x = F.relu(self.fc3(x))\n",
" x = F.relu(self.fc4(x))\n",
" return x\n",
"\n",
" def __str__(self):\n",
" block_type = type(self).__name__\n",
" return f'{block_type}(units={self.units}, thetas_dim={self.thetas_dim}, ' \\\n",
" f'backcast_length={self.backcast_length}, forecast_length={self.forecast_length}, ' \\\n",
" f'share_thetas={self.share_thetas}) at @{id(self)}'\n",
"\n",
"\n",
"class SeasonalityBlock(Block):\n",
"\n",
" def __init__(self, units, thetas_dim, device, backcast_length=10, forecast_length=5, nb_harmonics=None):\n",
" if nb_harmonics:\n",
" super(SeasonalityBlock, self).__init__(units, nb_harmonics, device, backcast_length,\n",
" forecast_length, share_thetas=True)\n",
" else:\n",
" super(SeasonalityBlock, self).__init__(units, forecast_length, device, backcast_length,\n",
" forecast_length, share_thetas=True)\n",
"\n",
" def forward(self, x):\n",
" x = super(SeasonalityBlock, self).forward(x)\n",
" backcast = seasonality_model(self.theta_b_fc(x), self.backcast_linspace, self.device)\n",
" forecast = seasonality_model(self.theta_f_fc(x), self.forecast_linspace, self.device)\n",
" return backcast, forecast\n",
"\n",
"\n",
"class TrendBlock(Block):\n",
"\n",
" def __init__(self, units, thetas_dim, device, backcast_length=10, forecast_length=5, nb_harmonics=None):\n",
" super(TrendBlock, self).__init__(units, thetas_dim, device, backcast_length,\n",
" forecast_length, share_thetas=True)\n",
"\n",
" def forward(self, x):\n",
" x = super(TrendBlock, self).forward(x)\n",
" backcast = trend_model(self.theta_b_fc(x), self.backcast_linspace, self.device)\n",
" forecast = trend_model(self.theta_f_fc(x), self.forecast_linspace, self.device)\n",
" return backcast, forecast\n",
"\n",
"\n",
"\n",
"class GenericBlock(Block):\n",
"\n",
" def __init__(self, units, thetas_dim, device, backcast_length=10, forecast_length=5, nb_harmonics=None):\n",
" super(GenericBlock, self).__init__(units, thetas_dim, device, backcast_length, forecast_length)\n",
"\n",
" self.backcast_fc = nn.Linear(thetas_dim, backcast_length)\n",
" self.forecast_fc = nn.Linear(thetas_dim, forecast_length)\n",
"\n",
" def forward(self, x):\n",
" # no constraint for generic arch.\n",
" x = super(GenericBlock, self).forward(x)\n",
"\n",
" theta_b = self.theta_b_fc(x)\n",
" theta_f = self.theta_f_fc(x)\n",
"\n",
" backcast = self.backcast_fc(theta_b) # generic. 3.3.\n",
" forecast = self.forecast_fc(theta_f) # generic. 3.3.\n",
"\n",
" return backcast, forecast\n",
"\n",
"\n",
"class NBEATS(nn.Module):\n",
" SEASONALITY_BLOCK = 'seasonality'\n",
" TREND_BLOCK = 'trend'\n",
" GENERIC_BLOCK = 'generic'\n",
"\n",
" def __init__(\n",
" self,\n",
" device=torch.device(\"cpu\"),\n",
" stack_types=(GENERIC_BLOCK, GENERIC_BLOCK),\n",
" nb_blocks_per_stack=1,\n",
" forecast_length=7,\n",
" backcast_length=14,\n",
" theta_dims=(2,2),\n",
" share_weights_in_stack=False,\n",
" hidden_layer_units=32,\n",
" nb_harmonics=None,\n",
" ):\n",
" super(NBEATS, self).__init__()\n",
" self.forecast_length = forecast_length\n",
" self.backcast_length = backcast_length\n",
" self.hidden_layer_units = hidden_layer_units\n",
" self.nb_blocks_per_stack = nb_blocks_per_stack\n",
" self.share_weights_in_stack = share_weights_in_stack\n",
" self.nb_harmonics = nb_harmonics # for seasonal data\n",
" self.stack_types = stack_types\n",
" self.stacks = nn.ModuleList()\n",
" self.thetas_dim = theta_dims\n",
" self.device = device\n",
" print('| N-Beats')\n",
" for stack_id in range(len(self.stack_types)):\n",
" stack = self.create_stack(stack_id)\n",
" self.stacks.append(stack)\n",
" self.to(self.device)\n",
" # self.asset_weight_layer = nn.Softmax(dim=1)\n",
" # self.asset_classes = asset_classes\n",
"\n",
"\n",
" def create_stack(self, stack_id):\n",
" stack_type = self.stack_types[stack_id]\n",
" print(f'| -- Stack {stack_type.title()} (#{stack_id}) (share_weights_in_stack={self.share_weights_in_stack})')\n",
" blocks = nn.ModuleList()\n",
" for block_id in range(self.nb_blocks_per_stack):\n",
" block_init = NBEATS.select_block(stack_type)\n",
" if self.share_weights_in_stack and block_id != 0:\n",
" block = blocks[-1] # pick up the last one when we share weights.\n",
" else:\n",
" block = block_init(\n",
" self.hidden_layer_units, self.thetas_dim[stack_id],\n",
" self.device, self.backcast_length, self.forecast_length,\n",
" self.nb_harmonics\n",
" )\n",
" print(f' | -- {block}')\n",
" blocks.append(block)\n",
" return blocks\n",
"\n",
" @staticmethod\n",
" def select_block(block_type):\n",
" if block_type == NBEATS.SEASONALITY_BLOCK:\n",
" return SeasonalityBlock\n",
" elif block_type == NBEATS.TREND_BLOCK:\n",
" return TrendBlock\n",
" else:\n",
" return GenericBlock\n",
"\n",
"\n",
" def get_generic_and_interpretable_outputs(self):\n",
" g_pred = sum([a['value'][0] for a in self._intermediary_outputs if 'generic' in a['layer'].lower()])\n",
" i_pred = sum([a['value'][0] for a in self._intermediary_outputs if 'generic' not in a['layer'].lower()])\n",
" outputs = {o['layer']: o['value'][0] for o in self._intermediary_outputs}\n",
" return g_pred, i_pred,\n",
"\n",
" def forward(self, backcast):\n",
" self._intermediary_outputs = []\n",
" backcast = squeeze_last_dim(backcast)\n",
" forecast = torch.zeros(size=(backcast.size()[0], self.forecast_length,)) # maybe batch size here.\n",
" for stack_id in range(len(self.stacks)):\n",
" for block_id in range(len(self.stacks[stack_id])):\n",
" b, f = self.stacks[stack_id][block_id](backcast)\n",
" backcast = backcast.to(self.device) - b\n",
" forecast = forecast.to(self.device) + f\n",
" block_type = self.stacks[stack_id][block_id].__class__.__name__\n",
" layer_name = f'stack_{stack_id}-{block_type}_{block_id}'\n",
"\n",
" return backcast, forecast\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tTH313qbRLMG",
"outputId": "34275438-5548-4a8b-f4a9-d96056ebde1c"
},
"outputs": [],
"source": [
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"class TimeSeriesDataset(Dataset):\n",
" def __init__(self, close_data, return_data, backcast_length, forecast_length, shuffle=True):\n",
" self.close_data = close_data\n",
" self.return_data = return_data\n",
" self.backcast_length = backcast_length\n",
" self.forecast_length = forecast_length\n",
" self.indices = list(range(len(self.close_data) - self.backcast_length - self.forecast_length + 1))\n",
" if shuffle:\n",
" np.random.shuffle(self.indices)\n",
"\n",
" def __len__(self):\n",
" return len(self.close_data) - self.backcast_length - self.forecast_length + 1\n",
"\n",
" def __getitem__(self, idx):\n",
" start = idx\n",
" end = idx + self.backcast_length\n",
" x = self.close_data[start:end] # Take columns as needed\n",
" y = self.close_data[end:end+self.forecast_length] # Adjust as per forecast columns needed\n",
" return x, y\n",
"\n",
"# Hyperparameters\n",
"BACKCAST_LENGTH = 14\n",
"FORECAST_LENGTH = 7\n",
"\n",
"train_length = round(len(close_tensor) * 0.7)\n",
"train_dataset = TimeSeriesDataset(close_tensor[0:train_length], return_tensor[0:train_length], BACKCAST_LENGTH, FORECAST_LENGTH)\n",
"test_dataset = TimeSeriesDataset(close_tensor[train_length:], return_tensor[train_length:], BACKCAST_LENGTH, FORECAST_LENGTH)\n",
"train_loader = DataLoader(train_dataset)\n",
"\n",
"model = NBEATS(forecast_length=FORECAST_LENGTH, backcast_length=BACKCAST_LENGTH, device=('cuda' if torch.cuda.is_available() else 'cpu'))\n",
"model = model.to('cuda' if torch.cuda.is_available() else 'cpu')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JJ8-nh2GLKN_",
"outputId": "0d761daa-0f14-4a50-be41-b17993a4a182"
},
"outputs": [],
"source": [
"EPOCHS = 1\n",
"\n",
"num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Number of trainable parameters in model: {num_parameters}\")\n",
"\n",
"criterion = torch.nn.L1Loss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
"\n",
"for epoch in range(EPOCHS):\n",
" total_loss = 0.0\n",
" for batch_idx, (x, y) in enumerate(train_loader):\n",
" # Zero gradients\n",
" optimizer.zero_grad()\n",
"\n",
" x = x.clone().detach().to(dtype=torch.float)\n",
" x = x.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
" y = y.clone().detach().to(dtype=torch.float)\n",
" y = y.to('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"\n",
" # Forward pass\n",
" forecast = model(x)\n",
"\n",
" loss = criterion(forecast[0], y)\n",
"\n",
" # Backprop and optimize\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # add positive gain for logging\n",
" total_loss += loss # add the positive gain_loss for logging\n",
"\n",
" avg_loss = total_loss / len(train_loader)\n",
" print(f\"Epoch {epoch+1}/{EPOCHS}, Average Loss: {avg_loss:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jWLKwNFLYDOk"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import os\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dhOHiCmt4pUn"
},
"outputs": [],
"source": [
"model_path = os.path.join('network.onnx')\n",
"compiled_model_path = os.path.join('network.compiled')\n",
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"settings_path = os.path.join('settings.json')\n",
"\n",
"witness_path = os.path.join('witness.json')\n",
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xsZ9xg7I48l4",
"outputId": "6dec08c6-f55e-4df1-b957-55d025286018"
},
"outputs": [],
"source": [
"# After training, export to onnx (network.onnx) and create a data file (input.json)\n",
"x_export = None\n",
"for batch_idx, (x, y) in enumerate(train_loader):\n",
" x_export = x.clone().detach().to(dtype=torch.float)\n",
" break\n",
"\n",
"# Flips the neural net into inference mode\n",
"model.eval()\n",
"\n",
" # Export the model\n",
"torch.onnx.export(model, # model being run\n",
" x_export, # model input (or a tuple for multiple inputs)\n",
" model_path, # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=10, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"data_array = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data = [data_array])\n",
"\n",
" # Serialize data into file:\n",
"json.dump( data, open(data_path, 'w' ))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5qdEFK_75GUb"
},
"outputs": [],
"source": [
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"private\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
"!RUST_LOG=trace\n",
"# TODO: Dictionary outputs\n",
"res = ezkl.gen_settings(model_path, settings_path)\n",
"assert res == True\n",
"\n",
"res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [3])\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pxDJPz-Q5LPF"
},
"outputs": [],
"source": [
"res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)\n",
"assert res == True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ptcb4SGA5Qeb"
},
"outputs": [],
"source": [
"# srs path\n",
"res = await ezkl.get_srs( settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OE7t0okU5WBQ"
},
"outputs": [],
"source": [
"res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)\n",
"assert os.path.isfile(witness_path)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"id": "12YIcFr85X9-"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"spawning module 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"quotient_poly_degree 4\n",
"n 262144\n",
"extended_k 20\n"
]
}
],
"source": [
"res = ezkl.setup(\n",
" compiled_model_path,\n",
" vk_path,\n",
" pk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CSbWeZB35awS"
},
"outputs": [],
"source": [
"proof_path = os.path.join('test.pf')\n",
"\n",
"res = ezkl.prove(\n",
" witness_path,\n",
" compiled_model_path,\n",
" pk_path,\n",
" proof_path,\n",
" \n",
" \"single\",\n",
" )\n",
"\n",
"print(res)\n",
"assert os.path.isfile(proof_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aGt8f4LS5dTP"
},
"outputs": [],
"source": [
"# VERIFY IT\n",
"\n",
"res = ezkl.verify(\n",
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
"print(\"verified\")"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 0
}