split pm_substitute_recurse (#12460)

This commit is contained in:
George Hotz
2025-10-06 09:35:50 +08:00
committed by GitHub
parent 1216fff781
commit 46e8ea15c1
2 changed files with 3 additions and 2 deletions

View File

@@ -1925,7 +1925,6 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(loss, 4))
np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6)
@expect_rangeify_fails
def test_const_folding_alt(self):
t = Tensor.full((2,), 1.)
lt = (t < 0.)

View File

@@ -749,7 +749,9 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
tsink = graph_rewrite(tsink, pm_rangeify, ctx=(rangeify_ctx:=RangeifyContext()), bottom_up=True, name="rangeify")
# NOTE: sym (vs symbolic_simple) breaks things here because ranges with len 1 aren't handled right
tsink = graph_rewrite(tsink, symbolic_simple+pm_reduce_unparented, name="symbolic") # this supports const folding
tsink = graph_rewrite(tsink, pm_cleanups+pm_substitute_recurse, bottom_up=True, name="remove costly buffers")
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
# TODO: can you substitute and remove costly buffers at the same time?
tsink = graph_rewrite(tsink, pm_substitute_recurse, bottom_up=True, name="run substitutes")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rangeify_ctx, name="limit buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph