mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix backward convs (#746)
* fix backward convs * no pushing in reduce * late cout * test_fold_4convs_sgd
This commit is contained in:
28
test/external/external_test_opt.py
vendored
28
test/external/external_test_opt.py
vendored
@@ -193,6 +193,34 @@ class TestOpt(unittest.TestCase):
|
||||
assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_2convs_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(optim.get_parameters([c1, c2]))
|
||||
with CLCache(allowed=9):
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_4convs_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
c1 = nn.Conv2d(3,4,3,bias=False)
|
||||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
opt = optim.SGD(optim.get_parameters([c1, c2, c3, c4]))
|
||||
with CLCache(allowed=19):
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
opt.step()
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_sgd(self):
|
||||
# TODO: with Tensor.training
|
||||
Tensor.training = True
|
||||
|
||||
Reference in New Issue
Block a user