diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a4545ba763..301647a72b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -994,9 +994,9 @@ merge_views = PatternMatcher([ view_left = merge_views+PatternMatcher([ # do not push masked view before unsafe pad ops - (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)), - lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None), + (UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="e"),), name="view"), + lambda e,view: e.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None), # view before elementwise ops - (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), - lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), + (UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST}, name="e"),), name="view"), + lambda e,view: e.replace(src=tuple(s.view(s.st+view.st) if s.op is Ops.VIEW else s.view(view.st) for s in e.src))), ])