From 3a32fa228c7bcb346ae276082bfc16d87b99fac8 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 7 May 2025 14:22:06 +0300 Subject: [PATCH] refactor merge_views matcher [pr] (#10188) --- tinygrad/engine/grouper.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 41a8469b8f..e34c50de88 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -15,22 +15,22 @@ from tinygrad.spec import type_verify, sched_spec import sys sys.setrecursionlimit(10000) -# *** UOp merge views *** +# **** UOp merge views merge_views = PatternMatcher([ # merge adjacent views - (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v2"),), name="v1"), lambda v1,v2: v2.replace(arg=v2.arg+v1.arg)), - # merge unmasked const views - (UPat(Ops.VIEW, name="v", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const"),)), - lambda v,const: const.replace(src=(const.src[0].replace(arg=const.st+v.st),)) if all(x.mask is None for x in (const.st+v.st).views) else None), + (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)), # merge view on load/store/valid - (UPat(Ops.VIEW, name="v", src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="b"),)), - lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), - # remove view if it's a contiguous and the shapes match - (UPat(Ops.VIEW, name="v", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda v,x: x if v.arg.contiguous and x.shape == v.shape else None), - # remove mask if there's a zero in the masked dim - (UPat(Ops.VIEW, name="v", src=(UPat(),)), - lambda v: v.const_like(0) if (mask:=v.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None), + (UPat(Ops.VIEW, src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="x"),), name="view"), + lambda x,view: x.replace(src=tuple((s.st+view.st).to_uop() if s.op is Ops.VIEW else s for s in x.src))), + # merge view on const if it's not masked + (UPat((Ops.CONST, Ops.DEFINE_VAR), name="x").view(name="view"), + lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None), + # replace view with base if it's contiguous and the shapes match + (UPat(GroupOp.All-{Ops.DEVICE}, name="x").view(name="view"), lambda x,view: x if view.st.contiguous and x.shape == view.shape else None), + # replace masked view with zero if it can collapse + (UPat(Ops.VIEW, src=(UPat(),), name="view"), + lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None), # movement ops apply a new view on the base (UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)), ])