diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 2d3dc65b76..9c2302dd88 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -467,18 +467,17 @@ class UOpGraph: # NOTE: relinearizering should be okay #assert self._uops is None, "already linearized" - # fixup gated stores with an IF block to save extra local loads + # replace UOps.STORE gate with an IF block @functools.lru_cache(None) - def _dfs(u:UOp, gate:UOp) -> UOp: + def _replace_gates(u:UOp, gate:UOp) -> UOp: if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER: if_uop = UOp(UOps.IF, None, (gate, u.src[-1])) return UOp(u.op, u.dtype, u.src[:-1]+(if_uop,), u.arg) - if (replace_source:=tuple(_dfs(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg) + if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg) return u sink_srcs = list(self.sink.src) for i, s in enumerate(sink_srcs): - # breaks for WMMA - if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s: + if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_replace_gates(s, s.src[3])) != s: sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg) sink = UOp(UOps.SINK, None, tuple(sink_srcs))