back to 6 for test_fold_conv_sgd

This commit is contained in:
George Hotz
2023-04-14 07:34:00 -07:00
parent 133521e730
commit f7f416d6f4
2 changed files with 4 additions and 4 deletions

View File

@@ -180,7 +180,7 @@ class TestOpt(unittest.TestCase):
def test_fold_conv_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(1,3,4,4)
img = Tensor.ones(2,3,4,4)
c1 = nn.Conv2d(3,32,3)
opt = optim.SGD(optim.get_parameters(c1))
with CLCache():
@@ -190,7 +190,7 @@ class TestOpt(unittest.TestCase):
# TODO: this should be 4, but the sum output child stays around
# with pushing_permutes it can be 3
# TODO: broken with optim fixes
assert len(GlobalCounters.cache) in [4,5,6,7,8], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
Tensor.training = False
def test_fold_conv_batchnorm_sgd(self):