also return index in Tensor.cummax (#14117)

* also return index in Tensor.cummax

* fix
This commit is contained in:
chenyu
2026-01-12 22:42:10 -05:00
committed by GitHub
parent 7c967399a4
commit 05fcb57696
4 changed files with 42 additions and 27 deletions

View File

@@ -162,9 +162,8 @@ def randperm_generator(n, generator=None, out=None):
@torch.library.impl("aten::cummax", "privateuseone")
def cummax(self, dim):
# TODO: support cummax with indices to match torch
cummax, indices = aten.cummax(self.cpu(), dim)
return (cummax.tiny(), indices.tiny())
values, indices = unwrap(self).cummax(dim)
return (wrap(values), wrap(indices.cast(dtypes.int64)))
@torch.library.impl("aten::nonzero", "privateuseone")
# TODO: move to tinygrad

View File

@@ -1074,26 +1074,38 @@ class TestOps(unittest.TestCase):
helper_test_op([(2, 3, 0)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2))
def test_small_cummax(self):
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
@slow_test
def test_simple_cummax(self):
helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
@slow_test
def test_cummax(self):
helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError)
helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError)
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2), expected=IndexError)
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1))
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2))
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1))
helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
helper_test_op([()], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0], expected=IndexError)
helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0], expected=IndexError)
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2)[0], expected=IndexError)
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0])
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).indices.int(), lambda x: Tensor.cummax(x, axis=1)[1], forward_only=True)
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)[0])
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).indices.int(), lambda x: Tensor.cummax(x, axis=2)[1], forward_only=True)
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1)[0])
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).indices.int(), lambda x: Tensor.cummax(x, axis=-1)[1], forward_only=True)
def test_cummax_zero_axis(self):
helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1))
helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2))
helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0])
helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).indices.int(), lambda x: Tensor.cummax(x, axis=1)[1], forward_only=True)
helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)[0])
helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).indices.int(), lambda x: Tensor.cummax(x, axis=2)[1], forward_only=True)
def test_argmax(self):
# check if it returns the first index for multiple occurrences

View File

@@ -182,7 +182,7 @@ class TestSchedule(unittest.TestCase):
assert not a.uop.is_realized
def test_simplify_padded_const(self):
a = Tensor.empty(1022).cummax(axis=0)
a, _ = Tensor.empty(1022).cummax(axis=0)
check_schedule(a, 3)
def test_basic_binop_fusion(self):

View File

@@ -1997,7 +1997,7 @@ class Tensor(OpMixin):
x = self.transpose(axis, -1)
last_dim_size = x.shape[-1]
x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
x_cummax = x.cummax(-1)
x_cummax, _ = x.cummax(-1)
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril()
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax
return ret.transpose(-1, axis)
@@ -2443,19 +2443,23 @@ class Tensor(OpMixin):
"""
return self._split_cumalu(axis, Ops.MUL)
def cummax(self, axis:int=0) -> Tensor:
def cummax(self, axis:int=0) -> tuple[Tensor, Tensor]:
"""
Computes the cumulative max of the tensor along the specified `axis`.
Computes the cumulative max of the tensor along `axis`, returning (values, indices).
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0, 1, -1, 2, -2, 3, -3])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.cummax(0).numpy())
values, indices = t.cummax(0)
print(values.numpy())
print(indices.numpy())
```
"""
return self._split_cumalu(axis, Ops.MAX)
if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), Tensor.zeros(self.shape, dtype=dtypes.int32, device=self.device)
values, n = self._split_cumalu(axis, Ops.MAX), self.shape[axis]
x, values_t = self.transpose(axis, -1), values.transpose(axis, -1)
match = (x.unsqueeze(-1) == values_t.unsqueeze(-2)) * Tensor.ones(n, n, requires_grad=False, device=self.device).triu()
idx = (-(match * Tensor.arange(n, 0, -1, requires_grad=False, device=self.device).reshape(n, 1)).max(-2) + n).cast(dtypes.int32)
return values, idx.transpose(-1, axis)
@staticmethod
def _tri(r:sint, c:sint, diagonal:int=0, device=None, requires_grad:bool|None=None) -> Tensor: