mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
detect infinite loop in graph rewrite [pr] (#11036)
This commit is contained in:
@@ -290,6 +290,15 @@ class TestRecurse(unittest.TestCase):
|
||||
pm = PatternMatcher([(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x)])
|
||||
graph_rewrite(a, pm)
|
||||
|
||||
def test_inf_loop(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph_rewrite(a, pm)
|
||||
|
||||
def bidir_append(ctx, x, b): ctx.append((x.arg if x.op is Ops.CONST else "+", b))
|
||||
class TestBidirectional(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
|
||||
@@ -868,6 +868,7 @@ class RewriteContext:
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
|
||||
while stack:
|
||||
if len(stack) >= 200000: raise RuntimeError("infinite loop in graph_rewrite")
|
||||
n, stage, new_n = stack.pop()
|
||||
if n in self.replace: continue # skip any nodes we have seen
|
||||
if stage == 0:
|
||||
|
||||
Reference in New Issue
Block a user