fix view_left for unsafe pad ops [pr] (#9478)

This commit is contained in:
qazal
2025-03-17 19:02:02 +08:00
committed by GitHub
parent 813f713edc
commit 3b00a778ba
3 changed files with 15 additions and 2 deletions

View File

@@ -1971,6 +1971,16 @@ class TestSwizzle(unittest.TestCase):
t = a_reduce+b_reduce
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
def test_unsafe_pad(self):
x = Tensor.full((2,2), 1.0).contiguous()
y = x*x.sum((1,)).reciprocal()
t = y.pad(((0,1),None)).contiguous()
swizzled = swizzle_rewrite(t.lazydata)
sched = check_schedule(swizzled.sink(), 3)
output_buffer = sched[-1].bufs[0]
run_schedule(sched)
self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.])
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
zero_pm = UPat(Ops.CONST, arg=0)
class TestView(unittest.TestCase):

View File

@@ -113,7 +113,7 @@ sym = symbolic_simple+PatternMatcher([
# **** UOp realization
DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK}
DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS}
@dataclass(frozen=True)
class GrouperContext:

View File

@@ -992,8 +992,11 @@ merge_views = PatternMatcher([
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
])
# view before elementwise ops
view_left = merge_views+PatternMatcher([
# do not push masked view before unsafe pad ops
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)),
lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None),
# view before elementwise ops
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
])