mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
fixed point rewrite [pr] (#10732)
This commit is contained in:
@@ -879,6 +879,12 @@ class PatternMatcher:
|
||||
if (ret:=match(uop, ctx)) is not None: return ret
|
||||
return None
|
||||
|
||||
def fixed_point_rewrite(self, uop:UOp, ctx=None) -> UOp:
|
||||
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
|
||||
new_n: UOp|None = uop
|
||||
while new_n is not None: last_n, new_n = new_n, self.rewrite(new_n, ctx)
|
||||
return last_n
|
||||
|
||||
# *** tracking pattern matcher ***
|
||||
|
||||
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
||||
@@ -998,10 +1004,9 @@ class RewriteContext:
|
||||
return ret
|
||||
def bottom_up_rewrite(self, n:UOp) -> UOp:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
new_n: UOp|None = n
|
||||
while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
|
||||
new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
|
||||
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
||||
new_n = self.pm.fixed_point_rewrite(n, self.ctx)
|
||||
new_src = tuple([self.bottom_up_rewrite(x) for x in new_n.src])
|
||||
self.replace[n] = ret = new_n if new_src == new_n.src else self.bottom_up_rewrite(UOp(new_n.op, new_n.dtype, new_src, new_n.arg))
|
||||
return ret
|
||||
|
||||
@track_matches
|
||||
|
||||
Reference in New Issue
Block a user