From a3c359b28b8bee380b7086489ccbb6d32e674881 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 29 Dec 2024 16:06:49 -0500 Subject: [PATCH] merge all one op views [pr] (#8412) * merge all one op views [pr] * does this work? * this won't work (yet) * apply movement ops on top of the BUFFER * buffer needs to become base next --------- Co-authored-by: qazal Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> --- tinygrad/engine/schedule.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 0a7ace87f5..f17c061195 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -521,7 +521,6 @@ def unbind_variable(ctx:ScheduleContext, bind:UOp, st:UOp): return UOp.const(bind.dtype, bind).valid(unwrap(st.st)) def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): - assert st.size == b.size and unwrap(st.st).contiguous, f"ShapeTracker of realized {b} BUFFER must match the BUFFER size {st}" # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) @@ -556,14 +555,13 @@ create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, # **** movement ops remove_movement_ops = PatternMatcher([ - (UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))), + # NOTE: movement ops are always applied to base + (UPat(GroupOp.Movement, name="mov", src=(UPat.any(UPat.var("x").view(), UPat.var("x")))), lambda x,mov: x.view(unwrap(mov.st))), # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) (UPat(Ops.VIEW, name="view"), lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), - # merge one src (unrealized) views - # NOTE: we can't merge realized buffer views here, because the buffer is realized before the view - (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat.var("x"),), name="v1")), name="v2"), - lambda x,v1,v2: v1.replace(arg=v1.arg+v2.arg) if x.op is not Ops.BUFFER else None), + # merge one src views. + (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(),), name="v1")), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)), # merge unmasked const views (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),