detect infinite loop in graph rewrite [pr] (#11036)

This commit is contained in:
George Hotz
2025-06-30 08:15:13 -07:00
committed by GitHub
parent 710d734ce7
commit cb531dba42
2 changed files with 10 additions and 0 deletions

View File

@@ -290,6 +290,15 @@ class TestRecurse(unittest.TestCase):
pm = PatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x)])
graph_rewrite(a, pm)
def test_inf_loop(self):
a = UOp.variable('a', 0, 10)
pm = PatternMatcher([
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm)
def bidir_append(ctx, x, b): ctx.append((x.arg if x.op is Ops.CONST else "+", b))
class TestBidirectional(unittest.TestCase):
def test_simple(self):

View File

@@ -868,6 +868,7 @@ class RewriteContext:
def unified_rewrite(self, root:UOp) -> UOp:
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
while stack:
if len(stack) >= 200000: raise RuntimeError("infinite loop in graph_rewrite")
n, stage, new_n = stack.pop()
if n in self.replace: continue # skip any nodes we have seen
if stage == 0: