fixed point rewrite [pr] (#10732)

This commit is contained in:
George Hotz
2025-06-09 14:46:20 -07:00
committed by GitHub
parent 55cdbb9a20
commit 916bbd5c6b

View File

@@ -879,6 +879,12 @@ class PatternMatcher:
if (ret:=match(uop, ctx)) is not None: return ret
return None
def fixed_point_rewrite(self, uop:UOp, ctx=None) -> UOp:
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
new_n: UOp|None = uop
while new_n is not None: last_n, new_n = new_n, self.rewrite(new_n, ctx)
return last_n
# *** tracking pattern matcher ***
TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
@@ -998,10 +1004,9 @@ class RewriteContext:
return ret
def bottom_up_rewrite(self, n:UOp) -> UOp:
if (rn := self.replace.get(n)) is not None: return rn
new_n: UOp|None = n
while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
new_n = self.pm.fixed_point_rewrite(n, self.ctx)
new_src = tuple([self.bottom_up_rewrite(x) for x in new_n.src])
self.replace[n] = ret = new_n if new_src == new_n.src else self.bottom_up_rewrite(UOp(new_n.op, new_n.dtype, new_src, new_n.arg))
return ret
@track_matches