mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
refactor merge_views matcher [pr] (#10188)
This commit is contained in:
@@ -15,22 +15,22 @@ from tinygrad.spec import type_verify, sched_spec
|
||||
import sys
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# *** UOp merge views ***
|
||||
# **** UOp merge views
|
||||
|
||||
merge_views = PatternMatcher([
|
||||
# merge adjacent views
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v2"),), name="v1"), lambda v1,v2: v2.replace(arg=v2.arg+v1.arg)),
|
||||
# merge unmasked const views
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const"),)),
|
||||
lambda v,const: const.replace(src=(const.src[0].replace(arg=const.st+v.st),)) if all(x.mask is None for x in (const.st+v.st).views) else None),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# merge view on load/store/valid
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="b"),)),
|
||||
lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
# remove view if it's a contiguous and the shapes match
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda v,x: x if v.arg.contiguous and x.shape == v.shape else None),
|
||||
# remove mask if there's a zero in the masked dim
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat(),)),
|
||||
lambda v: v.const_like(0) if (mask:=v.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="x"),), name="view"),
|
||||
lambda x,view: x.replace(src=tuple((s.st+view.st).to_uop() if s.op is Ops.VIEW else s for s in x.src))),
|
||||
# merge view on const if it's not masked
|
||||
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x").view(name="view"),
|
||||
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
||||
# replace view with base if it's contiguous and the shapes match
|
||||
(UPat(GroupOp.All-{Ops.DEVICE}, name="x").view(name="view"), lambda x,view: x if view.st.contiguous and x.shape == view.shape else None),
|
||||
# replace masked view with zero if it can collapse
|
||||
(UPat(Ops.VIEW, src=(UPat(),), name="view"),
|
||||
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
||||
# movement ops apply a new view on the base
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user