partial fusion spec (#4398)

This commit is contained in:
qazal
2024-05-03 04:14:23 +03:00
committed by GitHub
parent 2c3b7f8e70
commit 0deaaf2bc8

View File

@@ -712,5 +712,40 @@ class TestSchedule(unittest.TestCase):
schedule = check_schedule(b, 2)
assert schedule[0].ast[0].src[0].op is ReduceOps.MAX
# pattern in test_transformer
def test_partial_fuse1(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(16, 16)
c = a.sum() + 2
d = (a.sum() - b.sum()) * 4
check_schedule([c, d], 3)
# pattern in conv
def test_partial_fuse2(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(16, 16)
c = a.sum() + 2
d = b.sum() - c
check_schedule([c, d], 2)
# pattern in adam
def test_partial_fuse3(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(16, 16)
c = a.sum() + 2
d = a.sum() * 2
e = c * d
f = b.sum() - e
check_schedule([c, d, e, f], 3)
def test_partial_fuse4(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(16, 16)
c = a.sum() + 2
d = a.sum() * 2
e = c * d
f = (b - d).sum() - e
check_schedule([c, d, e, f], 3)
if __name__ == '__main__':
unittest.main(verbosity=2)