diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9e4ff9d0d5..7f6b141345 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -153,9 +153,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) -# push VIEW to stores +# push VIEW to children view_right = merge_views+PatternMatcher([ - # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> STORE(.., new_val).view() + # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index cf40c1fbed..c028c82745 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1316,18 +1316,19 @@ Variable = UOp ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]] -# *** uop swizzling *** +# *** UOp merge views and swizzling *** merge_views = PatternMatcher([ - (UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st)), - (UPat(Ops.VIEW, name="mv", src=(UPat.var("x"),)), lambda mv,x: x if mv.st.contiguous and x.st is not None and x.shape == mv.shape else None), + # VIEW(VIEW) merges to a single VIEW + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)), + (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None), ]) -# push VIEW to loads +# push VIEW to parents view_left = merge_views+PatternMatcher([ - # VIEW before elementwise ops - (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), - lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))), - # early merge VIEW buffer ops - (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), + # VIEW before elementwise/buffer ops + (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), + lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), + (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)), + lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), ])