From 38554322659fbe7e19c3cc7052465645274db5b9 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 22 Dec 2023 01:07:44 -0500 Subject: [PATCH] don't use numpy to create Tensor(None) (#2909) * don't use numpy to create Tensor(None) empty suffices * parentheses --- test/test_dtype.py | 1 + tinygrad/tensor.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 73170fd562..231dd704c9 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -252,6 +252,7 @@ class TestTypeSpec(unittest.TestCase): def test_creation(self, default_int, default_float): dtypes.default_int, dtypes.default_float = default_int, default_float assert Tensor(True).dtype == dtypes.bool + assert Tensor(None).dtype == dtypes.default_float assert Tensor(2).dtype == dtypes.default_int assert Tensor(2.34).dtype == dtypes.default_float assert Tensor([]).dtype == dtypes.default_float diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3ef2c17349..a91ecc0ad3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -62,7 +62,7 @@ class Tensor: if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" elif isinstance(data, (bool, int, float)): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) - elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or dtypes.default_float).np)) + elif data is None: data = LazyBuffer.loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) elif isinstance(data, list): if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool elif d and all_int(d): dtype = dtype or dtypes.default_int