From bbb2dd8141ca5d6033c61a91b1f34da1fb1f13fc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 26 Jan 2025 09:58:05 -0500 Subject: [PATCH] move VALID creation after merging the views (#8757) * do valid creation later * work for view_left * only view(const) makes valids in view_left * cleaner bind diff --- tinygrad/engine/schedule.py | 13 +++++-------- tinygrad/ops.py | 5 +++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f8f06b821b..3b6828688d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -199,6 +199,8 @@ to_si = PatternMatcher([ (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), # once images are loaded they become the base dtype (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), + # CONST(VIEW) becomes VALID too, TODO: doesn't have to + (UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)), ]) # LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel @@ -438,11 +440,11 @@ do_realize = PatternMatcher([ (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), ]) -# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp +# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp): - assert isinstance(val.src[1].const_arg, int), f"expected BIND value to be int {val}" - ctx.var_vals[ret:=var.replace(src=())] = val.src[1].const_arg + assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}" + ctx.var_vals[ret:=var.replace(src=())] = val.const_arg return ret.valid(unwrap(bind.st)) def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): @@ -456,8 +458,6 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) break_sched = PatternMatcher([ - # CONST is always fused and generated - (UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)), (UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable), # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), @@ -481,9 +481,6 @@ remove_movement_ops = merge_views+PatternMatcher([ # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) (UPat(Ops.VIEW, name="view"), lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), - # merge unmasked const views - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), - lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) @track_rewrites(named=True) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 950589a80a..7f30ebe5a0 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1322,10 +1322,15 @@ merge_views = PatternMatcher([ # VIEW(VIEW) merges to a single VIEW (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)), (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None), + # merge unmasked const views + (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), + lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) # push VIEW to parents view_left = merge_views+PatternMatcher([ + # VIEW(CONST) becomes VALID + (UPat(Ops.VIEW, name="vm", src=(UPat.cvar("x"),)), lambda vm,x: UOp.const(x.dtype, x.const_arg).valid(vm.st)), # VIEW before elementwise/buffer ops (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),