From 4b163ee27067b478066f571d914aea957b59637e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6ktu=C4=9F=20Karaka=C5=9Fl=C4=B1?= <20567087+goktug97@users.noreply.github.com> Date: Wed, 28 Oct 2020 01:54:40 +0300 Subject: [PATCH] efficient version of adam (#20) * counteracted bias initialization * test new adam * add optimizer tests * rename helper function names to fix the test * remove redundant import --- test/test_optim.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++ tinygrad/optim.py | 7 +++-- 2 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 test/test_optim.py diff --git a/test/test_optim.py b/test/test_optim.py new file mode 100644 index 0000000000..927c4d85cc --- /dev/null +++ b/test/test_optim.py @@ -0,0 +1,72 @@ +import numpy as np +import torch +import unittest +from tinygrad.tensor import Tensor +from tinygrad.optim import Adam, SGD, RMSprop + +x_init = np.random.randn(1,3).astype(np.float32) +W_init = np.random.randn(3,3).astype(np.float32) +m_init = np.random.randn(1,3).astype(np.float32) + +def step_tinygrad(optim, kwargs={}): + net = TinyNet() + optim = optim([net.x, net.W], **kwargs) + out = net.forward() + out.backward() + optim.step() + return net.x.data, net.W.data + +def step_pytorch(optim, kwargs={}): + net = TorchNet() + optim = optim([net.x, net.W], **kwargs) + out = net.forward() + out.backward() + optim.step() + return net.x.detach().numpy(), net.W.detach().numpy() + + +class TinyNet(): + def __init__(self): + self.x = Tensor(x_init.copy()) + self.W = Tensor(W_init.copy()) + self.m = Tensor(m_init.copy()) + + def forward(self): + out = self.x.dot(self.W).relu() + out = out.logsoftmax() + out = out.mul(self.m).add(self.m).sum() + return out + + +class TorchNet(): + def __init__(self): + self.x = torch.tensor(x_init.copy(), requires_grad=True) + self.W = torch.tensor(W_init.copy(), requires_grad=True) + self.m = torch.tensor(m_init.copy()) + + def forward(self): + out = self.x.matmul(self.W).relu() + out = torch.nn.functional.log_softmax(out, dim=1) + out = out.mul(self.m).add(self.m).sum() + return out + + +class TestOptim(unittest.TestCase): + def test_adam(self): + for x,y in zip(step_tinygrad(Adam), + step_pytorch(torch.optim.Adam)): + np.testing.assert_allclose(x, y, atol=1e-5) + + def test_sgd(self): + for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}), + step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})): + np.testing.assert_allclose(x, y, atol=1e-5) + + def test_rmsprop(self): + for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}), + step_pytorch(torch.optim.RMSprop, + kwargs={'lr': 0.001, 'alpha': 0.99})): + np.testing.assert_allclose(x, y, atol=1e-5) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/optim.py b/tinygrad/optim.py index b2bb5985f3..ab4e8034fb 100644 --- a/tinygrad/optim.py +++ b/tinygrad/optim.py @@ -43,10 +43,11 @@ class Adam(Optimizer): def step(self): self.t += 1 + a = self.lr * ( + np.sqrt(1 - np.power(self.b2, self.t)) / + (1 - np.power(self.b1, self.t))) for i,t in enumerate(self.params): self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * t.grad self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * np.square(t.grad) - mhat = self.m[i] / (1. - self.b1**self.t) - vhat = self.v[i] / (1. - self.b2**self.t) - t.data -= self.lr * mhat / (np.sqrt(vhat) + self.eps) + t.data -= a * self.m[i] / (np.sqrt(self.v[i]) + self.eps)