test reduce graph permutations (#4291)

This commit is contained in:
qazal
2024-04-25 11:34:44 +03:00
committed by GitHub
parent 0f0627bc60
commit 74a1be88f5

View File

@@ -465,6 +465,40 @@ class TestSchedule(unittest.TestCase):
out1 = a.sum() + out0 + 4
check_schedule([out0, out1], 1)
def test_group_outside_reduce(self):
a = Tensor.empty((4, 4))
b = Tensor.empty((4, 4))
out0 = a.sum() + 2
# b.sum() is not a descendant of the fused nodes
out1 = a.sum() + b.sum() + 4
check_schedule([out0, out1], 3) # TODO: this can fuse
def test_reduce_multiple_paths_fuse(self):
a = Tensor.empty(4, 4)
out0 = a.sum().exp2()
# out1 has two paths to a.sum()
out1 = a.sum() + out0
check_schedule([out0, out1], 1)
def test_reduce_multiple_paths_midreduce(self):
a = Tensor.empty(4, 4)
r = a.sum()
out0 = r.exp2()
# reduce node in the indirect path from r to out2
out1 = (a - out0).max()
out2 = r + out1
check_schedule([r, out0, out1, out2], 4)
def test_reduce_multiple_paths_midexpand(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4, 4)
r = a.sum()
out0 = r.exp2()
# e1 is in the indirect path from a.sum() to out1
e = b + out0
out1 = r + e[0][0][0]
check_schedule([r, out0, out1, e], 4)
def test_group_midreduce_nofuse(self):
a = Tensor.empty((4, 4))
b = Tensor.empty((4, 4))