diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 120782bbe1..565fe09e31 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -162,9 +162,8 @@ def randperm_generator(n, generator=None, out=None): @torch.library.impl("aten::cummax", "privateuseone") def cummax(self, dim): - # TODO: support cummax with indices to match torch - cummax, indices = aten.cummax(self.cpu(), dim) - return (cummax.tiny(), indices.tiny()) + values, indices = unwrap(self).cummax(dim) + return (wrap(values), wrap(indices.cast(dtypes.int64))) @torch.library.impl("aten::nonzero", "privateuseone") # TODO: move to tinygrad diff --git a/test/test_ops.py b/test/test_ops.py index 89df107a4c..961f36fe5b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1074,26 +1074,38 @@ class TestOps(unittest.TestCase): helper_test_op([(2, 3, 0)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2)) def test_small_cummax(self): - helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) + helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) @slow_test def test_simple_cummax(self): - helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) - helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) + helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) + helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) + helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) + helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) @slow_test def test_cummax(self): - helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) - self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError) - helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) - self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError) - self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2), expected=IndexError) - helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) - helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)) - helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)) - helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1)) + helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) + helper_test_op([()], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) + self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0], expected=IndexError) + helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) + helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) + self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0], expected=IndexError) + self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2)[0], expected=IndexError) + helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) + helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) + helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0]) + helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).indices.int(), lambda x: Tensor.cummax(x, axis=1)[1], forward_only=True) + helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)[0]) + helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).indices.int(), lambda x: Tensor.cummax(x, axis=2)[1], forward_only=True) + helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1)[0]) + helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).indices.int(), lambda x: Tensor.cummax(x, axis=-1)[1], forward_only=True) def test_cummax_zero_axis(self): - helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)) - helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)) - helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)) + helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0]) + helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).indices.int(), lambda x: Tensor.cummax(x, axis=1)[1], forward_only=True) + helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0]) + helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True) + helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)[0]) + helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).indices.int(), lambda x: Tensor.cummax(x, axis=2)[1], forward_only=True) def test_argmax(self): # check if it returns the first index for multiple occurrences diff --git a/test/test_schedule.py b/test/test_schedule.py index 02b4024b4c..0a47ba17b3 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -182,7 +182,7 @@ class TestSchedule(unittest.TestCase): assert not a.uop.is_realized def test_simplify_padded_const(self): - a = Tensor.empty(1022).cummax(axis=0) + a, _ = Tensor.empty(1022).cummax(axis=0) check_schedule(a, 3) def test_basic_binop_fusion(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1bd473e579..deca3f3507 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1997,7 +1997,7 @@ class Tensor(OpMixin): x = self.transpose(axis, -1) last_dim_size = x.shape[-1] x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None)) - x_cummax = x.cummax(-1) + x_cummax, _ = x.cummax(-1) mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril() ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax return ret.transpose(-1, axis) @@ -2443,19 +2443,23 @@ class Tensor(OpMixin): """ return self._split_cumalu(axis, Ops.MUL) - def cummax(self, axis:int=0) -> Tensor: + def cummax(self, axis:int=0) -> tuple[Tensor, Tensor]: """ - Computes the cumulative max of the tensor along the specified `axis`. + Computes the cumulative max of the tensor along `axis`, returning (values, indices). ```python exec="true" source="above" session="tensor" result="python" t = Tensor([0, 1, -1, 2, -2, 3, -3]) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.cummax(0).numpy()) + values, indices = t.cummax(0) + print(values.numpy()) + print(indices.numpy()) ``` """ - return self._split_cumalu(axis, Ops.MAX) + if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), Tensor.zeros(self.shape, dtype=dtypes.int32, device=self.device) + values, n = self._split_cumalu(axis, Ops.MAX), self.shape[axis] + x, values_t = self.transpose(axis, -1), values.transpose(axis, -1) + match = (x.unsqueeze(-1) == values_t.unsqueeze(-2)) * Tensor.ones(n, n, requires_grad=False, device=self.device).triu() + idx = (-(match * Tensor.arange(n, 0, -1, requires_grad=False, device=self.device).reshape(n, 1)).max(-2) + n).cast(dtypes.int32) + return values, idx.transpose(-1, axis) @staticmethod def _tri(r:sint, c:sint, diagonal:int=0, device=None, requires_grad:bool|None=None) -> Tensor: