This commit is contained in:
qazal
2024-07-22 14:15:03 +03:00
parent dc21e63bd2
commit 3dc2b48042

View File

@@ -478,9 +478,8 @@ class UOpGraph:
sink_srcs = list(self.sink.src)
for i, s in enumerate(sink_srcs):
# breaks for WMMA
if all(x.op is not UOps.WMMA for x in s.parents):
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s:
sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg)
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s:
sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg)
sink = UOp(UOps.SINK, None, tuple(sink_srcs))
# do graph rewrite