From a44cd1e6f7bcfd8e5c7247b8d5bc45acf6bd15ed Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 29 Dec 2024 15:30:29 +0200 Subject: [PATCH] add collapse_view to the scheduler [pr] (#8440) --- test/test_schedule.py | 33 +++++++++++++++++++++++++++++++++ tinygrad/engine/schedule.py | 5 +++++ tinygrad/ops.py | 1 - 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 4187ded416..b0219595ae 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1981,6 +1981,39 @@ class TestView(unittest.TestCase): run_schedule(sched) np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:]) + # a*VIEW(x), where VIEW(x) = 0 + # x collapses along with its children + def test_parent_view_collapses(self): + a = Tensor([1, 2]) + b = Tensor.arange(3).contiguous() + bv = b.pad(((0, 2),))[-2:] + # this becomes a late a*0 + late_mul = a*bv + check_schedule(late_mul, 0) + # the arange doesn't realize + self.assertIsNone(b.lazydata.base.realized) + # mul doesn't realize + self.assertIsNone(late_mul.lazydata.base.realized) + self.assertEqual(late_mul.tolist(), [0, 0]) + + # SINK has two branches: + # a*VIEW(x), where VIEW(x) = 0 + # x+2 + # as long as one child realizes, x does not collapse + def test_parent_multiple_children_no_collapse(self): + a = Tensor([1, 2]) + b = Tensor.arange(3).contiguous() + bv = b.pad(((0, 2),))[-2:] + late_mul = a*bv + other_child = b+2 + s = check_schedule([late_mul, other_child], 2) + # the arange realizes + self.assertIsNotNone(b.lazydata.base.realized) + # mul still collapses + self.assertIsNone(late_mul.lazydata.base.realized) + run_schedule(s) + self.assertEqual(other_child.tolist(), [2, 3, 4]) + def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, remove_movement_ops+symbolic) class TestBigGraph(unittest.TestCase): def test_sink_childless_const(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 186e652872..0a7ace87f5 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -553,8 +553,13 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: buf_uop.buffer.ref(1) create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) +# **** movement ops + remove_movement_ops = PatternMatcher([ (UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))), + # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) + (UPat(Ops.VIEW, name="view"), + lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), # merge one src (unrealized) views # NOTE: we can't merge realized buffer views here, because the buffer is realized before the view (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat.var("x"),), name="v1")), name="v2"), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d543dead6c..f5187726d1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -487,7 +487,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st) ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st) # instant folding rules - if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0) if new_st.contiguous and self.base.shape == new_st.shape: return self.base return ret