diff --git a/test/models/test_mnist.py b/test/models/test_mnist.py index 3e2c6d0d1c..54ae2cf356 100644 --- a/test/models/test_mnist.py +++ b/test/models/test_mnist.py @@ -112,7 +112,7 @@ class TestMNIST(unittest.TestCase): def test_rmsprop(self): np.random.seed(1337) model = TinyBobNet() - optimizer = optim.RMSprop(model.parameters(), lr=0.0002) + optimizer = optim.RMSprop(model.parameters(), lr=0.0002, alpha=0.9) train(model, X_train, Y_train, optimizer, steps=400) assert evaluate(model, X_test, Y_test) > 0.95 diff --git a/test/test_optim.py b/test/test_optim.py index 01ea23d657..1f44caf35d 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -2,72 +2,68 @@ import numpy as np import torch import unittest from tinygrad.tensor import Tensor -from tinygrad.nn.optim import Adam, SGD, RMSprop, get_parameters - -x_init = np.random.randn(1,3).astype(np.float32) -W_init = np.random.randn(3,3).astype(np.float32) -m_init = np.random.randn(1,3).astype(np.float32) - -def step_tinygrad(optim, kwargs={}): - net = TinyNet() - optim = optim([net.x, net.W], **kwargs) - out = net.forward() - out.backward() - optim.step() - return net.x.cpu().numpy(), net.W.cpu().numpy() - -def step_pytorch(optim, kwargs={}): - net = TorchNet() - optim = optim([net.x, net.W], **kwargs) - out = net.forward() - out.backward() - optim.step() - return net.x.detach().numpy(), net.W.detach().numpy() +from tinygrad.nn.optim import Adam, SGD, RMSprop +np.random.seed(1337) +x_init = np.random.randn(1,4).astype(np.float32) +W_init = np.random.randn(4,4).astype(np.float32) +m_init = np.random.randn(1,4).astype(np.float32) class TinyNet(): - def __init__(self): - self.x = Tensor(x_init.copy()) - self.W = Tensor(W_init.copy()) - self.m = Tensor(m_init.copy()) - - def forward(self): - out = self.x.dot(self.W).relu() - out = out.log_softmax() - out = out.mul(self.m).add(self.m).sum() - return out - - -class TorchNet(): - def __init__(self): - self.x = torch.tensor(x_init.copy(), requires_grad=True) - self.W = torch.tensor(W_init.copy(), requires_grad=True) - self.m = torch.tensor(m_init.copy()) + def __init__(self, tensor): + self.x = tensor(x_init.copy(), requires_grad=True) + self.W = tensor(W_init.copy(), requires_grad=True) + self.m = tensor(m_init.copy()) def forward(self): out = self.x.matmul(self.W).relu() - out = torch.nn.functional.log_softmax(out, dim=1) + out = out.log_softmax(1) out = out.mul(self.m).add(self.m).sum() return out +def step(tensor, optim, steps=1, kwargs={}): + net = TinyNet(tensor) + optim = optim([net.x, net.W], **kwargs) + for _ in range(steps): + out = net.forward() + optim.zero_grad() + out.backward() + optim.step() + return net.x.detach().numpy(), net.W.detach().numpy() class TestOptim(unittest.TestCase): - def test_adam(self): - for x,y in zip(step_tinygrad(Adam), - step_pytorch(torch.optim.Adam)): - np.testing.assert_allclose(x, y, atol=1e-4) + def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol): + for x,y in zip(step(Tensor, tinygrad_optim, steps, kwargs=opts), + step(torch.tensor, torch_optim, steps, kwargs=opts)): + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) - def test_sgd(self): - for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}), - step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})): - np.testing.assert_allclose(x, y, atol=1e-5) + def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol) + def _test_rmsprop(self, steps, opts, atol, rtol): self._test_optim(RMSprop, torch.optim.RMSprop, steps, opts, atol, rtol) + def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol) - def test_rmsprop(self): - for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}), - step_pytorch(torch.optim.RMSprop, - kwargs={'lr': 0.001, 'alpha': 0.99})): - np.testing.assert_allclose(x, y, atol=1e-5) + def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0) + def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5) + + def test_multistep_sgd(self): self._test_sgd(10, {'lr': 0.001}, 1e-6, 0) + def test_multistep_sgd_high_lr(self): self._test_sgd(10, {'lr': 10}, 1e-6, 3e-4) + + def test_multistep_sgd_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9}, 1e-6, 0) + def test_multistep_sgd_high_lr_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9}, 1e-5, 3e-4) + + def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0) + def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4) + + def test_rmsprop(self): self._test_rmsprop(1, {'lr': 0.001, 'alpha': 0.99}, 1e-5, 0) + def test_rmsprop_high_lr(self): self._test_rmsprop(1, {'lr': 10, 'alpha': 0.99}, 1e-5, 1e-5) + def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0) + def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-5, 1e-5) + + def test_multistep_rmsprop(self): self._test_rmsprop(10, {'lr': 0.001}, 1e-5, 0) + def test_multistep_rmsprop_high_lr(self): self._test_rmsprop(10, {'lr': 10}, 1e-5, 3e-4) + + def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0) + def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 1e-5, 3e-4) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 72f27a17d3..b92389e815 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -23,6 +23,7 @@ class Optimizer: def realize(self, extra=None): # TODO: corealize + # NOTE: in extra is too late for most of the params due to issues with assign for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers: p.realize() @@ -36,25 +37,26 @@ class SGD(Optimizer): def step(self) -> None: for i, t in enumerate(self.params): assert t.grad is not None - g = t.grad + g = t.grad.realize() if self.momentum: - self.b[i].assign(self.momentum * self.b[i] + g) + self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i] t.assign(t.detach() - g * self.lr) self.realize(self.b) class RMSprop(Optimizer): - def __init__(self, params: List[Tensor], lr=0.001, decay=0.9, eps=1e-8): + def __init__(self, params: List[Tensor], lr=0.001, alpha=0.99, eps=1e-8): super().__init__(params) - self.lr, self.decay, self.eps = lr, decay, eps + self.lr, self.alpha, self.eps = lr, alpha, eps self.v = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] def step(self) -> None: for i, t in enumerate(self.params): assert t.grad is not None - self.v[i].assign(self.decay * self.v[i] + (1.0 - self.decay) * (t.grad * t.grad)) - t.assign(t.detach() - (t.grad * self.lr).div(self.v[i].sqrt() + self.eps)) + g = t.grad.realize() + self.v[i].assign(self.alpha * self.v[i] + (1.0 - self.alpha) * (g * g)).realize() + t.assign(t.detach() - (g * self.lr).div(self.v[i].sqrt() + self.eps)) self.realize(self.v) class Adam(Optimizer): @@ -67,12 +69,13 @@ class Adam(Optimizer): self.v = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] def step(self) -> None: - self.t = self.t + 1 + self.t.assign(self.t + 1).realize() a = self.lr * ((1.0 - self.b2**self.t)**0.5) / (1.0 - self.b1**self.t) for i, t in enumerate(self.params): assert t.grad is not None - self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad) - self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)) + g = t.grad.realize() + self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize() + self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize() t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps)) self.realize([self.t] + self.m + self.v)