dtypes.default_float and dtypes.default_int (#2824)

This commit is contained in:
chenyu
2023-12-18 12:21:44 -05:00
committed by GitHub
parent 8aab19ce3d
commit 0723f26c80
9 changed files with 108 additions and 82 deletions

View File

@@ -12,7 +12,7 @@ if __name__ == "__main__":
# https://siboehm.com/articles/22/CUDA-MMM
import random, time
import numpy as np
from typing import Any, Dict, Optional, SupportsIndex, Type, Union
from typing import Any, Dict, Optional, SupportsIndex
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from tinygrad import nn
from tinygrad.nn.state import get_state_dict
@@ -27,10 +27,10 @@ from tinygrad.jit import TinyJit
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
if getenv("HALF", 0):
Tensor.default_type = dtypes.float16
np_dtype: Type[Union[np.float16, np.float32]] = np.float16
dtypes.default_float = dtypes.float16
np_dtype = np.float16
else:
Tensor.default_type = dtypes.float32
dtypes.default_float = dtypes.float32
np_dtype = np.float32
class BatchNorm(nn.BatchNorm2d):
@@ -52,13 +52,13 @@ class ConvGroup:
x = x.max_pool2d(2)
x = x.float()
x = self.norm1(x)
x = x.cast(Tensor.default_type)
x = x.cast(dtypes.default_float)
x = x.gelu()
residual = x
x = self.conv2(x)
x = x.float()
x = self.norm2(x)
x = x.cast(Tensor.default_type)
x = x.cast(dtypes.default_float)
x = x.gelu()
return x + residual
@@ -277,7 +277,8 @@ def train_cifar():
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
# Convert data and labels to the default dtype
X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type)
X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float)
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
# parse the training params into bias and non-bias
params_dict = get_state_dict(model)