mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
infinite loop detect in fixed_point_rewrite [pr] (#11038)
This commit is contained in:
@@ -299,6 +299,15 @@ class TestRecurse(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph_rewrite(a, pm)
|
||||
|
||||
def test_inf_loop_bottom_up(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: x.replace(op=Ops.DEFINE_REG)),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda x: x.replace(op=Ops.DEFINE_VAR)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph_rewrite(a, pm, bottom_up=True)
|
||||
|
||||
def bidir_append(ctx, x, b): ctx.append((x.arg if x.op is Ops.CONST else "+", b))
|
||||
class TestBidirectional(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
|
||||
@@ -733,7 +733,11 @@ class PatternMatcher:
|
||||
def fixed_point_rewrite(self, uop:UOp, ctx=None) -> UOp:
|
||||
# apply rewrite rules until a fixed point is reached. may return `uop` itself if PatternMatcher doesn't match
|
||||
new_n: UOp|None = uop
|
||||
while new_n is not None: last_n, new_n = new_n, self.rewrite(new_n, ctx)
|
||||
seen = set()
|
||||
while new_n is not None:
|
||||
if new_n in seen: raise RuntimeError("infinite loop in fixed_point_rewrite")
|
||||
seen.add(new_n)
|
||||
last_n, new_n = new_n, self.rewrite(new_n, ctx)
|
||||
return last_n
|
||||
|
||||
# *** non-blocking UOp tracker ***
|
||||
|
||||
Reference in New Issue
Block a user