diff --git a/test/test_schedule.py b/test/test_schedule.py index c6e7cd0b00..cbcb6a6ca4 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -84,6 +84,7 @@ class TestSchedule(unittest.TestCase): with Context(FUSE_ARANGE=1, NOOPT=1): self.test_arange_avgpool2d(kcount=1) # linearizer error + @unittest.skip("recursion error no longer raised") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "needs supports_float4 to fail") def test_arange_avgpool2d_fused(self): with self.assertRaises(RecursionError): diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index fd0dedc068..23e1a4b773 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -153,6 +153,7 @@ class TestSoftmaxFusion(unittest.TestCase): np.testing.assert_allclose(sout.numpy(), out.numpy()) + @unittest.skip("recursion error no longer raised") def test_softmax_bw(self): print("*** softmax bw ***") self.test.requires_grad_() diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index 97816f2e89..385bfda897 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -253,6 +253,7 @@ class TestSubstitute(unittest.TestCase): # broken due to infinite recursion # NOTE: VIZ hangs and doesn't recover if you click this one + @unittest.skip("recursion error no longer raised") def test_assert_inf_recurse(self): a = UOp.variable('a', 0, 10) n1 = a.sin() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index e28224505a..5e06b71913 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -996,30 +996,45 @@ class RewriteContext: self.pm: PatternMatcher = pm self.ctx = ctx self.replace: dict[UOp, UOp] = {} - def top_down_rewrite(self, n:UOp) -> UOp: - if (rn := self.replace.get(n)) is not None: return rn - new_src = tuple([self.top_down_rewrite(x) for x in n.src]) - new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg) - self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n) - return ret - def bottom_up_rewrite(self, n:UOp) -> UOp: - if (rn := self.replace.get(n)) is not None: return rn - new_n = self.pm.fixed_point_rewrite(n, self.ctx) - new_src = tuple([self.bottom_up_rewrite(x) for x in new_n.src]) - self.replace[n] = ret = new_n if new_src == new_n.src else self.bottom_up_rewrite(UOp(new_n.op, new_n.dtype, new_src, new_n.arg)) - return ret + + def unified_rewrite(self, root:UOp, bottom_up=False) -> UOp: + stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)] + while stack: + n, stage, new_n = stack.pop() + if n in self.replace: continue # skip any nodes we have seen + if stage == 0: + # if bottom up, we rewrite this node early. in both cases, we add its parents to the stack + if bottom_up: new_n = self.pm.fixed_point_rewrite(new_n, self.ctx) + stack.append((n, 1, new_n)) + for x in reversed(new_n.src): stack.append((x, 0, x)) + elif stage == 1: + if (new_src:=tuple([self.replace[x] for x in new_n.src])) == new_n.src: + # if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict + if bottom_up or (new_src_n:=self.pm.rewrite(new_n, self.ctx)) is None: + self.replace[n] = new_n + continue + else: + # if srcs changed from rewrites, construct a new UOp with the new srcs + new_src_n = UOp(new_n.op, new_n.dtype, new_src, new_n.arg) + # trigger a rewrite of new_src_n, then after that rewrite is done, link it back to n + stack.append((n, 2, new_src_n)) + stack.append((new_src_n, 0, new_src_n)) + else: + # in stage 2, we link the result of new_n to the result of n + self.replace[n] = self.replace[new_n] + return self.replace[root] @track_matches def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp: rewrite_ctx = RewriteContext(pm, ctx) - return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink) + return rewrite_ctx.unified_rewrite(sink, bottom_up) @track_matches def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, input_map:dict[UOp, UOp]|None=None) -> dict[UOp, UOp]: rewrite_ctx = RewriteContext(pm, ctx) new_map: dict[UOp, UOp] = {} for k in sink.toposort(): - new_map[k] = v = rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k) + new_map[k] = v = rewrite_ctx.unified_rewrite(k, bottom_up) if k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata if input_map is not None: for k,v in input_map.items(): new_map[k] = new_map.get(v,v)