tighter late_buffer_view match [pr] (#14456)

src must be len 2 at that point
This commit is contained in:
chenyu
2026-01-31 07:28:26 -05:00
committed by GitHub
parent d69bc5aa1a
commit 55f806b713

View File

@@ -280,10 +280,10 @@ def late_buffer_view(t:UOp, b:UOp):
if len(shape) == 0: offset = x.src[1].arg
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag), b.src[1]))
to_bufferview = PatternMatcher([
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view),
(UPat(Ops.BUFFERIZE, src=(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="t"), UPat()), name="b"), late_buffer_view),
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
])