mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-13 08:58:05 -05:00
also return index in Tensor.cummax (#14117)
* also return index in Tensor.cummax * fix
This commit is contained in:
@@ -162,9 +162,8 @@ def randperm_generator(n, generator=None, out=None):
|
||||
|
||||
@torch.library.impl("aten::cummax", "privateuseone")
|
||||
def cummax(self, dim):
|
||||
# TODO: support cummax with indices to match torch
|
||||
cummax, indices = aten.cummax(self.cpu(), dim)
|
||||
return (cummax.tiny(), indices.tiny())
|
||||
values, indices = unwrap(self).cummax(dim)
|
||||
return (wrap(values), wrap(indices.cast(dtypes.int64)))
|
||||
|
||||
@torch.library.impl("aten::nonzero", "privateuseone")
|
||||
# TODO: move to tinygrad
|
||||
|
||||
@@ -1074,26 +1074,38 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(2, 3, 0)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2))
|
||||
|
||||
def test_small_cummax(self):
|
||||
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
|
||||
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
|
||||
@slow_test
|
||||
def test_simple_cummax(self):
|
||||
helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
|
||||
helper_test_op([(512)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
|
||||
helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
|
||||
helper_test_op([(1022)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
|
||||
@slow_test
|
||||
def test_cummax(self):
|
||||
helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError)
|
||||
helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1), expected=IndexError)
|
||||
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2), expected=IndexError)
|
||||
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1))
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2))
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1))
|
||||
helper_test_op([()], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
|
||||
helper_test_op([()], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
|
||||
self.helper_test_exception([()], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0], expected=IndexError)
|
||||
helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
|
||||
helper_test_op([(20,)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
|
||||
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0], expected=IndexError)
|
||||
self.helper_test_exception([(20,)], lambda x: torch.cummax(x, dim=-2).values, lambda x: Tensor.cummax(x, axis=-2)[0], expected=IndexError)
|
||||
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
|
||||
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
|
||||
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0])
|
||||
helper_test_op([(20,30)], lambda x: torch.cummax(x, dim=1).indices.int(), lambda x: Tensor.cummax(x, axis=1)[1], forward_only=True)
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)[0])
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=2).indices.int(), lambda x: Tensor.cummax(x, axis=2)[1], forward_only=True)
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).values, lambda x: Tensor.cummax(x, axis=-1)[0])
|
||||
helper_test_op([(20,30,40)], lambda x: torch.cummax(x, dim=-1).indices.int(), lambda x: Tensor.cummax(x, axis=-1)[1], forward_only=True)
|
||||
def test_cummax_zero_axis(self):
|
||||
helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1))
|
||||
helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2))
|
||||
helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).values, lambda x: Tensor.cummax(x, axis=1)[0])
|
||||
helper_test_op([(2,0,4)], lambda x: torch.cummax(x, dim=1).indices.int(), lambda x: Tensor.cummax(x, axis=1)[1], forward_only=True)
|
||||
helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0)[0])
|
||||
helper_test_op([(0,3)], lambda x: torch.cummax(x, dim=0).indices.int(), lambda x: Tensor.cummax(x, axis=0)[1], forward_only=True)
|
||||
helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).values, lambda x: Tensor.cummax(x, axis=2)[0])
|
||||
helper_test_op([(2,3,0)], lambda x: torch.cummax(x, dim=2).indices.int(), lambda x: Tensor.cummax(x, axis=2)[1], forward_only=True)
|
||||
|
||||
def test_argmax(self):
|
||||
# check if it returns the first index for multiple occurrences
|
||||
|
||||
@@ -182,7 +182,7 @@ class TestSchedule(unittest.TestCase):
|
||||
assert not a.uop.is_realized
|
||||
|
||||
def test_simplify_padded_const(self):
|
||||
a = Tensor.empty(1022).cummax(axis=0)
|
||||
a, _ = Tensor.empty(1022).cummax(axis=0)
|
||||
check_schedule(a, 3)
|
||||
|
||||
def test_basic_binop_fusion(self):
|
||||
|
||||
@@ -1997,7 +1997,7 @@ class Tensor(OpMixin):
|
||||
x = self.transpose(axis, -1)
|
||||
last_dim_size = x.shape[-1]
|
||||
x_unsqueezed = x.unsqueeze(-2).expand((None,)*(self.ndim-1)+(last_dim_size, None))
|
||||
x_cummax = x.cummax(-1)
|
||||
x_cummax, _ = x.cummax(-1)
|
||||
mask = Tensor.ones(last_dim_size, last_dim_size, requires_grad=False, device=self.device).tril()
|
||||
ret = mask.where(x_unsqueezed - x_cummax.unsqueeze(-1), dtypes.min(self.dtype)).exp().sum(-1).log() + x_cummax
|
||||
return ret.transpose(-1, axis)
|
||||
@@ -2443,19 +2443,23 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.MUL)
|
||||
|
||||
def cummax(self, axis:int=0) -> Tensor:
|
||||
def cummax(self, axis:int=0) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Computes the cumulative max of the tensor along the specified `axis`.
|
||||
Computes the cumulative max of the tensor along `axis`, returning (values, indices).
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([0, 1, -1, 2, -2, 3, -3])
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.cummax(0).numpy())
|
||||
values, indices = t.cummax(0)
|
||||
print(values.numpy())
|
||||
print(indices.numpy())
|
||||
```
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.MAX)
|
||||
if self.ndim == 0: return self._split_cumalu(axis, Ops.MAX), Tensor.zeros(self.shape, dtype=dtypes.int32, device=self.device)
|
||||
values, n = self._split_cumalu(axis, Ops.MAX), self.shape[axis]
|
||||
x, values_t = self.transpose(axis, -1), values.transpose(axis, -1)
|
||||
match = (x.unsqueeze(-1) == values_t.unsqueeze(-2)) * Tensor.ones(n, n, requires_grad=False, device=self.device).triu()
|
||||
idx = (-(match * Tensor.arange(n, 0, -1, requires_grad=False, device=self.device).reshape(n, 1)).max(-2) + n).cast(dtypes.int32)
|
||||
return values, idx.transpose(-1, axis)
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:sint, c:sint, diagonal:int=0, device=None, requires_grad:bool|None=None) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user