reduce children fusion tests (#4321)

* base tests

* real-world tests
This commit is contained in:
qazal
2024-04-28 18:14:02 +03:00
committed by GitHub
parent f3de17912f
commit 3372bea322
2 changed files with 72 additions and 30 deletions

View File

@@ -453,33 +453,28 @@ class TestSchedule(unittest.TestCase):
out = x.contiguous() + y.contiguous()
check_schedule(out, 2)
def test_group_fuse(self):
a = Tensor.empty((4, 4))
def test_reduce_same_size(self):
a = Tensor.empty(4, 4)
out0 = a.sum() + 2
out1 = a.sum() + 4
check_schedule([out0, out1], 1)
out2 = out0 * out1
check_schedule([out0, out1, out2], 2)
def test_group_inner_deps_fuse(self):
a = Tensor.empty((4, 4))
out0 = a.sum() + 2
out1 = a.sum() + out0 + 4
check_schedule([out0, out1], 1)
def test_group_outside_reduce(self):
a = Tensor.empty((4, 4))
b = Tensor.empty((4, 4))
out0 = a.sum() + 2
# b.sum() is not a descendant of the fused nodes
out1 = a.sum() + b.sum() + 4
check_schedule([out0, out1], 3) # TODO: this can fuse
def test_reduce_multiple_paths_fuse(self):
def test_reduce_multiple_paths(self):
a = Tensor.empty(4, 4)
out0 = a.sum().exp2()
# out1 has two paths to a.sum()
out1 = a.sum() + out0
check_schedule([out0, out1], 1)
def test_reduce_ext_reduce_child(self):
a = Tensor.empty((4, 4))
b = Tensor.empty((4, 4))
# b.sum() is not a descendant of the fused nodes
out0 = a.sum() + b.sum() + 2
out1 = a.sum() + b.sum() + 4
check_schedule([out0, out1], 4)
def test_reduce_multiple_paths_midreduce(self):
a = Tensor.empty(4, 4)
r = a.sum()
@@ -489,6 +484,14 @@ class TestSchedule(unittest.TestCase):
out2 = r + out1
check_schedule([r, out0, out1, out2], 4)
def test_reduce_multiple_paths_midreduce_fused(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
out0 = a.sum() + 4
out1 = b.max() + out0*2
out2 = a.sum() + out1
check_schedule([out0, out1, out2], 4)
def test_reduce_multiple_paths_midexpand(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4, 4)
@@ -499,26 +502,33 @@ class TestSchedule(unittest.TestCase):
out1 = r + e[0][0][0]
check_schedule([r, out0, out1, e], 4)
def test_group_midreduce_nofuse(self):
a = Tensor.empty((4, 4))
b = Tensor.empty((4, 4))
out0 = a.sum() + 2
out1 = a.sum() + b.sum() + 4
check_schedule([out0, out1], 3)
def test_group_midexpand_nofuse(self):
def test_reduce_expand_child(self):
a = Tensor.empty((32, 32, 32))
b = Tensor.empty((1, 16))
out0 = a.sum() + 2
out1 = a.sum() + b
check_schedule([out0, out1], 4)
def test_group_midshrink_fuse(self):
def test_reduce_shrink_child(self):
a = Tensor.empty(100, 100)
b = Tensor.empty(10,)
out0 = a.sum() + b[0]
out1 = a.sum() + 2
check_schedule([out0, out1], 1)
c = a.sum() + b[0]
d = a.sum() + 2
check_schedule([c, d], 1)
def test_reduce_multiple_paths_midshrink(self):
a = Tensor.empty(4, 4)
r = a.sum(axis=1)
out0 = r.exp2()
out1 = out0[0] + out0
check_schedule([r, out0, out1], 3)
def test_reduce_shrink_output(self):
a = Tensor.empty(4, 4)
r = a.sum(keepdim=True)
out0 = r.exp2()
out1 = out0[0] + Tensor.empty(1, )
check_schedule([r, out0, out1], 3)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):