add collapse_view to the scheduler [pr] (#8440)

This commit is contained in:
qazal
2024-12-29 15:30:29 +02:00
committed by GitHub
parent 98b2854f14
commit a44cd1e6f7
3 changed files with 38 additions and 1 deletions

View File

@@ -1981,6 +1981,39 @@ class TestView(unittest.TestCase):
run_schedule(sched)
np.testing.assert_allclose(b.numpy(), np.pad(a.numpy(), ((0, 5), (0, 0)))[5:])
# a*VIEW(x), where VIEW(x) = 0
# x collapses along with its children
def test_parent_view_collapses(self):
a = Tensor([1, 2])
b = Tensor.arange(3).contiguous()
bv = b.pad(((0, 2),))[-2:]
# this becomes a late a*0
late_mul = a*bv
check_schedule(late_mul, 0)
# the arange doesn't realize
self.assertIsNone(b.lazydata.base.realized)
# mul doesn't realize
self.assertIsNone(late_mul.lazydata.base.realized)
self.assertEqual(late_mul.tolist(), [0, 0])
# SINK has two branches:
# a*VIEW(x), where VIEW(x) = 0
# x+2
# as long as one child realizes, x does not collapse
def test_parent_multiple_children_no_collapse(self):
a = Tensor([1, 2])
b = Tensor.arange(3).contiguous()
bv = b.pad(((0, 2),))[-2:]
late_mul = a*bv
other_child = b+2
s = check_schedule([late_mul, other_child], 2)
# the arange realizes
self.assertIsNotNone(b.lazydata.base.realized)
# mul still collapses
self.assertIsNone(late_mul.lazydata.base.realized)
run_schedule(s)
self.assertEqual(other_child.tolist(), [2, 3, 4])
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, remove_movement_ops+symbolic)
class TestBigGraph(unittest.TestCase):
def test_sink_childless_const(self):