mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 00:08:16 -05:00
merge all one op views [pr] (#8412)
* merge all one op views [pr] * does this work? * this won't work (yet) * apply movement ops on top of the BUFFER * buffer needs to become base next --------- Co-authored-by: qazal <qazal.software@gmail.com> Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
@@ -521,7 +521,6 @@ def unbind_variable(ctx:ScheduleContext, bind:UOp, st:UOp):
|
||||
return UOp.const(bind.dtype, bind).valid(unwrap(st.st))
|
||||
|
||||
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
assert st.size == b.size and unwrap(st.st).contiguous, f"ShapeTracker of realized {b} BUFFER must match the BUFFER size {st}"
|
||||
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
@@ -556,14 +555,13 @@ create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER,
|
||||
# **** movement ops
|
||||
|
||||
remove_movement_ops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),
|
||||
# NOTE: movement ops are always applied to base
|
||||
(UPat(GroupOp.Movement, name="mov", src=(UPat.any(UPat.var("x").view(), UPat.var("x")))), lambda x,mov: x.view(unwrap(mov.st))),
|
||||
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
# merge one src (unrealized) views
|
||||
# NOTE: we can't merge realized buffer views here, because the buffer is realized before the view
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat.var("x"),), name="v1")), name="v2"),
|
||||
lambda x,v1,v2: v1.replace(arg=v1.arg+v2.arg) if x.op is not Ops.BUFFER else None),
|
||||
# merge one src views.
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(),), name="v1")), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# merge unmasked const views
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
|
||||
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
|
||||
|
||||
Reference in New Issue
Block a user