diff --git a/test/test_dtype.py b/test/test_dtype.py index 73170fd562..231dd704c9 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -252,6 +252,7 @@ class TestTypeSpec(unittest.TestCase): def test_creation(self, default_int, default_float): dtypes.default_int, dtypes.default_float = default_int, default_float assert Tensor(True).dtype == dtypes.bool + assert Tensor(None).dtype == dtypes.default_float assert Tensor(2).dtype == dtypes.default_int assert Tensor(2.34).dtype == dtypes.default_float assert Tensor([]).dtype == dtypes.default_float diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3ef2c17349..a91ecc0ad3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -62,7 +62,7 @@ class Tensor: if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" elif isinstance(data, (bool, int, float)): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) - elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or dtypes.default_float).np)) + elif data is None: data = LazyBuffer.loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) elif isinstance(data, list): if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool elif d and all_int(d): dtype = dtype or dtypes.default_int