mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
non recursive top_down_rewrite (#10729)
* non recursive top_down_rewrite * nicer algorithm * rewrite bottom up also * only top down is broken? * simpler iterative algo * no recursion errors * top down and bottom up * unified rewrite * simpler rewrite * clean up comments * move that comment
This commit is contained in:
@@ -84,6 +84,7 @@ class TestSchedule(unittest.TestCase):
|
||||
with Context(FUSE_ARANGE=1, NOOPT=1): self.test_arange_avgpool2d(kcount=1)
|
||||
|
||||
# linearizer error
|
||||
@unittest.skip("recursion error no longer raised")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "needs supports_float4 to fail")
|
||||
def test_arange_avgpool2d_fused(self):
|
||||
with self.assertRaises(RecursionError):
|
||||
|
||||
@@ -153,6 +153,7 @@ class TestSoftmaxFusion(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(sout.numpy(), out.numpy())
|
||||
|
||||
@unittest.skip("recursion error no longer raised")
|
||||
def test_softmax_bw(self):
|
||||
print("*** softmax bw ***")
|
||||
self.test.requires_grad_()
|
||||
|
||||
@@ -253,6 +253,7 @@ class TestSubstitute(unittest.TestCase):
|
||||
|
||||
# broken due to infinite recursion
|
||||
# NOTE: VIZ hangs and doesn't recover if you click this one
|
||||
@unittest.skip("recursion error no longer raised")
|
||||
def test_assert_inf_recurse(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
n1 = a.sin()
|
||||
|
||||
@@ -996,30 +996,45 @@ class RewriteContext:
|
||||
self.pm: PatternMatcher = pm
|
||||
self.ctx = ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
def top_down_rewrite(self, n:UOp) -> UOp:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
|
||||
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
|
||||
self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
|
||||
return ret
|
||||
def bottom_up_rewrite(self, n:UOp) -> UOp:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
new_n = self.pm.fixed_point_rewrite(n, self.ctx)
|
||||
new_src = tuple([self.bottom_up_rewrite(x) for x in new_n.src])
|
||||
self.replace[n] = ret = new_n if new_src == new_n.src else self.bottom_up_rewrite(UOp(new_n.op, new_n.dtype, new_src, new_n.arg))
|
||||
return ret
|
||||
|
||||
def unified_rewrite(self, root:UOp, bottom_up=False) -> UOp:
|
||||
stack: list[tuple[UOp, int, UOp]] = [(root, 0, root)]
|
||||
while stack:
|
||||
n, stage, new_n = stack.pop()
|
||||
if n in self.replace: continue # skip any nodes we have seen
|
||||
if stage == 0:
|
||||
# if bottom up, we rewrite this node early. in both cases, we add its parents to the stack
|
||||
if bottom_up: new_n = self.pm.fixed_point_rewrite(new_n, self.ctx)
|
||||
stack.append((n, 1, new_n))
|
||||
for x in reversed(new_n.src): stack.append((x, 0, x))
|
||||
elif stage == 1:
|
||||
if (new_src:=tuple([self.replace[x] for x in new_n.src])) == new_n.src:
|
||||
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
|
||||
if bottom_up or (new_src_n:=self.pm.rewrite(new_n, self.ctx)) is None:
|
||||
self.replace[n] = new_n
|
||||
continue
|
||||
else:
|
||||
# if srcs changed from rewrites, construct a new UOp with the new srcs
|
||||
new_src_n = UOp(new_n.op, new_n.dtype, new_src, new_n.arg)
|
||||
# trigger a rewrite of new_src_n, then after that rewrite is done, link it back to n
|
||||
stack.append((n, 2, new_src_n))
|
||||
stack.append((new_src_n, 0, new_src_n))
|
||||
else:
|
||||
# in stage 2, we link the result of new_n to the result of n
|
||||
self.replace[n] = self.replace[new_n]
|
||||
return self.replace[root]
|
||||
|
||||
@track_matches
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink)
|
||||
return rewrite_ctx.unified_rewrite(sink, bottom_up)
|
||||
|
||||
@track_matches
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, input_map:dict[UOp, UOp]|None=None) -> dict[UOp, UOp]:
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
new_map: dict[UOp, UOp] = {}
|
||||
for k in sink.toposort():
|
||||
new_map[k] = v = rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)
|
||||
new_map[k] = v = rewrite_ctx.unified_rewrite(k, bottom_up)
|
||||
if k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
|
||||
if input_map is not None:
|
||||
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
||||
|
||||
Reference in New Issue
Block a user