do not view_left assign + elementwise sources always have a shape [pr] (#9491)

This commit is contained in:
qazal
2025-03-18 17:42:51 +08:00
committed by GitHub
parent 117b7a16ef
commit cde4fd3be3

View File

@@ -994,9 +994,9 @@ merge_views = PatternMatcher([
view_left = merge_views+PatternMatcher([
# do not push masked view before unsafe pad ops
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)),
lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None),
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="e"),), name="view"),
lambda e,view: e.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
# view before elementwise ops
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(s.st+view.st) if s.op is Ops.VIEW else s.view(view.st) for s in e.src))),
])