mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
efficient version of adam (#20)
* counteracted bias initialization * test new adam * add optimizer tests * rename helper function names to fix the test * remove redundant import
This commit is contained in:
72
test/test_optim.py
Normal file
72
test/test_optim.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.optim import Adam, SGD, RMSprop
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
W_init = np.random.randn(3,3).astype(np.float32)
|
||||
m_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
def step_tinygrad(optim, kwargs={}):
|
||||
net = TinyNet()
|
||||
optim = optim([net.x, net.W], **kwargs)
|
||||
out = net.forward()
|
||||
out.backward()
|
||||
optim.step()
|
||||
return net.x.data, net.W.data
|
||||
|
||||
def step_pytorch(optim, kwargs={}):
|
||||
net = TorchNet()
|
||||
optim = optim([net.x, net.W], **kwargs)
|
||||
out = net.forward()
|
||||
out.backward()
|
||||
optim.step()
|
||||
return net.x.detach().numpy(), net.W.detach().numpy()
|
||||
|
||||
|
||||
class TinyNet():
|
||||
def __init__(self):
|
||||
self.x = Tensor(x_init.copy())
|
||||
self.W = Tensor(W_init.copy())
|
||||
self.m = Tensor(m_init.copy())
|
||||
|
||||
def forward(self):
|
||||
out = self.x.dot(self.W).relu()
|
||||
out = out.logsoftmax()
|
||||
out = out.mul(self.m).add(self.m).sum()
|
||||
return out
|
||||
|
||||
|
||||
class TorchNet():
|
||||
def __init__(self):
|
||||
self.x = torch.tensor(x_init.copy(), requires_grad=True)
|
||||
self.W = torch.tensor(W_init.copy(), requires_grad=True)
|
||||
self.m = torch.tensor(m_init.copy())
|
||||
|
||||
def forward(self):
|
||||
out = self.x.matmul(self.W).relu()
|
||||
out = torch.nn.functional.log_softmax(out, dim=1)
|
||||
out = out.mul(self.m).add(self.m).sum()
|
||||
return out
|
||||
|
||||
|
||||
class TestOptim(unittest.TestCase):
|
||||
def test_adam(self):
|
||||
for x,y in zip(step_tinygrad(Adam),
|
||||
step_pytorch(torch.optim.Adam)):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_sgd(self):
|
||||
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}),
|
||||
step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_rmsprop(self):
|
||||
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}),
|
||||
step_pytorch(torch.optim.RMSprop,
|
||||
kwargs={'lr': 0.001, 'alpha': 0.99})):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -43,10 +43,11 @@ class Adam(Optimizer):
|
||||
|
||||
def step(self):
|
||||
self.t += 1
|
||||
a = self.lr * (
|
||||
np.sqrt(1 - np.power(self.b2, self.t)) /
|
||||
(1 - np.power(self.b1, self.t)))
|
||||
for i,t in enumerate(self.params):
|
||||
self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * t.grad
|
||||
self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * np.square(t.grad)
|
||||
mhat = self.m[i] / (1. - self.b1**self.t)
|
||||
vhat = self.v[i] / (1. - self.b2**self.t)
|
||||
t.data -= self.lr * mhat / (np.sqrt(vhat) + self.eps)
|
||||
t.data -= a * self.m[i] / (np.sqrt(self.v[i]) + self.eps)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user