mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
replace pow in LAMB by tracking b1**t and b2**t per step (#4582)
* replace pow in LAMB by tracking b1**t and b2**t per step * remove t, add [self.b1_t, self.b2_t] to return * adam has one less kernel
This commit is contained in:
@@ -204,7 +204,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 20), (nn.optim.SGD, 17)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
@@ -609,7 +609,7 @@ class TestSchedule(unittest.TestCase):
|
||||
layer = nn.Linear(768, 768*4)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
||||
layer(x).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
check_schedule(opt.schedule_step(), 11)
|
||||
|
||||
def test_adam_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -618,7 +618,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
check_schedule(opt.schedule_step(), 11)
|
||||
|
||||
def test_adam_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -628,7 +628,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
|
||||
Reference in New Issue
Block a user