mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Add weight decay to SGD (#883)
* feat: add weight decay to sgd * fix: fix tests
This commit is contained in:
@@ -17,6 +17,7 @@ class TinyNet():
|
||||
|
||||
def forward(self):
|
||||
out = self.x.matmul(self.W).relu()
|
||||
# print(out.detach().numpy())
|
||||
out = out.log_softmax(1)
|
||||
out = out.mul(self.m).add(self.m).sum()
|
||||
return out
|
||||
@@ -41,18 +42,26 @@ class TestOptim(unittest.TestCase):
|
||||
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
|
||||
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
|
||||
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
|
||||
|
||||
#
|
||||
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
|
||||
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
|
||||
def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
|
||||
def test_sgd_high_lr_wd(self): self._test_sgd(1, {'lr': 10, 'weight_decay': 0.1}, 1e-6, 1e-5)
|
||||
|
||||
def test_multistep_sgd(self): self._test_sgd(10, {'lr': 0.001}, 1e-6, 0)
|
||||
def test_multistep_sgd_high_lr(self): self._test_sgd(10, {'lr': 10}, 1e-6, 3e-4)
|
||||
def test_multistep_sgd_wd(self): self._test_sgd(10, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
|
||||
def test_multistep_sgd_high_lr_wd(self): self._test_sgd(10, {'lr': 9, 'weight_decay': 0.1}, 1e-6, 3e-4)
|
||||
|
||||
def test_multistep_sgd_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9}, 1e-6, 0)
|
||||
def test_multistep_sgd_high_lr_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9}, 1e-5, 3e-4)
|
||||
def test_multistep_sgd_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-6, 0)
|
||||
def test_multistep_sgd_high_lr_momentum_wd(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-5, 3e-4)
|
||||
|
||||
def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0)
|
||||
def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4)
|
||||
def test_multistep_sgd_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0)
|
||||
def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4)
|
||||
|
||||
def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-5, 1e-5)
|
||||
|
||||
@@ -21,16 +21,16 @@ class Optimizer:
|
||||
p.realize()
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, nesterov=False):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
|
||||
super().__init__(params)
|
||||
self.lr, self.momentum, self.nesterov = lr, momentum, nesterov
|
||||
self.lr, self.momentum, self.wd, self.nesterov = lr, momentum, weight_decay, nesterov
|
||||
self.b = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []
|
||||
|
||||
# https://pytorch.org/docs/stable/generated/torch.optim.SGD.html
|
||||
def step(self) -> None:
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
g = t.grad.realize()
|
||||
g = t.grad.realize() + self.wd * t.detach()
|
||||
if self.momentum:
|
||||
self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required
|
||||
g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i]
|
||||
|
||||
Reference in New Issue
Block a user