fix reduce axis of 0d tensors (#5089)

`x.sum(())` is fine, and `x.sum((1,))` should throw IndexError
This commit is contained in:
chenyu
2024-06-21 13:51:40 -04:00
committed by GitHub
parent 3ff048b68c
commit 166a2b19b5
2 changed files with 10 additions and 3 deletions

View File

@@ -776,7 +776,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1))
helper_test_op([()], lambda x: x.sum(), Tensor.sum)
helper_test_op([()], lambda x: x.sum())
helper_test_op([()], lambda x: x.sum(0))
helper_test_op([()], lambda x: x.sum(-1))
helper_test_op([()], lambda x: x.sum(()))
self.helper_test_exception([(3,4,5,6)], lambda x: x.sum(5), lambda x: x.sum(5), expected=IndexError)
self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError)
self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError)
def test_sum_with_zeros_shape(self):
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)))
helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,)))

View File

@@ -1263,8 +1263,8 @@ class Tensor:
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
if self.ndim == 0:
if axis is not None and axis not in [-1, 0]: raise IndexError(f"{axis=} out of range of [-1, 0]")
axis = None
if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
axis = ()
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
axis_ = tuple(self._resolve_dim(x) for x in axis_)
ret = fxn.apply(self, axis=axis_)