refactor merge_views matcher [pr] (#10188)

This commit is contained in:
qazal
2025-05-07 14:22:06 +03:00
committed by GitHub
parent 94e07725a6
commit 3a32fa228c

View File

@@ -15,22 +15,22 @@ from tinygrad.spec import type_verify, sched_spec
import sys
sys.setrecursionlimit(10000)
# *** UOp merge views ***
# **** UOp merge views
merge_views = PatternMatcher([
# merge adjacent views
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v2"),), name="v1"), lambda v1,v2: v2.replace(arg=v2.arg+v1.arg)),
# merge unmasked const views
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const"),)),
lambda v,const: const.replace(src=(const.src[0].replace(arg=const.st+v.st),)) if all(x.mask is None for x in (const.st+v.st).views) else None),
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
# merge view on load/store/valid
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="b"),)),
lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
# remove view if it's a contiguous and the shapes match
(UPat(Ops.VIEW, name="v", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda v,x: x if v.arg.contiguous and x.shape == v.shape else None),
# remove mask if there's a zero in the masked dim
(UPat(Ops.VIEW, name="v", src=(UPat(),)),
lambda v: v.const_like(0) if (mask:=v.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
(UPat(Ops.VIEW, src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="x"),), name="view"),
lambda x,view: x.replace(src=tuple((s.st+view.st).to_uop() if s.op is Ops.VIEW else s for s in x.src))),
# merge view on const if it's not masked
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x").view(name="view"),
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
# replace view with base if it's contiguous and the shapes match
(UPat(GroupOp.All-{Ops.DEVICE}, name="x").view(name="view"), lambda x,view: x if view.st.contiguous and x.shape == view.shape else None),
# replace masked view with zero if it can collapse
(UPat(Ops.VIEW, src=(UPat(),), name="view"),
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
# movement ops apply a new view on the base
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
])