diff --git a/test/test_ops.py b/test/test_ops.py index 03262e66e2..ac05d62e0f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -317,7 +317,8 @@ class TestOps(unittest.TestCase): with self.assertRaises(AssertionError): a = Tensor(3.14) a.matmul(a) - + def test_simple_cumsum(self): + helper_test_op([(1024)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) def test_cumsum(self): helper_test_op([(20)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4432f2c68d..35a4d38e4b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -143,7 +143,7 @@ class Tensor: # ***** creation helper functions ***** @staticmethod - def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape).contiguous() + def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)