mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
dtypes.from_py to convert py types to dtypes (#2826)
also updated some tests to test against default dtypes
This commit is contained in:
@@ -156,25 +156,25 @@ class TestTinygrad(unittest.TestCase):
|
||||
for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]:
|
||||
a = Tensor([1, 2, 3], dtype=datatype)
|
||||
b = Tensor.zeros_like(a)
|
||||
assert a.dtype == b.dtype, f"a.dtype and b.dtype should be {datatype}"
|
||||
assert a.shape == b.shape, f"shape mismatch (Tensor.zeros_like){a.shape} != (torch){b.shape}"
|
||||
assert a.dtype == b.dtype, f"dtype mismatch {a.dtype=} != {b.dtype}"
|
||||
assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
|
||||
|
||||
a = Tensor([1, 2, 3])
|
||||
b = Tensor.zeros_like(a, dtype=dtypes.int8)
|
||||
assert a.dtype == dtypes.int32 and b.dtype == dtypes.int8, "a.dtype should be int and b.dtype should be char"
|
||||
assert a.shape == b.shape, f"shape mismatch (Tensor.zeros_like){a.shape} != (torch){b.shape}"
|
||||
assert a.dtype == dtypes.default_int and b.dtype == dtypes.int8, "a.dtype should be int and b.dtype should be char"
|
||||
assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
|
||||
|
||||
def test_ones_like_has_same_dtype_and_shape(self):
|
||||
for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]:
|
||||
a = Tensor([1, 2, 3], dtype=datatype)
|
||||
b = Tensor.ones_like(a)
|
||||
assert a.dtype == b.dtype, f"a.dtype and b.dtype should be {datatype}"
|
||||
assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}"
|
||||
assert a.dtype == b.dtype, f"dtype mismatch {a.dtype=} != {b.dtype}"
|
||||
assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
|
||||
|
||||
a = Tensor([1, 2, 3])
|
||||
b = Tensor.ones_like(a, dtype=dtypes.int8)
|
||||
assert a.dtype == dtypes.int32 and b.dtype == dtypes.int8, "a.dtype should be int and b.dtype should be char"
|
||||
assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}"
|
||||
assert a.dtype == dtypes.default_int and b.dtype == dtypes.int8, "a.dtype should be int and b.dtype should be char"
|
||||
assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
|
||||
|
||||
def test_ndim(self):
|
||||
assert Tensor(1).ndim == 0
|
||||
@@ -240,7 +240,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
|
||||
def test_tensor_list_dtype(self):
|
||||
for arr in ([1], [[[1]]], [[1,1],[1,1]], [[[1,1],[1,1]],[[1,1],[1,1]]]):
|
||||
assert Tensor(arr).dtype == dtypes.int32
|
||||
assert Tensor(arr).dtype == dtypes.default_int
|
||||
assert Tensor(arr, dtype=dtypes.float32).dtype == dtypes.float32
|
||||
assert Tensor(arr, dtype=dtypes.float64).dtype == dtypes.float64
|
||||
|
||||
@@ -258,7 +258,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
# mixture of bool and int
|
||||
for arr in ([True, 3], [[True],[3]], [[[True]], [[3]]], [[True, 3], [3, True]]):
|
||||
t = Tensor(arr)
|
||||
assert t.dtype == dtypes.int32
|
||||
assert t.dtype == dtypes.default_int
|
||||
np.testing.assert_allclose(t.numpy(), np.array(arr))
|
||||
|
||||
# mixture of bool, int and float
|
||||
|
||||
@@ -145,6 +145,8 @@ class dtypes:
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
||||
@staticmethod # NOTE: isinstance(True, int) is True in python
|
||||
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set, DefaultDict, cast
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Set, DefaultDict, cast
|
||||
from collections import defaultdict
|
||||
from functools import partialmethod, reduce
|
||||
from itertools import accumulate
|
||||
@@ -15,6 +15,8 @@ from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.realize import run_schedule
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Function:
|
||||
def __init__(self, device:str, *tensors:Tensor):
|
||||
self.device = device
|
||||
@@ -34,8 +36,6 @@ class Function:
|
||||
|
||||
import tinygrad.mlops as mlops
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Tensor:
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
@@ -43,10 +43,11 @@ class Tensor:
|
||||
class train:
|
||||
def __init__(self, val=True): self.val = val
|
||||
def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||
|
||||
no_grad: ClassVar[bool] = False
|
||||
def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): # noqa: E501
|
||||
def __init__(self, data:Union[None, bool, int, float, List, Tuple, LazyBuffer, np.ndarray, bytes],
|
||||
device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = Device.canonicalize(device)
|
||||
# tensors have gradients, buffers do not
|
||||
@@ -59,16 +60,14 @@ class Tensor:
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx: Optional[Function] = None
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, bool): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.bool, device, data)
|
||||
elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.default_int, device, data)
|
||||
elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.default_float, device, data)
|
||||
elif isinstance(data, (bool, int, float)): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
|
||||
elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or dtypes.default_float).np))
|
||||
elif isinstance(data, list):
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
|
||||
elif d and all_int(d): dtype = dtype or dtypes.default_int
|
||||
else: dtype = dtype or dtypes.default_float
|
||||
# NOTE: cast at the end for the types that do not have a numpy dtype
|
||||
# NOTE: cast at the end for the dtypes that do not have a numpy dtype
|
||||
data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
|
||||
elif isinstance(data, np.ndarray):
|
||||
if data.shape == (): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||
@@ -169,11 +168,8 @@ class Tensor:
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[sint, ...], fill_value, **kwargs):
|
||||
# TODO: dtypes.from_py
|
||||
dtype = kwargs.pop("dtype",
|
||||
dtypes.default_float if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.default_int)
|
||||
return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
def full(shape:Tuple[sint, ...], fill_value: Union[bool, int, float], **kwargs):
|
||||
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
||||
@@ -326,7 +322,7 @@ class Tensor:
|
||||
# treat internal tuples and lists as Tensors and standardize indices to list type
|
||||
if isinstance(indices, (tuple, list)):
|
||||
# special case <indices: List[int]>, a lil ugly
|
||||
if isinstance(indices, list) and all(isinstance(i, int) for i in indices): indices = [Tensor(indices, dtype=dtypes.int32, requires_grad=False, device=self.device)] # noqa: E501
|
||||
if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, dtype=dtypes.int32, requires_grad=False, device=self.device)]
|
||||
else: indices = [Tensor(list(i), dtype=dtypes.int32, requires_grad=False, device=self.device) if isinstance(i, (tuple, list)) else i for i in indices] # noqa: E501
|
||||
else: indices = [indices]
|
||||
|
||||
@@ -376,7 +372,7 @@ class Tensor:
|
||||
new_shape = list(ret.shape)
|
||||
for dim in type_dim[None]: new_shape.insert(dim, 1)
|
||||
for dim in (dims_collapsed := [dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])]): new_shape.pop(dim)
|
||||
for dim_sh in new_shape: assert isinstance(dim_sh, int), f"does not support symbolic shape {dim_sh}"
|
||||
assert all_int(new_shape), f"does not support symbolic shape {new_shape}"
|
||||
|
||||
ret = ret.reshape(tuple(new_shape))
|
||||
|
||||
@@ -733,7 +729,7 @@ class Tensor:
|
||||
if 0 in self.shape: return self, self.full_like(y)
|
||||
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
||||
else:
|
||||
y_dtype = dtypes.bool if isinstance(y, bool) else dtypes.default_int if isinstance(y, int) else dtypes.default_float
|
||||
y_dtype = dtypes.from_py(y)
|
||||
x = x.cast(y_dtype)
|
||||
y = Tensor(y, self.device, y_dtype, requires_grad=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user