make maximum split the grad like torch when equal (#738)

* make maximum split grad

* added test for maximum split grad when equal

* minor expr simplification

* (2-eq)/2 only once

* update test bc one more sum output child stays
This commit is contained in:
worldwalker2000
2023-04-14 00:17:46 -07:00
committed by GitHub
parent 06ed958abd
commit 552a048a33
3 changed files with 9 additions and 5 deletions

View File

@@ -190,7 +190,7 @@ class TestOpt(unittest.TestCase):
# TODO: this should be 4, but the sum output child stays around
# with pushing_permutes it can be 3
# TODO: broken with optim fixes
assert len(GlobalCounters.cache) in [4,5,6,7], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
assert len(GlobalCounters.cache) in [4,5,6,7,8], f"optimizer didn't fold conv-backward SGD, got {len(GlobalCounters.cache)}"
Tensor.training = False
def test_fold_conv_batchnorm_sgd(self):