diff --git a/test/test_optim.py b/test/test_optim.py index 1f44caf35d..23ff404b26 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -2,7 +2,7 @@ import numpy as np import torch import unittest from tinygrad.tensor import Tensor -from tinygrad.nn.optim import Adam, SGD, RMSprop +from tinygrad.nn.optim import Adam, SGD, RMSprop, AdamW np.random.seed(1337) x_init = np.random.randn(1,4).astype(np.float32) @@ -41,6 +41,7 @@ class TestOptim(unittest.TestCase): def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol) def _test_rmsprop(self, steps, opts, atol, rtol): self._test_optim(RMSprop, torch.optim.RMSprop, steps, opts, atol, rtol) def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol) + def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol) def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0) def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5) @@ -58,12 +59,17 @@ class TestOptim(unittest.TestCase): def test_rmsprop_high_lr(self): self._test_rmsprop(1, {'lr': 10, 'alpha': 0.99}, 1e-5, 1e-5) def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0) def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-5, 1e-5) + def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0) + def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-5, 1e-5) def test_multistep_rmsprop(self): self._test_rmsprop(10, {'lr': 0.001}, 1e-5, 0) def test_multistep_rmsprop_high_lr(self): self._test_rmsprop(10, {'lr': 10}, 1e-5, 3e-4) def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0) def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 1e-5, 3e-4) + + def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0) + def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 1e-5, 3e-4) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index e2a7b83fe5..d6a2cd15c2 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -59,11 +59,11 @@ class RMSprop(Optimizer): t.assign(t.detach() - (g * self.lr).div(self.v[i].sqrt() + self.eps)) self.realize(self.v) -class Adam(Optimizer): - def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): +class AdamW(Optimizer): + def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, wd=0.01): super().__init__(params) # NOTE: self.t is a tensor so Adam can be jitted - self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0], requires_grad=False).realize() + self.lr, self.b1, self.b2, self.eps, self.wd, self.t = lr, b1, b2, eps, wd, Tensor([0], requires_grad=False).realize() self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] @@ -76,9 +76,11 @@ class Adam(Optimizer): g = t.grad.realize() self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize() self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize() - t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps)) + t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps) - self.lr * self.wd * t.detach()) self.realize([self.t] + self.m + self.v) +def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return AdamW(params, lr, b1, b2, eps, 0.0) + def get_state_dict(obj, prefix:str='') -> Dict[str, Tensor]: if isinstance(obj, Tensor): return {prefix.strip('.'):obj} if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix)