diff --git a/tinygrad/ops.py b/tinygrad/ops.py index eedde4cabf..78d2d503c7 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -719,17 +719,15 @@ if TRACK_MATCH_STATS: # *** simple graph rewrite engine *** -class RewriteContext: - def __init__(self, pm): - self.pm: PatternMatcher = pm - self.nodes: Dict[Tuple, UOp] = {} - self.replace: Dict[UOp, UOp] = {} - def rewrite(self, n:UOp) -> UOp: - if rn := self.replace.get(n): return rn - replace_source = (n.op, n.dtype, new_src:=tuple(self.rewrite(y) for y in n.src), n.arg) - if found := self.nodes.get(replace_source): self.replace[n] = found +def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: + nodes: Dict[Tuple, UOp] = {} + replace: Dict[UOp, UOp] = {} + def __inner_rewrite(n:UOp) -> UOp: + if rn := replace.get(n): return rn + replace_source = (n.op, n.dtype, new_src:=tuple(__inner_rewrite(y) for y in n.src), n.arg) + if found := nodes.get(replace_source): replace[n] = found else: x = UOp(*replace_source) if new_src != n.src else n - self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x + nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x)) else x return found -def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink) + return __inner_rewrite(sink)