mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
replace gates in uopgraph [run_process_replay]
This commit is contained in:
@@ -467,18 +467,17 @@ class UOpGraph:
|
||||
# NOTE: relinearizering should be okay
|
||||
#assert self._uops is None, "already linearized"
|
||||
|
||||
# fixup gated stores with an IF block to save extra local loads
|
||||
# replace UOps.STORE gate with an IF block
|
||||
@functools.lru_cache(None)
|
||||
def _dfs(u:UOp, gate:UOp) -> UOp:
|
||||
def _replace_gates(u:UOp, gate:UOp) -> UOp:
|
||||
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
|
||||
if_uop = UOp(UOps.IF, None, (gate, u.src[-1]))
|
||||
return UOp(u.op, u.dtype, u.src[:-1]+(if_uop,), u.arg)
|
||||
if (replace_source:=tuple(_dfs(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
return u
|
||||
sink_srcs = list(self.sink.src)
|
||||
for i, s in enumerate(sink_srcs):
|
||||
# breaks for WMMA
|
||||
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_dfs(s, s.src[3])) != s:
|
||||
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_replace_gates(s, s.src[3])) != s:
|
||||
sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg)
|
||||
sink = UOp(UOps.SINK, None, tuple(sink_srcs))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user