mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
simplify _cumsum with _first_zero=True (#4782)
handled the case with 0 in shape output of _cumsum, and _cumsum returns the correct shape with _first_zero=True
This commit is contained in:
@@ -1710,8 +1710,9 @@ class Tensor:
|
||||
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
||||
|
||||
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
||||
pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
|
||||
return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
|
||||
assert self.shape[axis] != 0
|
||||
pl_sz = self.shape[axis] - int(not _first_zero)
|
||||
return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
"""
|
||||
Computes the cumulative sum of the tensor along the specified axis.
|
||||
@@ -1727,14 +1728,14 @@ class Tensor:
|
||||
```
|
||||
"""
|
||||
axis = self._resolve_dim(axis)
|
||||
if self.ndim == 0: return self
|
||||
if self.ndim == 0 or 0 in self.shape: return self
|
||||
# TODO: someday the optimizer will find this on it's own
|
||||
# for now this is a two stage cumsum
|
||||
SPLIT = 256
|
||||
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
|
||||
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
|
||||
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
||||
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
|
||||
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
||||
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
||||
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
|
||||
return fix(ret) + fix(base_add)
|
||||
|
||||
Reference in New Issue
Block a user