From fe03725b2172bf8927ced5734a2b28415969b9ef Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 3 Apr 2024 12:31:24 -0400 Subject: [PATCH] const fold cast unrealized_unpadded_const (#4047) * const fold unrealized_unpadded_const changed the underlying arg directly * CAST_BEFORE_VIEW folds some * fix const index in getitem --- test/test_const_folding.py | 11 ++++++++++- test/test_ops.py | 5 +++++ test/test_uops.py | 5 +---- tinygrad/lazy.py | 2 ++ tinygrad/tensor.py | 2 +- 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 2daa9f1f80..38cacb5fba 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -19,7 +19,6 @@ class TestSimpleConstFolding(unittest.TestCase): _check_ast_count(0, Tensor.ones(4) + Tensor.ones(4)) _check_ast_count(0, Tensor.ones(4) / Tensor.ones(4)) - @unittest.expectedFailure def test_cast(self): _check_ast_count(0, Tensor.ones(4).cast(dtypes.int16)) _check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16)) @@ -89,6 +88,16 @@ class TestMovedConstFolding(unittest.TestCase): def test_add_padded_one(self): _check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),))) + def test_cast_padded(self): + # NOTE: this is folded due to CAST_BEFORE_VIEW + _check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16)) + np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0]) + _check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16)) + np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0]) + # not folded + _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64)) + np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0]) + @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") class TestMultiConstFolding(unittest.TestCase): def test_multi_const_folding_literal(self): diff --git a/test/test_ops.py b/test/test_ops.py index cecf244cc5..2c5328bcf8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -907,6 +907,11 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1]) helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2]) + def test_slice_with_const_tensor(self): + t = Tensor.zeros(1, dtype=dtypes.int) + helper_test_op([(3,3,3)], lambda x: x[:, [0], :], lambda x: x[:, t, :]) + helper_test_op([(3,3,3)], lambda x: x[:, [0], :], lambda x: x[:, t.contiguous(), :]) + def test_slice_one_endpoint_out_of_bounds(self): helper_test_op([(3,3,3)], lambda x: x[0:4]) helper_test_op([(3,3,3)], lambda x: x[-6:4]) diff --git a/test/test_uops.py b/test/test_uops.py index 0584be645a..ef5f0f27d7 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -203,10 +203,7 @@ class TestConstantFolding(unittest.TestCase): def test_cast_const(self): t = Tensor(1, dtype=dtypes.float).cast(dtypes.int) si = create_schedule([t.lazydata]) - assert len(si) == 1 - si = si[0] - lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize() - assert all(uop.uop is not UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} contains non-folded constant cast" + assert len(si) == 0 def test_bitcast_const(self): t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 1e2e260164..8c896f0520 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -78,6 +78,8 @@ class LazyBuffer: def cast(self, dtype:DType, bitcast:bool=False): if self.dtype == dtype: return self if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)") + if self.is_unrealized_unpadded_const() and not bitcast: + return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype)) # TODO: applying this makes gpt2 slower if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base: return self.base.cast(dtype, bitcast)._view(self.st) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cbd31eb9a6..fb7ac9d595 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -422,7 +422,7 @@ class Tensor: else: indices = [indices] # turn scalar Tensors into const val for int indexing if possible - indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices] + indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices] # move Tensor indices to the same device as self indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]