mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove RMSprop, nobody uses it anymore
This commit is contained in:
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.optim import Adam, SGD, RMSprop, AdamW
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW
|
||||
|
||||
np.random.seed(1337)
|
||||
x_init = np.random.randn(1,4).astype(np.float32)
|
||||
@@ -39,7 +39,6 @@ class TestOptim(unittest.TestCase):
|
||||
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
|
||||
|
||||
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
|
||||
def _test_rmsprop(self, steps, opts, atol, rtol): self._test_optim(RMSprop, torch.optim.RMSprop, steps, opts, atol, rtol)
|
||||
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
|
||||
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
|
||||
|
||||
@@ -55,19 +54,14 @@ class TestOptim(unittest.TestCase):
|
||||
def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0)
|
||||
def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4)
|
||||
|
||||
def test_rmsprop(self): self._test_rmsprop(1, {'lr': 0.001, 'alpha': 0.99}, 1e-5, 0)
|
||||
def test_rmsprop_high_lr(self): self._test_rmsprop(1, {'lr': 10, 'alpha': 0.99}, 1e-5, 1e-5)
|
||||
def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-5, 1e-5)
|
||||
def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-5, 1e-5)
|
||||
|
||||
def test_multistep_rmsprop(self): self._test_rmsprop(10, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_multistep_rmsprop_high_lr(self): self._test_rmsprop(10, {'lr': 10}, 1e-5, 3e-4)
|
||||
|
||||
def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 1e-5, 3e-4)
|
||||
|
||||
|
||||
def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0)
|
||||
def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 1e-5, 3e-4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user