mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
do not view_left assign + elementwise sources always have a shape [pr] (#9491)
This commit is contained in:
@@ -994,9 +994,9 @@ merge_views = PatternMatcher([
|
||||
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# do not push masked view before unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)),
|
||||
lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None),
|
||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="e"),), name="view"),
|
||||
lambda e,view: e.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
|
||||
# view before elementwise ops
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
|
||||
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST}, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(s.st+view.st) if s.op is Ops.VIEW else s.view(view.st) for s in e.src))),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user