one more opt test

This commit is contained in:
George Hotz
2022-10-28 18:37:53 -07:00
parent dd543fbc7a
commit 7909786dbf

View File

@@ -59,6 +59,21 @@ class TestOpt(unittest.TestCase):
assert len(CL.CACHE) == 3, "optimizer didn't fold batchnorm"
Tensor.training = False
def test_fold_conv_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = optim.SGD(optim.get_parameters(c1))
with CLCache():
opt.zero_grad()
c1(img).relu().sum().backward()
opt.step()
# TODO: this should be 4, but the sum output child stays around
# with pushing_permutes it can be 3
assert len(CL.CACHE) == 5, "optimizer didn't fold conv-backward SGD"
Tensor.training = False
def test_fold_conv_batchnorm_sgd(self):
# TODO: with Tensor.training
Tensor.training = True