mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
back to 6 for test_fold_conv_sgd
This commit is contained in:
4
test/external/external_test_opt.py
vendored
4
test/external/external_test_opt.py
vendored
@@ -180,7 +180,7 @@ class TestOpt(unittest.TestCase):
|
||||
def test_fold_conv_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
img = Tensor.ones(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(optim.get_parameters(c1))
|
||||
with CLCache():
|
||||
@@ -190,7 +190,7 @@ class TestOpt(unittest.TestCase):
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
# TODO: broken with optim fixes
|
||||
assert len(GlobalCounters.cache) in [4,5,6,7,8], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
|
||||
assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
|
||||
Reference in New Issue
Block a user