diff --git a/test/test_ops.py b/test/test_ops.py index 5abca633e7..58b983ed08 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -510,6 +510,10 @@ class TestOps(unittest.TestCase): helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1)) helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2)) helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=-1), lambda x: Tensor.cumsum(x, axis=-1)) + def test_cumsum_zero_axis(self): + helper_test_op([(2,0,4)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1)) + helper_test_op([(0,3)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2)) def test_argmax(self): self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ff29ce1d17..c820209877 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -735,7 +735,8 @@ class Tensor: return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype) def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: - return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) + pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0) + return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1) def cumsum(self, axis:int=0) -> Tensor: # TODO: someday the optimizer will find this on it's own # for now this is a two stage cumsum