mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
batchnorm + conv backward in test_schedule (#4420)
* test both optims * batchnorm_backward
This commit is contained in:
@@ -194,17 +194,19 @@ class TestSchedule(unittest.TestCase):
|
||||
out = bn(c1(img)).relu()
|
||||
check_schedule(out, 4, [c1.weight, c1.bias])
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, bn]))
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
# this is too high
|
||||
check_schedule(opt.schedule_step(), 17)
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 20), (nn.optim.SGD, 17)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim(nn.state.get_parameters([c1, bn]))
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
check_schedule(opt.schedule_step(), cnt)
|
||||
|
||||
def test_fold_conv_relu_backward(self):
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
@@ -217,6 +219,16 @@ class TestSchedule(unittest.TestCase):
|
||||
# img.grad is requiring two reduces
|
||||
check_schedule([img.grad, c1.weight.grad], 5)
|
||||
|
||||
def test_fold_batchnorm_backward(self):
|
||||
with Tensor.train():
|
||||
x = Tensor.empty((2, 16, 8, 8)).contiguous()
|
||||
bn = nn.BatchNorm2d(16)
|
||||
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
|
||||
fw = bn(x).contiguous_backward().relu().contiguous()
|
||||
fw.sum().backward()
|
||||
# TODO: this is too many
|
||||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user