feat: late to_bufferview (#12271)

This commit is contained in:
wozeparrot
2025-09-29 00:29:43 -07:00
committed by GitHub
parent e01a3eb59a
commit a982480512

View File

@@ -53,12 +53,6 @@ earliest_rewrites = double_reshape+PatternMatcher([
# copy only to different device
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP, tag=copy.tag) if x.device == copy.device else None),
# handle disk
# TODO: this doesn't need to use st.views
(UPat.var("x").f((Ops.BITCAST, Ops.CONTIGUOUS), name="t"),
lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset), tag=t.tag).reshape(t.shape) if isinstance(x.device, str) \
and x.device.startswith("DISK") else None),
# contiguous/buffer/copy/assign is already contiguous
#(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
])
@@ -407,6 +401,25 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
.f(Ops.INDEX, allow_any_len=True, name="x"), UPat()), name="copy"), pre_bufferize),
])
def late_buffer_view(x, t, b):
if isinstance(t.device, str) and t.device.startswith("DISK"):
rngs = b.src[1:]
size = prod(shape := [int(r.vmax+1) for r in rngs])
if len(shape) == 0:
offset = x.src[1].arg
else:
idxs = x.src[1:]
offset = sum(idx.vmin for idx in idxs)
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag),) + b.src[1:])
return b
to_bufferview = PatternMatcher([
(UPat(Ops.INDEX, name="x").f((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(Ops.BUFFERIZE, allow_any_len=True, name="b"), late_buffer_view),
(UPat(Ops.INDEX, name="x").f((Ops.BITCAST, Ops.CONTIGUOUS), name="t").f(GroupOp.All).f(Ops.BUFFERIZE, allow_any_len=True, name="b"),
late_buffer_view),
(UPat((Ops.BITCAST, Ops.CONTIGUOUS)).f(Ops.BUFFER_VIEW, name="b"), lambda b: b.replace(src=b.src[0].src)),
])
# *****************
# 4. put in buffers for bufferize
# TODO: should BUFFERIZE look a lot more like STORE
@@ -454,7 +467,7 @@ def bufferize_to_store(x:UOp):
# TODO: how is this unified?
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
pm_add_buffers = pm_mops+PatternMatcher([
pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
# move RESHAPEs through MSELECT/MSTACK