Revert "RewriteContext [run_process_replay] (#6428)" (#6438)

This reverts commit e1d61b048b.
This commit is contained in:
George Hotz
2024-09-09 18:53:18 +08:00
committed by GitHub
parent eda177da84
commit e7dd08448f

View File

@@ -719,17 +719,15 @@ if TRACK_MATCH_STATS:
# *** simple graph rewrite engine ***
class RewriteContext:
def __init__(self, pm):
self.pm: PatternMatcher = pm
self.nodes: Dict[Tuple, UOp] = {}
self.replace: Dict[UOp, UOp] = {}
def rewrite(self, n:UOp) -> UOp:
if rn := self.replace.get(n): return rn
replace_source = (n.op, n.dtype, new_src:=tuple(self.rewrite(y) for y in n.src), n.arg)
if found := self.nodes.get(replace_source): self.replace[n] = found
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
nodes: Dict[Tuple, UOp] = {}
replace: Dict[UOp, UOp] = {}
def __inner_rewrite(n:UOp) -> UOp:
if rn := replace.get(n): return rn
replace_source = (n.op, n.dtype, new_src:=tuple(__inner_rewrite(y) for y in n.src), n.arg)
if found := nodes.get(replace_source): replace[n] = found
else:
x = UOp(*replace_source) if new_src != n.src else n
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x)) else x
return found
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink)
return __inner_rewrite(sink)