From 916bbd5c6b93d9d0e61b25ca84a615892ce89bfb Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 9 Jun 2025 14:46:20 -0700 Subject: [PATCH] fixed point rewrite [pr] (#10732) --- tinygrad/uop/ops.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 20e60ae9c2..e28224505a 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -879,6 +879,12 @@ class PatternMatcher: if (ret:=match(uop, ctx)) is not None: return ret return None + def fixed_point_rewrite(self, uop:UOp, ctx=None) -> UOp: + # apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match + new_n: UOp|None = uop + while new_n is not None: last_n, new_n = new_n, self.rewrite(new_n, ctx) + return last_n + # *** tracking pattern matcher *** TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0) @@ -998,10 +1004,9 @@ class RewriteContext: return ret def bottom_up_rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn - new_n: UOp|None = n - while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx) - new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src]) - self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg)) + new_n = self.pm.fixed_point_rewrite(n, self.ctx) + new_src = tuple([self.bottom_up_rewrite(x) for x in new_n.src]) + self.replace[n] = ret = new_n if new_src == new_n.src else self.bottom_up_rewrite(UOp(new_n.op, new_n.dtype, new_src, new_n.arg)) return ret @track_matches