From 166a2b19b5ae55e190120d79d36c1e774e67f7d6 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 21 Jun 2024 13:51:40 -0400 Subject: [PATCH] fix reduce axis of 0d tensors (#5089) `x.sum(())` is fine, and `x.sum((1,))` should throw IndexError --- test/test_ops.py | 9 ++++++++- tinygrad/tensor.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 7116c8f754..b2aff9d98e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -776,7 +776,14 @@ class TestOps(unittest.TestCase): helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2))) helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1)) - helper_test_op([()], lambda x: x.sum(), Tensor.sum) + helper_test_op([()], lambda x: x.sum()) + helper_test_op([()], lambda x: x.sum(0)) + helper_test_op([()], lambda x: x.sum(-1)) + helper_test_op([()], lambda x: x.sum(())) + self.helper_test_exception([(3,4,5,6)], lambda x: x.sum(5), lambda x: x.sum(5), expected=IndexError) + self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError) + self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError) + def test_sum_with_zeros_shape(self): helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,))) helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,))) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6a94d49951..a674c0a9e5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1263,8 +1263,8 @@ class Tensor: def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: if self.ndim == 0: - if axis is not None and axis not in [-1, 0]: raise IndexError(f"{axis=} out of range of [-1, 0]") - axis = None + if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]") + axis = () axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis)) axis_ = tuple(self._resolve_dim(x) for x in axis_) ret = fxn.apply(self, axis=axis_)