batchnorm + conv backward in test_schedule (#4420)

* test both optims

* batchnorm_backward
This commit is contained in:
qazal
2024-05-06 21:40:17 +08:00
committed by GitHub
parent 3f3c973022
commit 6dbe5585b0

View File

@@ -194,17 +194,19 @@ class TestSchedule(unittest.TestCase):
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
def test_fold_conv_batchnorm_sgd(self):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
opt = nn.optim.SGD(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
# this is too high
check_schedule(opt.schedule_step(), 17)
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 20), (nn.optim.SGD, 17)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
opt = optim(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
check_schedule(opt.schedule_step(), cnt)
def test_fold_conv_relu_backward(self):
c1 = nn.Conv2d(3,16,3, bias=False)
@@ -217,6 +219,16 @@ class TestSchedule(unittest.TestCase):
# img.grad is requiring two reduces
check_schedule([img.grad, c1.weight.grad], 5)
def test_fold_batchnorm_backward(self):
with Tensor.train():
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# TODO: this is too many
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)