mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 17:15:48 -05:00
one more opt test
This commit is contained in:
@@ -59,6 +59,21 @@ class TestOpt(unittest.TestCase):
|
||||
assert len(CL.CACHE) == 3, "optimizer didn't fold batchnorm"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
opt = optim.SGD(optim.get_parameters(c1))
|
||||
with CLCache():
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
opt.step()
|
||||
# TODO: this should be 4, but the sum output child stays around
|
||||
# with pushing_permutes it can be 3
|
||||
assert len(CL.CACHE) == 5, "optimizer didn't fold conv-backward SGD"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
|
||||
Reference in New Issue
Block a user