diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index e022205d1d..d486b932a8 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -5,7 +5,6 @@ import torch from tinygrad import nn, GlobalCounters, Tensor, Device from tinygrad.helpers import getenv -from tinygrad.nn import optim from tinygrad.nn.state import get_parameters from tinygrad.engine.realize import capturing @@ -165,103 +164,6 @@ class TestOpt(unittest.TestCase): d.realize() np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5, atol=1e-7) - def test_fold_reduce_elementwise(self): - img = Tensor.ones(32).contiguous() - addme = Tensor.ones(1) - with CLCache() as cache: - ret = img.sum() + addme - ret.realize() - assert cache.count == 1, "optimizer didn't fold reduce/elementwise" - assert ret.item() == 33 - - def test_fold_batchnorm(self): - with Tensor.train(): - img = Tensor.ones(1,32,4,4).contiguous() - bn = nn.BatchNorm2d(32, track_running_stats=False) - with CLCache() as cache: - img_bn = bn(img).realize() - print(img_bn) - assert cache.count == 3, f"optimizer didn't fold batchnorm, got {cache.count}" - - def test_fold_conv_sgd(self): - with Tensor.train(): - img = Tensor.ones(2,3,4,4) - c1 = nn.Conv2d(3,32,3) - opt = optim.SGD(get_parameters(c1)) - with CLCache() as cache: - opt.zero_grad() - c1(img).relu().sum().backward() - opt.step() - assert cache.count == 5, f"optimizer didn't fold conv-backward SGD, got {cache.count}" - - def test_fold_conv_adam(self): - with Tensor.train(): - img = Tensor.ones(2,3,4,4) - c1 = nn.Conv2d(3,32,3) - opt = optim.Adam(get_parameters(c1), lr=1e-4) - with CLCache(allowed=10): - opt.zero_grad() - c1(img).relu().sum().backward() - opt.step() - - def test_fold_2convs_adam(self): - with Tensor.train(): - img = Tensor.ones(2,3,64,64) - c1 = nn.Conv2d(3,16,3,bias=False) - c2 = nn.Conv2d(16,32,3,bias=False) - opt = optim.Adam(get_parameters([c1, c2]), lr=1e-4) - with CLCache(allowed=13): - opt.zero_grad() - c2(c1(img).relu()).relu().sum().backward() - opt.step() - - def test_fold_2convs_sgd(self): - with Tensor.train(): - img = Tensor.ones(2,3,64,64) - c1 = nn.Conv2d(3,16,3,bias=False) - c2 = nn.Conv2d(16,32,3,bias=False) - opt = optim.SGD(get_parameters([c1, c2])) - with CLCache(allowed=8): - opt.zero_grad() - c2(c1(img).relu()).relu().sum().backward() - opt.step() - - def test_fold_2convs_sgd_nesterov_momentum_wd(self): - with Tensor.train(): - img = Tensor.ones(2,3,64,64) - c1 = nn.Conv2d(3,16,3,bias=False) - c2 = nn.Conv2d(16,32,3,bias=False) - opt = optim.SGD(get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) - with CLCache(allowed=10): - opt.zero_grad() - c2(c1(img).relu()).relu().sum().backward() - opt.step() - - def test_fold_4convs_sgd(self): - with Tensor.train(): - img = Tensor.ones(2,3,64,64) - c1 = nn.Conv2d(3,4,3,bias=False) - c2 = nn.Conv2d(4,8,3,bias=False) - c3 = nn.Conv2d(8,16,3,bias=False) - c4 = nn.Conv2d(16,32,3,bias=False) - opt = optim.SGD(get_parameters([c1, c2, c3, c4])) - with CLCache(allowed=18): - opt.zero_grad() - c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - opt.step() - - def test_fold_conv_batchnorm_sgd(self): - with Tensor.train(): - img = Tensor.ones(1,3,4,4) - c1 = nn.Conv2d(3,32,3) - bn = nn.BatchNorm2d(32, track_running_stats=False) - opt = optim.SGD(get_parameters([c1, bn])) - with CLCache(allowed=16): # this is too high - img_bn = bn(c1(img)).elu().sum() - opt.zero_grad() - img_bn.backward() - opt.step() - def test_fold_conv_batchnorm_notrain(self): img = Tensor.ones(1,3,8,8) c1 = nn.Conv2d(3,32,3) @@ -272,16 +174,6 @@ class TestOpt(unittest.TestCase): bn(c1(img)).relu().realize() assert cache.count == 1, f"optimizer didn't fold conv-batchnorm at test time, got {cache.count}" - def test_fold_conv_batchnorm(self): - with Tensor.train(): - img = Tensor.ones(1,3,8,8) - c1 = nn.Conv2d(3,32,3) - bn = nn.BatchNorm2d(32, track_running_stats=False) - with CLCache() as cache: - img_conv = bn(c1(img)).relu().realize() - print(img_conv) - assert cache.count == 4, f"optimizer didn't fold conv-batchnorm, got {cache.count}" - def test_fold_conv_elu(self): img = Tensor.ones(1,4,8,8) c1 = nn.Conv2d(4, 4, kernel_size=3) @@ -300,15 +192,6 @@ class TestOpt(unittest.TestCase): print(img_conv) assert cache.count == 2, "optimizer didn't fold conv/relu" - def test_fold_conv_relu_nobias(self): - img = Tensor.ones(1,4,8,8) - c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False) - c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False) - with CLCache() as cache: - img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() - print(img_conv) - assert cache.count == 2, "optimizer didn't fold conv/relu" - def test_permute_was_pushed(self): a = Tensor.randn(16, 16, 16) with CLCache(2): diff --git a/test/test_schedule.py b/test/test_schedule.py index f35f967c7d..86c3e99c0a 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -190,6 +190,26 @@ class TestSchedule(unittest.TestCase): out = bn(img) check_schedule(out, 3) + def test_fold_conv_batchnorm(self): + with Tensor.train(): + img = Tensor.empty(1,3,8,8) + c1 = nn.Conv2d(3,32,3) + bn = nn.BatchNorm2d(32, track_running_stats=False) + out = bn(c1(img)).relu() + check_schedule(out, 4, [c1.weight, c1.bias]) + + def test_fold_conv_batchnorm_sgd(self): + with Tensor.train(): + img = Tensor.ones(1,3,4,4) + c1 = nn.Conv2d(3,32,3) + bn = nn.BatchNorm2d(32, track_running_stats=False) + opt = nn.optim.SGD(nn.state.get_parameters([c1, bn])) + img_bn = bn(c1(img)).elu().sum() + opt.zero_grad() + img_bn.backward() + # this is too high + check_schedule(opt.schedule_step(), 18) + def test_fold_conv_relu(self): c1 = nn.Conv2d(3,16,3) @@ -198,6 +218,13 @@ class TestSchedule(unittest.TestCase): out = c1(img).relu() check_schedule(out, 1, [c1.weight, c1.bias]) + def test_fold_conv_relu_nobias(self): + img = Tensor.ones(1,4,8,8) + c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False) + c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False) + out = img.sequential([c1, Tensor.relu, c2, Tensor.relu]) + check_schedule(out, 2, [c1.weight, c2.weight, img]) + def test_fold_conv_elu(self): c1 = nn.Conv2d(3,16,3) @@ -556,6 +583,66 @@ class TestSchedule(unittest.TestCase): layer(x).relu().sum().backward() check_schedule(opt.schedule_step(), 14) + def test_adam_conv_fuse(self): + with Tensor.train(): + img = Tensor.empty(2,3,4,4) + c1 = nn.Conv2d(3,32,3) + opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4) + opt.zero_grad() + c1(img).relu().sum().backward() + check_schedule(opt.schedule_step(), 14) + + def test_adam_2convs_fuse(self): + with Tensor.train(): + img = Tensor.empty(2,3,4,4) + c1 = nn.Conv2d(3,16,3,bias=False) + c2 = nn.Conv2d(16,32,3,bias=False) + opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) + opt.zero_grad() + c2(c1(img).relu()).relu().sum().backward() + check_schedule(opt.schedule_step(), 15) + + def test_sgd_conv_fuse(self): + with Tensor.train(): + img = Tensor.empty(2,3,4,4) + c1 = nn.Conv2d(3,32,3) + opt = nn.optim.SGD(nn.state.get_parameters(c1)) + opt.zero_grad() + c1(img).relu().sum().backward() + check_schedule(opt.schedule_step(), 7) + + def test_sgd_2convs_fuse(self): + with Tensor.train(): + img = Tensor.empty(2,3,4,4) + c1 = nn.Conv2d(3,16,3,bias=False) + c2 = nn.Conv2d(16,32,3,bias=False) + opt = nn.optim.SGD(nn.state.get_parameters([c1, c2])) + opt.zero_grad() + c2(c1(img).relu()).relu().sum().backward() + check_schedule(opt.schedule_step(), 7) + + def test_fold_2convs_sgd_nesterov_momentum_wd(self): + with Tensor.train(): + img = Tensor.empty(2,3,4,4) + c1 = nn.Conv2d(3,16,3,bias=False) + c2 = nn.Conv2d(16,32,3,bias=False) + opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) + opt.zero_grad() + c2(c1(img).relu()).relu().sum().backward() + check_schedule(opt.schedule_step(), 9) + + def test_sgd_4convs_fuse(self): + with Tensor.train(): + img = Tensor.empty(2,3,64,64) + c1 = nn.Conv2d(3,4,3,bias=False) + c2 = nn.Conv2d(4,8,3,bias=False) + c3 = nn.Conv2d(8,16,3,bias=False) + c4 = nn.Conv2d(16,32,3,bias=False) + opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) + opt.zero_grad() + c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() + check_schedule(opt.schedule_step(), 22) + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self): x = Tensor.ones(4).contiguous().realize()