non recursive top_down_rewrite (#10729)

* non recursive top_down_rewrite

* nicer algorithm

* rewrite bottom up also

* only top down is broken?

* simpler iterative algo

* no recursion errors

* top down and bottom up

* unified rewrite

* simpler rewrite

* clean up comments

* move that comment
This commit is contained in:
George Hotz
2025-06-09 16:33:04 -07:00
committed by GitHub
parent 53cbd4254b
commit 81ef879da3
4 changed files with 32 additions and 14 deletions

View File

@@ -84,6 +84,7 @@ class TestSchedule(unittest.TestCase):
with Context(FUSE_ARANGE=1, NOOPT=1): self.test_arange_avgpool2d(kcount=1)
# linearizer error
@unittest.skip("recursion error no longer raised")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "needs supports_float4 to fail")
def test_arange_avgpool2d_fused(self):
with self.assertRaises(RecursionError):

View File

@@ -153,6 +153,7 @@ class TestSoftmaxFusion(unittest.TestCase):
np.testing.assert_allclose(sout.numpy(), out.numpy())
@unittest.skip("recursion error no longer raised")
def test_softmax_bw(self):
print("*** softmax bw ***")
self.test.requires_grad_()

View File

@@ -253,6 +253,7 @@ class TestSubstitute(unittest.TestCase):
# broken due to infinite recursion
# NOTE: VIZ hangs and doesn't recover if you click this one
@unittest.skip("recursion error no longer raised")
def test_assert_inf_recurse(self):
a = UOp.variable('a', 0, 10)
n1 = a.sin()

View File

@@ -996,30 +996,45 @@ class RewriteContext:
self.pm: PatternMatcher = pm
self.ctx = ctx
self.replace: dict[UOp, UOp] = {}
def top_down_rewrite(self, n:UOp) -> UOp:
if (rn := self.replace.get(n)) is not None: return rn
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
return ret
def bottom_up_rewrite(self, n:UOp) -> UOp:
if (rn := self.replace.get(n)) is not None: return rn
new_n = self.pm.fixed_point_rewrite(n, self.ctx)
new_src = tuple([self.bottom_up_rewrite(x) for x in new_n.src])
self.replace[n] = ret = new_n if new_src == new_n.src else self.bottom_up_rewrite(UOp(new_n.op, new_n.dtype, new_src, new_n.arg))
return ret
def unified_rewrite(self, root:UOp, bottom_up=False) -> UOp:
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
while stack:
n, stage, new_n = stack.pop()
if n in self.replace: continue # skip any nodes we have seen
if stage == 0:
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
if bottom_up: new_n = self.pm.fixed_point_rewrite(new_n, self.ctx)
stack.append((n, 1, new_n))
for x in reversed(new_n.src): stack.append((x, 0, x))
elif stage == 1:
if (new_src:=tuple([self.replace[x] for x in new_n.src])) == new_n.src:
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
if bottom_up or (new_src_n:=self.pm.rewrite(new_n, self.ctx)) is None:
self.replace[n] = new_n
continue
else:
# if srcs changed from rewrites, construct a new UOp with the new srcs
new_src_n = UOp(new_n.op, new_n.dtype, new_src, new_n.arg)
# trigger a rewrite of new_src_n, then after that rewrite is done, link it back to n
stack.append((n, 2, new_src_n))
stack.append((new_src_n, 0, new_src_n))
else:
# in stage 2, we link the result of new_n to the result of n
self.replace[n] = self.replace[new_n]
return self.replace[root]
@track_matches
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
rewrite_ctx = RewriteContext(pm, ctx)
return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink)
return rewrite_ctx.unified_rewrite(sink, bottom_up)
@track_matches
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, input_map:dict[UOp, UOp]|None=None) -> dict[UOp, UOp]:
rewrite_ctx = RewriteContext(pm, ctx)
new_map: dict[UOp, UOp] = {}
for k in sink.toposort():
new_map[k] = v = rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)
new_map[k] = v = rewrite_ctx.unified_rewrite(k, bottom_up)
if k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
if input_map is not None:
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)