dtypes.default_float and dtypes.default_int (#2824)

This commit is contained in:
chenyu
2023-12-18 12:21:44 -05:00
committed by GitHub
parent 8aab19ce3d
commit 0723f26c80
9 changed files with 108 additions and 82 deletions

View File

@@ -12,7 +12,7 @@ if __name__ == "__main__":
# https://siboehm.com/articles/22/CUDA-MMM
import random, time
import numpy as np
from typing import Any, Dict, Optional, SupportsIndex, Type, Union
from typing import Any, Dict, Optional, SupportsIndex
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from tinygrad import nn
from tinygrad.nn.state import get_state_dict
@@ -27,10 +27,10 @@ from tinygrad.jit import TinyJit
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
if getenv("HALF", 0):
Tensor.default_type = dtypes.float16
np_dtype: Type[Union[np.float16, np.float32]] = np.float16
dtypes.default_float = dtypes.float16
np_dtype = np.float16
else:
Tensor.default_type = dtypes.float32
dtypes.default_float = dtypes.float32
np_dtype = np.float32
class BatchNorm(nn.BatchNorm2d):
@@ -52,13 +52,13 @@ class ConvGroup:
x = x.max_pool2d(2)
x = x.float()
x = self.norm1(x)
x = x.cast(Tensor.default_type)
x = x.cast(dtypes.default_float)
x = x.gelu()
residual = x
x = self.conv2(x)
x = x.float()
x = self.norm2(x)
x = x.cast(Tensor.default_type)
x = x.cast(dtypes.default_float)
x = x.gelu()
return x + residual
@@ -277,7 +277,8 @@ def train_cifar():
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
# Convert data and labels to the default dtype
X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type)
X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float)
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
# parse the training params into bias and non-bias
params_dict = get_state_dict(model)

View File

@@ -482,8 +482,8 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
scales_ = [os/xs for xs, os in zip(X.shape, output_shape)]
x_out = Tensor.arange(output_shape[-1]).cast(Tensor.default_type)
y_out = Tensor.arange(output_shape[-2]).cast(Tensor.default_type)
x_out = Tensor.arange(output_shape[-1]).cast(dtypes.default_float)
y_out = Tensor.arange(output_shape[-2]).cast(dtypes.default_float)
if mode == "nearest":
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_, roi)
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])

View File

@@ -15,8 +15,8 @@ def helper_test_jitted_correctness(gen, train, train_jit):
class TestJittedModels(unittest.TestCase):
def test_jitted_tiny_llama(self):
old_type = Tensor.default_type
Tensor.default_type = dtypes.float16
old_float = dtypes.default_float
dtypes.default_float = dtypes.float16
args_tiny = {"dim": 1024, "hidden_dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
@@ -26,7 +26,7 @@ class TestJittedModels(unittest.TestCase):
@TinyJit
def test_jit(t): return model(t, 0).realize()
helper_test_jitted_correctness(lambda: (Tensor([[1,]]),), test, test_jit)
Tensor.default_type = old_type
dtypes.default_float = old_float
@unittest.skipUnless(not CI, "huge for CI")
def test_jitted_stable_diffusion(self):

View File

@@ -40,11 +40,11 @@ class TestRealWorld(unittest.TestCase):
gc.collect()
global global_mem_used
global_mem_used = GlobalCounters.mem_used
self.old_type = Tensor.default_type
self.old_float = dtypes.default_float
np.random.seed(2002)
def tearDown(self):
Tensor.default_type = self.old_type
dtypes.default_float = self.old_float
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
@unittest.skipIf(CI, "too big for CI")
@@ -68,7 +68,7 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1")
def test_llama(self):
Tensor.default_type = dtypes.float16
dtypes.default_float = dtypes.float16
args_tiny = {"dim": 1024, "hidden_dim": 2048, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
model = LLaMaTransformer(**(args_tiny if CI else LLAMA_MODEL_PARAMS["1"]["7B"]["args"]))
@@ -80,7 +80,7 @@ class TestRealWorld(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_gpt2(self):
Tensor.default_type = dtypes.float16
dtypes.default_float = dtypes.float16
args_tiny = {"dim": 1024, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-5, "vocab_size": 1000}
model = GPT2Transformer(**(args_tiny if CI else GPT2_MODEL_PARAMS["gpt2-medium"]))

View File

@@ -224,56 +224,78 @@ class TestHelpers(unittest.TestCase):
assert dtype.vec(amt).scalar() == dtype
class TestTypeSpec(unittest.TestCase):
def test_creation(self):
assert Tensor(True).dtype == dtypes.bool
assert Tensor(2).dtype == dtypes.int
assert Tensor(2.34).dtype == Tensor.default_type
assert Tensor([]).dtype == Tensor.default_type
assert Tensor([1]).dtype == dtypes.int
assert Tensor([1.1]).dtype == Tensor.default_type
assert Tensor([0, 1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
def test_const_full(self):
assert Tensor.ones([2,3]).dtype == Tensor.default_type
assert Tensor.zeros([2,3]).dtype == Tensor.default_type
assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type
assert Tensor.full([2,3], 3).dtype == dtypes.int
def test_set_dtype_default(self):
dtypes.default_int = dtypes.int16
assert dtypes.default_int == dtypes.int16
dtypes.default_int = dtypes.int64
assert dtypes.default_int == dtypes.int64
dtypes.default_int = dtypes.int32
assert dtypes.default_int == dtypes.int32
dtypes.default_float = dtypes.float16
assert dtypes.default_float == dtypes.float16
dtypes.default_float = dtypes.float64
assert dtypes.default_float == dtypes.float64
@given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor(True).dtype == dtypes.bool
assert Tensor(2).dtype == dtypes.default_int
assert Tensor(2.34).dtype == dtypes.default_float
assert Tensor([]).dtype == dtypes.default_float
assert Tensor([1]).dtype == dtypes.default_int
assert Tensor([1.1]).dtype == dtypes.default_float
assert Tensor([0,1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16
assert Tensor.eye(0).dtype == dtypes.default_float
assert Tensor.eye(3).dtype == dtypes.default_float
assert Tensor.eye(3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.eye(3, dtype=dtypes.int64).dtype == dtypes.int64
@given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_full(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor.ones([2,3]).dtype == dtypes.default_float
assert Tensor.zeros([2,3]).dtype == dtypes.default_float
assert Tensor.full([2,3], 3.3).dtype == dtypes.default_float
assert Tensor.full([2,3], 3).dtype == dtypes.default_int
assert Tensor.full([2,3], True).dtype == dtypes.bool
assert Tensor.zeros(3, 3).dtype == dtypes.default_float
assert Tensor.zeros(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.zeros(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
assert Tensor.ones(3, 3).dtype == dtypes.default_float
assert Tensor.ones(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.ones(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
assert Tensor.full((3, 3), 3).dtype == dtypes.default_int
assert Tensor.full((3, 3), 3.0).dtype == dtypes.default_float
assert Tensor.full((3, 3), 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.full((3, 3), 3, dtype=dtypes.int64).dtype == dtypes.int64
def test_reduce_0d_default(self):
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
assert Tensor.ones([2,3,0]).sum(2).dtype == dtypes.default_float
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int # requires reduceop acc fix
def test_arange(self):
assert Tensor.arange(5).dtype == dtypes.int32
assert Tensor.arange(5.0).dtype == Tensor.default_type
@given(st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_arange(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor.arange(5).dtype == dtypes.default_int
assert Tensor.arange(5.0).dtype == dtypes.default_float
assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16
assert Tensor.arange(5, dtype=dtypes.int64).dtype == dtypes.int64
assert Tensor.arange(5, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.arange(3, 9, 0.7).dtype == Tensor.default_type
assert Tensor.arange(3, 8.5, 3).dtype == Tensor.default_type
def test_zeros(self):
assert Tensor.zeros(3, 3).dtype == Tensor.default_type
assert Tensor.zeros(3, 3, dtype= dtypes.float16).dtype == dtypes.float16
assert Tensor.zeros(3, 3, dtype= dtypes.int64).dtype == dtypes.int64
def test_ones(self):
assert Tensor.ones(3, 3).dtype == Tensor.default_type
assert Tensor.ones(3, 3, dtype= dtypes.float16).dtype == dtypes.float16
assert Tensor.ones(3, 3, dtype= dtypes.int64).dtype == dtypes.int64
def test_full(self):
assert Tensor.full((3, 3), 3).dtype == dtypes.int
assert Tensor.full((3, 3), 3.0).dtype == Tensor.default_type
assert Tensor.full((3, 3), 3, dtype= dtypes.float16).dtype == dtypes.float16
assert Tensor.full((3, 3), 3, dtype= dtypes.int64).dtype == dtypes.int64
def test_eye(self):
assert Tensor.eye(0).dtype == Tensor.default_type
assert Tensor.eye(3).dtype == Tensor.default_type
assert Tensor.eye(3, dtype= dtypes.float16).dtype == dtypes.float16
assert Tensor.eye(3, dtype= dtypes.int64).dtype == dtypes.int64
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
core_types = list(DTYPES_DICT.values())
floats = [dt for dt in core_types if dtypes.is_float(dt)]
@@ -331,10 +353,10 @@ class TestAutoCastType(unittest.TestCase):
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-4, atol=1e-4)
def test_broadcast_float(self):
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == Tensor.default_type
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == Tensor.default_type
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == Tensor.default_type
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == Tensor.default_type
assert (Tensor.rand(4, 4, dtype=dtypes.bool) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.int) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.int8) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.uint64) + 2.3).dtype == dtypes.default_float
assert (Tensor.rand(4, 4, dtype=dtypes.float16) + 2.3).dtype == dtypes.float16
assert (Tensor.rand(4, 4, dtype=dtypes.bfloat16) + 2.3).dtype == dtypes.bfloat16
assert (Tensor.rand(4, 4, dtype=dtypes.float32) + 2.3).dtype == dtypes.float32

View File

@@ -1,7 +1,6 @@
#!/usr/bin/env python
import unittest
from tinygrad import Tensor, Device
from tinygrad.helpers import dtypes
from tinygrad import Tensor, Device, dtypes
from examples.beautiful_mnist import Model as MNIST
from examples.hlb_cifar10 import SpeedyResNet
@@ -23,8 +22,8 @@ class TestHIPCompilationRDNA(unittest.TestCase):
output.numpy()
def test_compile_hip_speedyresnet_hf(self):
old_default_type = Tensor.default_type
Tensor.default_type = dtypes.float16
old_default_float = dtypes.default_float
dtypes.default_float = dtypes.float16
W = Tensor.rand(12,3,2,2)
model = SpeedyResNet(W)
@@ -33,7 +32,7 @@ class TestHIPCompilationRDNA(unittest.TestCase):
output = model(input)
output.numpy()
Tensor.default_type = old_default_type
dtypes.default_float = old_default_float
if __name__ == "__main__":
unittest.main()

View File

@@ -252,7 +252,7 @@ class TestTinygrad(unittest.TestCase):
# empty tensor defaults
for arr in ([], [[[]]], [[],[]]):
t = Tensor(arr)
assert t.dtype == Tensor.default_type
assert t.dtype == dtypes.default_float
np.testing.assert_allclose(t.numpy(), np.array(arr))
# mixture of bool and int
@@ -264,7 +264,7 @@ class TestTinygrad(unittest.TestCase):
# mixture of bool, int and float
for arr in ([[True,True],[3.,True]], [[0,1],[3.,4]], [[[0],[1]],[[3.],[4]]], [[[True],[1]],[[3.],[4]]]):
t = Tensor(arr)
assert t.dtype == Tensor.default_type
assert t.dtype == dtypes.default_float
np.testing.assert_allclose(t.numpy(), np.array(arr))
def test_tensor_list_shapes(self):

View File

@@ -181,6 +181,9 @@ class dtypes:
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
@@ -198,7 +201,8 @@ def least_upper_dtype(*ds:DType) -> DType:
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod}
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if (
not k.startswith('__') and not k.startswith('default') and not callable(v) and v.__class__ is not staticmethod)}
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
class GlobalCounters:

View File

@@ -46,7 +46,6 @@ class Tensor:
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): # noqa: E501
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = Device.canonicalize(device)
@@ -61,14 +60,14 @@ class Tensor:
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, bool): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.bool, device, data)
elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.int32, device, data)
elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.default_int, device, data)
elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.default_float, device, data)
elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or Tensor.default_type).np))
elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or dtypes.default_float).np))
elif isinstance(data, list):
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
elif d and all_int(d): dtype = dtype or dtypes.int32
else: dtype = dtype or Tensor.default_type
elif d and all_int(d): dtype = dtype or dtypes.default_int
else: dtype = dtype or dtypes.default_float
# NOTE: cast at the end for the types that do not have a numpy dtype
data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
elif isinstance(data, np.ndarray):
@@ -153,7 +152,7 @@ class Tensor:
@staticmethod
def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs):
assert isinstance(sz, int), f"cannot create with symbolic size {sz}"
return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) # noqa: E501
return Tensor(LazyBuffer.loadop(op, (sz,), dtype or dtypes.default_float, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
@staticmethod
def empty(*shape, **kwargs):
@@ -171,8 +170,9 @@ class Tensor:
@staticmethod
def full(shape:Tuple[sint, ...], fill_value, **kwargs):
# TODO: dtypes.default_type and dtypes.from_py
dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.int32)
# TODO: dtypes.from_py
dtype = kwargs.pop("dtype",
dtypes.default_float if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.default_int)
return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
@@ -184,7 +184,7 @@ class Tensor:
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
dtype = kwargs.pop("dtype", Tensor.default_type if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.int32)
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
return Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)
@staticmethod
@@ -202,7 +202,7 @@ class Tensor:
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
@staticmethod
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(shape, low=low, high=high, dtype=dtypes.int32)
@@ -212,7 +212,7 @@ class Tensor:
@staticmethod
def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", Tensor.default_type)
dtype = kwargs.pop("dtype", dtypes.default_float)
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
@@ -733,7 +733,7 @@ class Tensor:
if 0 in self.shape: return self, self.full_like(y)
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
else:
y_dtype = dtypes.bool if isinstance(y, bool) else dtypes.int32 if isinstance(y, int) else Tensor.default_type
y_dtype = dtypes.bool if isinstance(y, bool) else dtypes.default_int if isinstance(y, int) else dtypes.default_float
x = x.cast(y_dtype)
y = Tensor(y, self.device, y_dtype, requires_grad=False)