From ed121d235c7edbc4f0bd23c788dc2fcfe4abf455 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 6 Jan 2025 10:43:58 +0200 Subject: [PATCH] spec for CAST_BEFORE_VIEW=1 [pr] (#8512) --- test/test_schedule.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/test/test_schedule.py b/test/test_schedule.py index 69c9e9ca9e..a8eecf4932 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1437,6 +1437,42 @@ class TestSchedule(unittest.TestCase): def test_late_fusion_post_expand(self): self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2) + def test_cast_padded_view(self): + a = Tensor.arange(4).reshape(1, 4) + casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float) + casted_view.realize() + self.assertEqual(casted_view.lazydata.base.realized.size, 4) + realized_view = casted_view.contiguous().realize() + self.assertEqual(realized_view.lazydata.base.realized.size, 8) + self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]]) + + # NOTE: we might want to reconsider pushing this cast before the shrink + def test_cast_after_shrink(self): + a = Tensor.arange(4).reshape(1, 4) + casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float) + casted_view.realize() + self.assertEqual(casted_view.lazydata.base.realized.size, 4) + realized_view = casted_view.contiguous().realize() + self.assertEqual(realized_view.lazydata.base.realized.size, 2) + self.assertListEqual(realized_view.tolist(), [[0, 1]]) + + def test_cast_const_view(self): + a = Tensor.ones((4, 4), dtype=dtypes.float32) + casted_view = a.cast(dtypes.int32) + run_schedule(check_schedule(casted_view, 0)) + self.assertIsNone(casted_view.lazydata.base.realized) + realized_const_view = casted_view.contiguous() + run_schedule(check_schedule(realized_const_view, 1)) + self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]) + + def test_cast_padded_const(self): + a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None)) + casted_view = a.cast(dtypes.float32) + run_schedule(check_schedule(casted_view, 0)) + realized_const_view = casted_view.contiguous() + run_schedule(check_schedule(realized_const_view, 1)) + self.assertListEqual(realized_const_view.tolist(), [[0], [1], [0]]) + class TestIndexing(unittest.TestCase): def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int): with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):