fusion tests from test_opt (#4357)

* opt tests

* more sgd

* batchnorm

* models stay in external
This commit is contained in:
qazal
2024-05-01 16:44:12 +03:00
committed by GitHub
parent 995d264666
commit ea06f657df
2 changed files with 87 additions and 117 deletions

View File

@@ -5,7 +5,6 @@ import torch
from tinygrad import nn, GlobalCounters, Tensor, Device
from tinygrad.helpers import getenv
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.engine.realize import capturing
@@ -165,103 +164,6 @@ class TestOpt(unittest.TestCase):
d.realize()
np.testing.assert_allclose(d.numpy(), na*nb+nc, rtol=1e-5, atol=1e-7)
def test_fold_reduce_elementwise(self):
img = Tensor.ones(32).contiguous()
addme = Tensor.ones(1)
with CLCache() as cache:
ret = img.sum() + addme
ret.realize()
assert cache.count == 1, "optimizer didn't fold reduce/elementwise"
assert ret.item() == 33
def test_fold_batchnorm(self):
with Tensor.train():
img = Tensor.ones(1,32,4,4).contiguous()
bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache() as cache:
img_bn = bn(img).realize()
print(img_bn)
assert cache.count == 3, f"optimizer didn't fold batchnorm, got {cache.count}"
def test_fold_conv_sgd(self):
with Tensor.train():
img = Tensor.ones(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = optim.SGD(get_parameters(c1))
with CLCache() as cache:
opt.zero_grad()
c1(img).relu().sum().backward()
opt.step()
assert cache.count == 5, f"optimizer didn't fold conv-backward SGD, got {cache.count}"
def test_fold_conv_adam(self):
with Tensor.train():
img = Tensor.ones(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = optim.Adam(get_parameters(c1), lr=1e-4)
with CLCache(allowed=10):
opt.zero_grad()
c1(img).relu().sum().backward()
opt.step()
def test_fold_2convs_adam(self):
with Tensor.train():
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = optim.Adam(get_parameters([c1, c2]), lr=1e-4)
with CLCache(allowed=13):
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
opt.step()
def test_fold_2convs_sgd(self):
with Tensor.train():
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(get_parameters([c1, c2]))
with CLCache(allowed=8):
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
opt.step()
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
with CLCache(allowed=10):
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
opt.step()
def test_fold_4convs_sgd(self):
with Tensor.train():
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(get_parameters([c1, c2, c3, c4]))
with CLCache(allowed=18):
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
opt.step()
def test_fold_conv_batchnorm_sgd(self):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
opt = optim.SGD(get_parameters([c1, bn]))
with CLCache(allowed=16): # this is too high
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
opt.step()
def test_fold_conv_batchnorm_notrain(self):
img = Tensor.ones(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
@@ -272,16 +174,6 @@ class TestOpt(unittest.TestCase):
bn(c1(img)).relu().realize()
assert cache.count == 1, f"optimizer didn't fold conv-batchnorm at test time, got {cache.count}"
def test_fold_conv_batchnorm(self):
with Tensor.train():
img = Tensor.ones(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
with CLCache() as cache:
img_conv = bn(c1(img)).relu().realize()
print(img_conv)
assert cache.count == 4, f"optimizer didn't fold conv-batchnorm, got {cache.count}"
def test_fold_conv_elu(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3)
@@ -300,15 +192,6 @@ class TestOpt(unittest.TestCase):
print(img_conv)
assert cache.count == 2, "optimizer didn't fold conv/relu"
def test_fold_conv_relu_nobias(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
with CLCache() as cache:
img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize()
print(img_conv)
assert cache.count == 2, "optimizer didn't fold conv/relu"
def test_permute_was_pushed(self):
a = Tensor.randn(16, 16, 16)
with CLCache(2):

View File

@@ -190,6 +190,26 @@ class TestSchedule(unittest.TestCase):
out = bn(img)
check_schedule(out, 3)
def test_fold_conv_batchnorm(self):
with Tensor.train():
img = Tensor.empty(1,3,8,8)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
def test_fold_conv_batchnorm_sgd(self):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
opt = nn.optim.SGD(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
# this is too high
check_schedule(opt.schedule_step(), 18)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)
@@ -198,6 +218,13 @@ class TestSchedule(unittest.TestCase):
out = c1(img).relu()
check_schedule(out, 1, [c1.weight, c1.bias])
def test_fold_conv_relu_nobias(self):
img = Tensor.ones(1,4,8,8)
c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
out = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
check_schedule(out, 2, [c1.weight, c2.weight, img])
def test_fold_conv_elu(self):
c1 = nn.Conv2d(3,16,3)
@@ -556,6 +583,66 @@ class TestSchedule(unittest.TestCase):
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
def test_adam_conv_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
def test_adam_2convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 15)
def test_sgd_conv_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = nn.optim.SGD(nn.state.get_parameters(c1))
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 7)
def test_sgd_2convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
img = Tensor.empty(2,3,4,4)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 9)
def test_sgd_4convs_fuse(self):
with Tensor.train():
img = Tensor.empty(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 22)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):
x = Tensor.ones(4).contiguous().realize()