cleanup the merge_views upats [pr] (#8738)

This commit is contained in:
qazal
2025-01-24 09:49:54 -05:00
committed by GitHub
parent 07069b9988
commit 0814a79cb4
2 changed files with 12 additions and 11 deletions

View File

@@ -153,9 +153,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
# push VIEW to stores
# push VIEW to children
view_right = merge_views+PatternMatcher([
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> STORE(.., new_val).view()
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()

View File

@@ -1316,18 +1316,19 @@ Variable = UOp
ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
# *** uop swizzling ***
# *** UOp merge views and swizzling ***
merge_views = PatternMatcher([
(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st)),
(UPat(Ops.VIEW, name="mv", src=(UPat.var("x"),)), lambda mv,x: x if mv.st.contiguous and x.st is not None and x.shape == mv.shape else None),
# VIEW(VIEW) merges to a single VIEW
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
(UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
])
# push VIEW to loads
# push VIEW to parents
view_left = merge_views+PatternMatcher([
# VIEW before elementwise ops
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
# early merge VIEW buffer ops
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
# VIEW before elementwise/buffer ops
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])