simplify _cumsum with _first_zero=True (#4782)

handled the case with 0 in shape output of _cumsum, and _cumsum returns the correct shape with _first_zero=True
This commit is contained in:
chenyu
2024-05-30 13:19:33 -04:00
committed by GitHub
parent 4921de1945
commit c4d1283049

View File

@@ -1710,8 +1710,9 @@ class Tensor:
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
assert self.shape[axis] != 0
pl_sz = self.shape[axis] - int(not _first_zero)
return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
def cumsum(self, axis:int=0) -> Tensor:
"""
Computes the cumulative sum of the tensor along the specified axis.
@@ -1727,14 +1728,14 @@ class Tensor:
```
"""
axis = self._resolve_dim(axis)
if self.ndim == 0: return self
if self.ndim == 0 or 0 in self.shape: return self
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
return fix(ret) + fix(base_add)