From 0deaaf2bc86f15d32c3399f1e454babf8d130f1c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 3 May 2024 04:14:23 +0300 Subject: [PATCH] partial fusion spec (#4398) --- test/test_schedule.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/test_schedule.py b/test/test_schedule.py index 8470d5330f..ee19279ad1 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -712,5 +712,40 @@ class TestSchedule(unittest.TestCase): schedule = check_schedule(b, 2) assert schedule[0].ast[0].src[0].op is ReduceOps.MAX + # pattern in test_transformer + def test_partial_fuse1(self): + a = Tensor.empty(16, 16) + b = Tensor.empty(16, 16) + c = a.sum() + 2 + d = (a.sum() - b.sum()) * 4 + check_schedule([c, d], 3) + + # pattern in conv + def test_partial_fuse2(self): + a = Tensor.empty(16, 16) + b = Tensor.empty(16, 16) + c = a.sum() + 2 + d = b.sum() - c + check_schedule([c, d], 2) + + # pattern in adam + def test_partial_fuse3(self): + a = Tensor.empty(16, 16) + b = Tensor.empty(16, 16) + c = a.sum() + 2 + d = a.sum() * 2 + e = c * d + f = b.sum() - e + check_schedule([c, d, e, f], 3) + + def test_partial_fuse4(self): + a = Tensor.empty(16, 16) + b = Tensor.empty(16, 16) + c = a.sum() + 2 + d = a.sum() * 2 + e = c * d + f = (b - d).sum() - e + check_schedule([c, d, e, f], 3) + if __name__ == '__main__': unittest.main(verbosity=2)