diff --git a/tinygrad/ops.py b/tinygrad/ops.py index da5faffbcc..ee810a4c57 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -880,11 +880,11 @@ class RewriteContext: self.pm: PatternMatcher = pm self.ctx = ctx self.replace: dict[UOp, UOp] = {} - def rewrite(self, n:UOp) -> UOp: + def top_down_rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn - new_src = tuple(map(self.rewrite, n.src)) + new_src = tuple(map(self.top_down_rewrite, n.src)) new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg) - self.replace[n] = ret = n if new_n is None else self.rewrite(new_n) + self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n) return ret def bottom_up_rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn @@ -897,13 +897,13 @@ class RewriteContext: def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp: if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)) - return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink) + return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink) def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]: if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)) rewrite_ctx = RewriteContext(pm, ctx) - return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.rewrite(k)) for k in list(sink.toposort)[::-1]} + return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]} # ***** uop type spec *****