infinite loop detect in fixed_point_rewrite [pr] (#11038)

This commit is contained in:
George Hotz
2025-06-30 08:57:29 -07:00
committed by GitHub
parent bc15e98f5c
commit b829331219
2 changed files with 14 additions and 1 deletions

View File

@@ -299,6 +299,15 @@ class TestRecurse(unittest.TestCase):
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm)
def test_inf_loop_bottom_up(self):
a = UOp.variable('a', 0, 10)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm, bottom_up=True)
def bidir_append(ctx, x, b): ctx.append((x.arg if x.op is Ops.CONST else "+", b))
class TestBidirectional(unittest.TestCase):
def test_simple(self):

View File

@@ -733,7 +733,11 @@ class PatternMatcher:
def fixed_point_rewrite(self, uop:UOp, ctx=None) -> UOp:
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
new_n: UOp|None = uop
while new_n is not None: last_n, new_n = new_n, self.rewrite(new_n, ctx)
seen = set()
while new_n is not None:
if new_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
seen.add(new_n)
last_n, new_n = new_n, self.rewrite(new_n, ctx)
return last_n
# *** non-blocking UOp tracker ***