RewriteContext [run_process_replay] (#6428)

This commit is contained in:
George Hotz
2024-09-09 16:49:02 +08:00
committed by GitHub
parent 935b6b658f
commit e1d61b048b

View File

@@ -721,15 +721,17 @@ if TRACK_MATCH_STATS:
# *** simple graph rewrite engine ***
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
nodes: Dict[Tuple, UOp] = {}
replace: Dict[UOp, UOp] = {}
def __inner_rewrite(n:UOp) -> UOp:
if rn := replace.get(n): return rn
replace_source = (n.op, n.dtype, new_src:=tuple(__inner_rewrite(y) for y in n.src), n.arg)
if found := nodes.get(replace_source): replace[n] = found
class RewriteContext:
def __init__(self, pm):
self.pm: PatternMatcher = pm
self.nodes: Dict[Tuple, UOp] = {}
self.replace: Dict[UOp, UOp] = {}
def rewrite(self, n:UOp) -> UOp:
if rn := self.replace.get(n): return rn
replace_source = (n.op, n.dtype, new_src:=tuple(self.rewrite(y) for y in n.src), n.arg)
if found := self.nodes.get(replace_source): self.replace[n] = found
else:
x = UOp(*replace_source) if new_src != n.src else n
nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x)) else x
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
return found
return __inner_rewrite(sink)
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink)