fix backward convs (#746)

* fix backward convs

* no pushing in reduce

* late cout

* test_fold_4convs_sgd
This commit is contained in:
George Hotz
2023-04-14 10:42:11 -07:00
committed by GitHub
parent f7f416d6f4
commit 17e37157b6
3 changed files with 72 additions and 29 deletions

View File

@@ -193,6 +193,34 @@ class TestOpt(unittest.TestCase):
assert len(GlobalCounters.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
Tensor.training = False
def test_fold_2convs_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,16,3,bias=False)
c2 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(optim.get_parameters([c1, c2]))
with CLCache(allowed=9):
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
opt.step()
Tensor.training = False
def test_fold_4convs_sgd(self):
# TODO: with Tensor.training
Tensor.training = True
img = Tensor.ones(2,3,64,64)
c1 = nn.Conv2d(3,4,3,bias=False)
c2 = nn.Conv2d(4,8,3,bias=False)
c3 = nn.Conv2d(8,16,3,bias=False)
c4 = nn.Conv2d(16,32,3,bias=False)
opt = optim.SGD(optim.get_parameters([c1, c2, c3, c4]))
with CLCache(allowed=19):
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
opt.step()
Tensor.training = False
def test_fold_conv_batchnorm_sgd(self):
# TODO: with Tensor.training
Tensor.training = True