mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
dtypes.default_float and dtypes.default_int (#2824)
This commit is contained in:
@@ -12,7 +12,7 @@ if __name__ == "__main__":
|
||||
# https://siboehm.com/articles/22/CUDA-MMM
|
||||
import random, time
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional, SupportsIndex, Type, Union
|
||||
from typing import Any, Dict, Optional, SupportsIndex
|
||||
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
|
||||
from tinygrad import nn
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
@@ -27,10 +27,10 @@ from tinygrad.jit import TinyJit
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
||||
|
||||
if getenv("HALF", 0):
|
||||
Tensor.default_type = dtypes.float16
|
||||
np_dtype: Type[Union[np.float16, np.float32]] = np.float16
|
||||
dtypes.default_float = dtypes.float16
|
||||
np_dtype = np.float16
|
||||
else:
|
||||
Tensor.default_type = dtypes.float32
|
||||
dtypes.default_float = dtypes.float32
|
||||
np_dtype = np.float32
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
@@ -52,13 +52,13 @@ class ConvGroup:
|
||||
x = x.max_pool2d(2)
|
||||
x = x.float()
|
||||
x = self.norm1(x)
|
||||
x = x.cast(Tensor.default_type)
|
||||
x = x.cast(dtypes.default_float)
|
||||
x = x.gelu()
|
||||
residual = x
|
||||
x = self.conv2(x)
|
||||
x = x.float()
|
||||
x = self.norm2(x)
|
||||
x = x.cast(Tensor.default_type)
|
||||
x = x.cast(dtypes.default_float)
|
||||
x = x.gelu()
|
||||
|
||||
return x + residual
|
||||
@@ -277,7 +277,8 @@ def train_cifar():
|
||||
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
|
||||
|
||||
# Convert data and labels to the default dtype
|
||||
X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type)
|
||||
X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float)
|
||||
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
|
||||
|
||||
# parse the training params into bias and non-bias
|
||||
params_dict = get_state_dict(model)
|
||||
|
||||
Reference in New Issue
Block a user