From 3b00a778ba8b8b0f62375f6024411325d89c4f04 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 17 Mar 2025 19:02:02 +0800 Subject: [PATCH] fix view_left for unsafe pad ops [pr] (#9478) --- test/test_schedule.py | 10 ++++++++++ tinygrad/engine/schedule.py | 2 +- tinygrad/ops.py | 5 ++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index da866903c5..e09d8e3bd0 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1971,6 +1971,16 @@ class TestSwizzle(unittest.TestCase): t = a_reduce+b_reduce with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1)) + def test_unsafe_pad(self): + x = Tensor.full((2,2), 1.0).contiguous() + y = x*x.sum((1,)).reciprocal() + t = y.pad(((0,1),None)).contiguous() + swizzled = swizzle_rewrite(t.lazydata) + sched = check_schedule(swizzled.sink(), 3) + output_buffer = sched[-1].bufs[0] + run_schedule(sched) + self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.]) + def store_val(si:ScheduleItem): return si.ast.src[0].src[2] zero_pm = UPat(Ops.CONST, arg=0) class TestView(unittest.TestCase): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ef19139859..4fa1222620 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -113,7 +113,7 @@ sym = symbolic_simple+PatternMatcher([ # **** UOp realization -DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK} +DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS} @dataclass(frozen=True) class GrouperContext: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 43b745a619..a4545ba763 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -992,8 +992,11 @@ merge_views = PatternMatcher([ (UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)), ]) -# view before elementwise ops view_left = merge_views+PatternMatcher([ + # do not push masked view before unsafe pad ops + (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)), + lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None), + # view before elementwise ops (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), ])