Add cumprod to Tensor (#9629)

* probably how cumprod should look like

* update _cumalu to work with MUL

* shorter

* cumprod testing

* clean

* more cleanup

* add cumprod to torch backend.

* make it look like cumsum

* mypy fix

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Yvon Manzi
2025-03-31 03:49:18 +02:00
committed by GitHub
parent d52e91db7b
commit 6652003839
3 changed files with 38 additions and 9 deletions

View File

@@ -109,10 +109,6 @@ def index_put(self, indices, values, accumulate=False):
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
@torch.library.impl("aten::cumprod", "privateuseone")
# TODO: move to tinygrad
def cumprod(self, dim, dtype=None): return aten.cumprod(self.cpu(), dim, dtype=dtype).tiny()
@torch.library.impl("aten::cummax", "privateuseone")
def cummax(self, dim):
# TODO: support cummax with indices to match torch
@@ -388,7 +384,7 @@ simple_tensor_methods = [
# modify
"tril", "triu",
# reduce
"all", "any", "argmax", "argmin", "cumsum",
"all", "any", "argmax", "argmin", "cumsum", "cumprod",
# complex
"avg_pool2d", "linspace"]

View File

@@ -991,6 +991,26 @@ class TestOps(unittest.TestCase):
helper_test_op([(0,3)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0))
helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2))
def test_small_cumprod(self):
helper_test_op([(10)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
def test_simple_cumprod(self):
helper_test_op([(512)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
helper_test_op([(1022)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
def test_cumprod(self):
helper_test_op([()],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
self.helper_test_exception([()],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1),expected=IndexError)
helper_test_op([(20,)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
self.helper_test_exception([(20,)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1),expected=IndexError)
self.helper_test_exception([(20,)],lambda x: torch.cumprod(x, dim=-2),lambda x: Tensor.cumprod(x, axis=-2),expected=IndexError)
helper_test_op([(20, 30)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
helper_test_op([(20, 30)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1))
helper_test_op([(20, 30, 40)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2))
helper_test_op([(20, 30, 40)],lambda x: torch.cumprod(x, dim=-1),lambda x: Tensor.cumprod(x, axis=-1))
def test_cumprod_zero_axis(self):
helper_test_op([(2, 0, 4)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1))
helper_test_op([(0, 3)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
helper_test_op([(2, 3, 0)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2))
def test_small_cummax(self):
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
def test_simple_cummax(self):

View File

@@ -2383,11 +2383,10 @@ class Tensor(SimpleMathTrait):
return x.dot(self, dtype=dtype) if reverse else self.dot(x, dtype=dtype)
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL)
pl_sz = self.shape[axis] - int(not _include_initial)
pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
return cast(Callable[[int], Tensor], {Ops.ADD: pooled.sum, Ops.MAX: pooled.max, Ops.MUL: pooled.prod}[op])(-1).transpose(axis, -1)
def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
axis = self._resolve_dim(axis)
if self.ndim == 0 or 0 in self.shape: return self
@@ -2399,7 +2398,7 @@ class Tensor(SimpleMathTrait):
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
return {Ops.ADD: lambda a, b: a + b, Ops.MAX: lambda a, b: a.maximum(b), Ops.MUL: lambda a, b: a * b}[op](fix(ret), fix(base))
def cumsum(self, axis:int=0) -> Tensor:
"""
@@ -2415,6 +2414,20 @@ class Tensor(SimpleMathTrait):
"""
return self._split_cumalu(axis, Ops.ADD)
def cumprod(self, axis:int) -> Tensor:
"""
Computes the cumulative product of the elements of the tensor along the specified `axis`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(1, 7).reshape(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.cumprod(axis=0).numpy())
```
"""
return self._split_cumalu(axis, Ops.MUL)
def cummax(self, axis:int=0) -> Tensor:
"""
Computes the cumulative max of the tensor along the specified `axis`.