diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index b0292cee29..2d3dc65b76 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -478,9 +478,8 @@ class UOpGraph: sink_srcs = list(self.sink.src) for i, s in enumerate(sink_srcs): # breaks for WMMA - if all(x.op is not UOps.WMMA for x in s.parents): - if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s: - sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg) + if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s: + sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg) sink = UOp(UOps.SINK, None, tuple(sink_srcs)) # do graph rewrite