From 3372bea322c1e14a501ab9599e593273fbba9cf7 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 28 Apr 2024 18:14:02 +0300 Subject: [PATCH] reduce children fusion tests (#4321) * base tests * real-world tests --- test/external/external_test_opt.py | 32 ++++++++++++++ test/test_schedule.py | 70 +++++++++++++++++------------- 2 files changed, 72 insertions(+), 30 deletions(-) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 22690e26f1..e022205d1d 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -194,6 +194,27 @@ class TestOpt(unittest.TestCase): opt.step() assert cache.count == 5, f"optimizer didn't fold conv-backward SGD, got {cache.count}" + def test_fold_conv_adam(self): + with Tensor.train(): + img = Tensor.ones(2,3,4,4) + c1 = nn.Conv2d(3,32,3) + opt = optim.Adam(get_parameters(c1), lr=1e-4) + with CLCache(allowed=10): + opt.zero_grad() + c1(img).relu().sum().backward() + opt.step() + + def test_fold_2convs_adam(self): + with Tensor.train(): + img = Tensor.ones(2,3,64,64) + c1 = nn.Conv2d(3,16,3,bias=False) + c2 = nn.Conv2d(16,32,3,bias=False) + opt = optim.Adam(get_parameters([c1, c2]), lr=1e-4) + with CLCache(allowed=13): + opt.zero_grad() + c2(c1(img).relu()).relu().sum().backward() + opt.step() + def test_fold_2convs_sgd(self): with Tensor.train(): img = Tensor.ones(2,3,64,64) @@ -205,6 +226,17 @@ class TestOpt(unittest.TestCase): c2(c1(img).relu()).relu().sum().backward() opt.step() + def test_fold_2convs_sgd_nesterov_momentum_wd(self): + with Tensor.train(): + img = Tensor.ones(2,3,64,64) + c1 = nn.Conv2d(3,16,3,bias=False) + c2 = nn.Conv2d(16,32,3,bias=False) + opt = optim.SGD(get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) + with CLCache(allowed=10): + opt.zero_grad() + c2(c1(img).relu()).relu().sum().backward() + opt.step() + def test_fold_4convs_sgd(self): with Tensor.train(): img = Tensor.ones(2,3,64,64) diff --git a/test/test_schedule.py b/test/test_schedule.py index e6fa47ecca..10bebcbded 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -453,33 +453,28 @@ class TestSchedule(unittest.TestCase): out = x.contiguous() + y.contiguous() check_schedule(out, 2) - def test_group_fuse(self): - a = Tensor.empty((4, 4)) + def test_reduce_same_size(self): + a = Tensor.empty(4, 4) out0 = a.sum() + 2 out1 = a.sum() + 4 - check_schedule([out0, out1], 1) + out2 = out0 * out1 + check_schedule([out0, out1, out2], 2) - def test_group_inner_deps_fuse(self): - a = Tensor.empty((4, 4)) - out0 = a.sum() + 2 - out1 = a.sum() + out0 + 4 - check_schedule([out0, out1], 1) - - def test_group_outside_reduce(self): - a = Tensor.empty((4, 4)) - b = Tensor.empty((4, 4)) - out0 = a.sum() + 2 - # b.sum() is not a descendant of the fused nodes - out1 = a.sum() + b.sum() + 4 - check_schedule([out0, out1], 3) # TODO: this can fuse - - def test_reduce_multiple_paths_fuse(self): + def test_reduce_multiple_paths(self): a = Tensor.empty(4, 4) out0 = a.sum().exp2() # out1 has two paths to a.sum() out1 = a.sum() + out0 check_schedule([out0, out1], 1) + def test_reduce_ext_reduce_child(self): + a = Tensor.empty((4, 4)) + b = Tensor.empty((4, 4)) + # b.sum() is not a descendant of the fused nodes + out0 = a.sum() + b.sum() + 2 + out1 = a.sum() + b.sum() + 4 + check_schedule([out0, out1], 4) + def test_reduce_multiple_paths_midreduce(self): a = Tensor.empty(4, 4) r = a.sum() @@ -489,6 +484,14 @@ class TestSchedule(unittest.TestCase): out2 = r + out1 check_schedule([r, out0, out1, out2], 4) + def test_reduce_multiple_paths_midreduce_fused(self): + a = Tensor.empty(4, 4) + b = Tensor.empty(4, 4) + out0 = a.sum() + 4 + out1 = b.max() + out0*2 + out2 = a.sum() + out1 + check_schedule([out0, out1, out2], 4) + def test_reduce_multiple_paths_midexpand(self): a = Tensor.empty(4, 4) b = Tensor.empty(4, 4, 4) @@ -499,26 +502,33 @@ class TestSchedule(unittest.TestCase): out1 = r + e[0][0][0] check_schedule([r, out0, out1, e], 4) - def test_group_midreduce_nofuse(self): - a = Tensor.empty((4, 4)) - b = Tensor.empty((4, 4)) - out0 = a.sum() + 2 - out1 = a.sum() + b.sum() + 4 - check_schedule([out0, out1], 3) - - def test_group_midexpand_nofuse(self): + def test_reduce_expand_child(self): a = Tensor.empty((32, 32, 32)) b = Tensor.empty((1, 16)) out0 = a.sum() + 2 out1 = a.sum() + b check_schedule([out0, out1], 4) - def test_group_midshrink_fuse(self): + def test_reduce_shrink_child(self): a = Tensor.empty(100, 100) b = Tensor.empty(10,) - out0 = a.sum() + b[0] - out1 = a.sum() + 2 - check_schedule([out0, out1], 1) + c = a.sum() + b[0] + d = a.sum() + 2 + check_schedule([c, d], 1) + + def test_reduce_multiple_paths_midshrink(self): + a = Tensor.empty(4, 4) + r = a.sum(axis=1) + out0 = r.exp2() + out1 = out0[0] + out0 + check_schedule([r, out0, out1], 3) + + def test_reduce_shrink_output(self): + a = Tensor.empty(4, 4) + r = a.sum(keepdim=True) + out0 = r.exp2() + out1 = out0[0] + Tensor.empty(1, ) + check_schedule([r, out0, out1], 3) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self):