replace gates in uopgraph [run_process_replay]

This commit is contained in:
qazal
2024-07-22 14:22:29 +03:00
parent 3dc2b48042
commit 11d9035fe0

View File

@@ -467,18 +467,17 @@ class UOpGraph:
# NOTE: relinearizering should be okay
#assert self._uops is None, "already linearized"
# fixup gated stores with an IF block to save extra local loads
# replace UOps.STORE gate with an IF block
@functools.lru_cache(None)
def _dfs(u:UOp, gate:UOp) -> UOp:
def _replace_gates(u:UOp, gate:UOp) -> UOp:
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
if_uop = UOp(UOps.IF, None, (gate, u.src[-1]))
return UOp(u.op, u.dtype, u.src[:-1]+(if_uop,), u.arg)
if (replace_source:=tuple(_dfs(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg)
if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg)
return u
sink_srcs = list(self.sink.src)
for i, s in enumerate(sink_srcs):
# breaks for WMMA
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s:
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_replace_gates(s, s.src[3])) != s:
sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg)
sink = UOp(UOps.SINK, None, tuple(sink_srcs))