merge all one op views [pr] (#8412)

* merge all one op views [pr]

* does this work?

* this won't work (yet)

* apply movement ops on top of the BUFFER

* buffer needs to become base next

---------

Co-authored-by: qazal <qazal.software@gmail.com>
Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
George Hotz
2024-12-29 16:06:49 -05:00
committed by GitHub
parent 0d2400fc7c
commit a3c359b28b

View File

@@ -521,7 +521,6 @@ def unbind_variable(ctx:ScheduleContext, bind:UOp, st:UOp):
return UOp.const(bind.dtype, bind).valid(unwrap(st.st))
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
assert st.size == b.size and unwrap(st.st).contiguous, f"ShapeTracker of realized {b} BUFFER must match the BUFFER size {st}"
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
@@ -556,14 +555,13 @@ create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER,
# **** movement ops
remove_movement_ops = PatternMatcher([
(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),
# NOTE: movement ops are always applied to base
(UPat(GroupOp.Movement, name="mov", src=(UPat.any(UPat.var("x").view(), UPat.var("x")))), lambda x,mov: x.view(unwrap(mov.st))),
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
(UPat(Ops.VIEW, name="view"),
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
# merge one src (unrealized) views
# NOTE: we can't merge realized buffer views here, because the buffer is realized before the view
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat.var("x"),), name="v1")), name="v2"),
lambda x,v1,v2: v1.replace(arg=v1.arg+v2.arg) if x.op is not Ops.BUFFER else None),
# merge one src views.
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(),), name="v1")), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
# merge unmasked const views
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),