From 4075208127ff6522b5d56b0c63ff5341eacc1a7b Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 11 Dec 2023 19:33:49 -0500 Subject: [PATCH] some dtype creation spec test cases (#2722) --- test/test_dtype.py | 16 ++++++++++++++++ tinygrad/tensor.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index b6c4a0139e..4036dbc175 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -218,5 +218,21 @@ class TestHelpers(unittest.TestCase): def test_scalar(self, dtype, amt): assert dtype.vec(amt).scalar() == dtype +class TestTypeSpec(unittest.TestCase): + def test_creation(self): + assert Tensor([]).dtype == Tensor.default_type + # assert Tensor([1]).dtype == dtypes.int + assert Tensor([1.1]).dtype == Tensor.default_type + + def test_const_full(self): + assert Tensor.ones([2,3]).dtype == Tensor.default_type + assert Tensor.zeros([2,3]).dtype == Tensor.default_type + assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type + # assert Tensor.full([2,3], 3).dtype == dtypes.int + + def test_reduce_0d_default(self): + assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type + # assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b371c6718e..49cee2d692 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -473,7 +473,7 @@ class Tensor: axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis)) axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_] shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0, mlops.Max: -float("inf")}[fxn]) + if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn]) ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)])) return ret if keepdim else ret.reshape(shape=shape)