mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix reduce axis of 0d tensors (#5089)
`x.sum(())` is fine, and `x.sum((1,))` should throw IndexError
This commit is contained in:
@@ -776,7 +776,14 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1))
|
||||
helper_test_op([()], lambda x: x.sum(), Tensor.sum)
|
||||
helper_test_op([()], lambda x: x.sum())
|
||||
helper_test_op([()], lambda x: x.sum(0))
|
||||
helper_test_op([()], lambda x: x.sum(-1))
|
||||
helper_test_op([()], lambda x: x.sum(()))
|
||||
self.helper_test_exception([(3,4,5,6)], lambda x: x.sum(5), lambda x: x.sum(5), expected=IndexError)
|
||||
self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError)
|
||||
self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError)
|
||||
|
||||
def test_sum_with_zeros_shape(self):
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,)))
|
||||
helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,)))
|
||||
|
||||
@@ -1263,8 +1263,8 @@ class Tensor:
|
||||
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
||||
if self.ndim == 0:
|
||||
if axis is not None and axis not in [-1, 0]: raise IndexError(f"{axis=} out of range of [-1, 0]")
|
||||
axis = None
|
||||
if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]")
|
||||
axis = ()
|
||||
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
|
||||
axis_ = tuple(self._resolve_dim(x) for x in axis_)
|
||||
ret = fxn.apply(self, axis=axis_)
|
||||
|
||||
Reference in New Issue
Block a user