replace pow in LAMB by tracking b1**t and b2**t per step (#4582)

* replace pow in LAMB by tracking b1**t and b2**t per step

* remove t, add [self.b1_t, self.b2_t] to return

* adam has one less kernel
This commit is contained in:
chenyu
2024-05-14 13:08:22 -04:00
committed by GitHub
parent 9b02aef45a
commit 7afca52796
2 changed files with 13 additions and 12 deletions

View File

@@ -204,7 +204,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 20), (nn.optim.SGD, 17)]:
for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
@@ -609,7 +609,7 @@ class TestSchedule(unittest.TestCase):
layer = nn.Linear(768, 768*4)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 12)
check_schedule(opt.schedule_step(), 11)
def test_adam_conv_fuse(self):
with Tensor.train():
@@ -618,7 +618,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 12)
check_schedule(opt.schedule_step(), 11)
def test_adam_2convs_fuse(self):
with Tensor.train():
@@ -628,7 +628,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
check_schedule(opt.schedule_step(), 13)
def test_sgd_conv_fuse(self):
with Tensor.train():