mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 22:38:16 -05:00
rename to top_down_rewrite [pr] (#8583)
This commit is contained in:
@@ -880,11 +880,11 @@ class RewriteContext:
|
||||
self.pm: PatternMatcher = pm
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
def rewrite(self, n:UOp) -> UOp:
|
||||
def top_down_rewrite(self, n:UOp) -> UOp:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
new_src = tuple(map(self.rewrite, n.src))
|
||||
new_src = tuple(map(self.top_down_rewrite, n.src))
|
||||
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
|
||||
self.replace[n] = ret = n if new_n is None else self.rewrite(new_n)
|
||||
self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
|
||||
return ret
|
||||
def bottom_up_rewrite(self, n:UOp) -> UOp:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
@@ -897,13 +897,13 @@ class RewriteContext:
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
||||
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).rewrite(sink)
|
||||
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
|
||||
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
|
||||
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user