mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Add cumprod to Tensor (#9629)
* probably how cumprod should look like * update _cumalu to work with MUL * shorter * cumprod testing * clean * more cleanup * add cumprod to torch backend. * make it look like cumsum * mypy fix --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -109,10 +109,6 @@ def index_put(self, indices, values, accumulate=False):
|
||||
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
|
||||
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
|
||||
|
||||
@torch.library.impl("aten::cumprod", "privateuseone")
|
||||
# TODO: move to tinygrad
|
||||
def cumprod(self, dim, dtype=None): return aten.cumprod(self.cpu(), dim, dtype=dtype).tiny()
|
||||
|
||||
@torch.library.impl("aten::cummax", "privateuseone")
|
||||
def cummax(self, dim):
|
||||
# TODO: support cummax with indices to match torch
|
||||
@@ -388,7 +384,7 @@ simple_tensor_methods = [
|
||||
# modify
|
||||
"tril", "triu",
|
||||
# reduce
|
||||
"all", "any", "argmax", "argmin", "cumsum",
|
||||
"all", "any", "argmax", "argmin", "cumsum", "cumprod",
|
||||
# complex
|
||||
"avg_pool2d", "linspace"]
|
||||
|
||||
|
||||
@@ -991,6 +991,26 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(0,3)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0))
|
||||
helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2))
|
||||
|
||||
def test_small_cumprod(self):
|
||||
helper_test_op([(10)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
|
||||
def test_simple_cumprod(self):
|
||||
helper_test_op([(512)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
|
||||
helper_test_op([(1022)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
|
||||
def test_cumprod(self):
|
||||
helper_test_op([()],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
|
||||
self.helper_test_exception([()],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1),expected=IndexError)
|
||||
helper_test_op([(20,)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
|
||||
self.helper_test_exception([(20,)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1),expected=IndexError)
|
||||
self.helper_test_exception([(20,)],lambda x: torch.cumprod(x, dim=-2),lambda x: Tensor.cumprod(x, axis=-2),expected=IndexError)
|
||||
helper_test_op([(20, 30)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
|
||||
helper_test_op([(20, 30)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1))
|
||||
helper_test_op([(20, 30, 40)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2))
|
||||
helper_test_op([(20, 30, 40)],lambda x: torch.cumprod(x, dim=-1),lambda x: Tensor.cumprod(x, axis=-1))
|
||||
def test_cumprod_zero_axis(self):
|
||||
helper_test_op([(2, 0, 4)],lambda x: torch.cumprod(x, dim=1),lambda x: Tensor.cumprod(x, axis=1))
|
||||
helper_test_op([(0, 3)],lambda x: torch.cumprod(x, dim=0),lambda x: Tensor.cumprod(x, axis=0))
|
||||
helper_test_op([(2, 3, 0)],lambda x: torch.cumprod(x, dim=2),lambda x: Tensor.cumprod(x, axis=2))
|
||||
|
||||
def test_small_cummax(self):
|
||||
helper_test_op([(10)], lambda x: torch.cummax(x, dim=0).values, lambda x: Tensor.cummax(x, axis=0))
|
||||
def test_simple_cummax(self):
|
||||
|
||||
@@ -2383,11 +2383,10 @@ class Tensor(SimpleMathTrait):
|
||||
return x.dot(self, dtype=dtype) if reverse else self.dot(x, dtype=dtype)
|
||||
|
||||
def _cumalu(self, axis:int, op:Ops, _include_initial=False) -> Tensor:
|
||||
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX)
|
||||
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL)
|
||||
pl_sz = self.shape[axis] - int(not _include_initial)
|
||||
pooled = self.transpose(axis,-1).pad((pl_sz, -int(_include_initial)), value=identity_element(op, self.dtype))._pool((self.shape[axis],))
|
||||
return (pooled.sum(-1) if op is Ops.ADD else pooled.max(-1)).transpose(axis,-1)
|
||||
|
||||
return cast(Callable[[int], Tensor], {Ops.ADD: pooled.sum, Ops.MAX: pooled.max, Ops.MUL: pooled.prod}[op])(-1).transpose(axis, -1)
|
||||
def _split_cumalu(self, axis:int, op:Ops) -> Tensor:
|
||||
axis = self._resolve_dim(axis)
|
||||
if self.ndim == 0 or 0 in self.shape: return self
|
||||
@@ -2399,7 +2398,7 @@ class Tensor(SimpleMathTrait):
|
||||
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
|
||||
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
||||
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
|
||||
return {Ops.ADD: lambda a, b: a + b, Ops.MAX: lambda a, b: a.maximum(b), Ops.MUL: lambda a, b: a * b}[op](fix(ret), fix(base))
|
||||
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
"""
|
||||
@@ -2415,6 +2414,20 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.ADD)
|
||||
|
||||
def cumprod(self, axis:int) -> Tensor:
|
||||
"""
|
||||
Computes the cumulative product of the elements of the tensor along the specified `axis`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(1, 7).reshape(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.cumprod(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
return self._split_cumalu(axis, Ops.MUL)
|
||||
|
||||
def cummax(self, axis:int=0) -> Tensor:
|
||||
"""
|
||||
Computes the cumulative max of the tensor along the specified `axis`.
|
||||
|
||||
Reference in New Issue
Block a user