From 43385c7dbffe844be58bc74b984697863ac52642 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 9 Jul 2023 17:31:15 -0700 Subject: [PATCH] remove contiguous on full (#1212) --- test/test_ops.py | 3 ++- tinygrad/tensor.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 03262e66e2..ac05d62e0f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -317,7 +317,8 @@ class TestOps(unittest.TestCase): with self.assertRaises(AssertionError): a = Tensor(3.14) a.matmul(a) - + def test_simple_cumsum(self): + helper_test_op([(1024)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) def test_cumsum(self): helper_test_op([(20)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4432f2c68d..35a4d38e4b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -143,7 +143,7 @@ class Tensor: # ***** creation helper functions ***** @staticmethod - def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape).contiguous() + def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)