diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4c4cedb953..ec9c807b2f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -717,15 +717,17 @@ if TRACK_MATCH_STATS: # *** simple graph rewrite engine *** -def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: - nodes: Dict[Tuple, UOp] = {} - replace: Dict[UOp, UOp] = {} - def __inner_rewrite(n:UOp) -> UOp: - if rn := replace.get(n): return rn - replace_source = (n.op, n.dtype, new_src:=tuple(__inner_rewrite(y) for y in n.src), n.arg) - if found := nodes.get(replace_source): replace[n] = found +class RewriteContext: + def __init__(self, pm): + self.pm: PatternMatcher = pm + self.nodes: Dict[Tuple, UOp] = {} + self.replace: Dict[UOp, UOp] = {} + def rewrite(self, n:UOp) -> UOp: + if rn := self.replace.get(n): return rn + replace_source = (n.op, n.dtype, new_src:=tuple(self.rewrite(y) for y in n.src), n.arg) + if found := self.nodes.get(replace_source): self.replace[n] = found else: x = UOp(*replace_source) if new_src != n.src else n - nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x)) else x + self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x return found - return __inner_rewrite(sink) +def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return RewriteContext(pm).rewrite(sink)