Files
tinygrad/test/test_optim.py
George Hotz 3dbde178c1 mark slow tests as slow instead of as CI (#13736)
* mark slow tests as slow instead of as CI

* CI shouldn't have different behavior

* more skips / CI

* slow
2025-12-17 10:29:57 -04:00

194 lines
9.0 KiB
Python

import numpy as np
import torch
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB
from tinygrad.device import is_dtype_supported
from test.helpers import needs_second_gpu, slow
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TeenyNet:
def __init__(self, tensor):
self.x = tensor(x_init.copy(), requires_grad=True)
self.W = tensor(W_init.copy(), requires_grad=True)
def forward(self):
return (self.x * self.W).sum()
class TinyNet:
def __init__(self, tensor):
self.x = tensor(x_init.copy(), requires_grad=True)
self.W = tensor(W_init.copy(), requires_grad=True)
self.m = tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
# print(out.detach().numpy())
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
def step(tensor, optim, steps=1, teeny=False, **kwargs):
net = TeenyNet(tensor) if teeny else TinyNet(tensor)
optim = optim([net.x, net.W], **kwargs)
for _ in range(steps):
out = net.forward()
optim.zero_grad()
out.backward()
optim.step()
return net.x.detach().numpy(), net.W.detach().numpy()
@slow
class TestOptim(unittest.TestCase):
def setUp(self):
self.old_training = Tensor.training
Tensor.training = True
def tearDown(self):
Tensor.training = self.old_training
def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol):
for x,y in zip(step(Tensor, tinygrad_optim, steps, **opts),
step(torch.tensor, torch_optim, steps, **opts)):
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
def _test_muon(self, steps, opts, atol, rtol): self._test_optim(Muon, torch.optim.Muon, steps, opts, atol, rtol)
def test_multistep_sgd_high_lr_teeny(self): self._test_sgd(2, {'lr': 1.1, 'teeny': True}, 1e-6, 1e-5)
def test_multistep_adam_high_lr_teeny(self): self._test_adam(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4)
def test_multistep_muon_high_lr_teeny(self): self._test_muon(2, {'lr': 1.1, 'teeny': True}, 1e-2, 5e-4)
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
def test_sgd_high_lr_wd(self): self._test_sgd(1, {'lr': 10, 'weight_decay': 0.1}, 1e-6, 1e-5)
def test_multistep_sgd(self): self._test_sgd(10, {'lr': 0.001}, 1e-6, 0)
def test_multistep_sgd_high_lr(self): self._test_sgd(10, {'lr': 10}, 1e-6, 3e-4)
def test_multistep_sgd_wd(self): self._test_sgd(10, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
def test_multistep_sgd_high_lr_wd(self): self._test_sgd(10, {'lr': 9, 'weight_decay': 0.1}, 1e-6, 3e-4)
def test_multistep_sgd_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9}, 1e-6, 0)
def test_multistep_sgd_high_lr_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9}, 1e-5, 3e-4)
def test_multistep_sgd_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-6, 0)
def test_multistep_sgd_high_lr_momentum_wd(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-5, 3e-4)
def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0)
def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4)
def test_multistep_sgd_nesterov_momentum_wd(self):
self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0)
def test_multistep_sgd_high_lr_nesterov_momentum_wd(self):
self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4)
def test_muon(self): self._test_muon(1, {'lr': 0.001}, 1e-3, 0)
# TODO: disabled due to big atol
# def test_muon_high_lr(self): self._test_muon(1, {'lr': 10}, 1e-6, 3e-4)
def test_muon_wd(self): self._test_muon(1, {'lr': 0.001, 'weight_decay': 0.01}, 1e-3, 3e-4)
# TODO: disabled due to big atol
# def test_muon_high_lr_wd(self): self._test_muon(1, {'lr': 10, 'weight_decay': 0.01}, 1e-6, 5e-4)
# NOTE: momentum set to 0.95 by default, nesterov set to True by default
def test_multistep_muon_momentum_wd(self): self._test_muon(10, {'lr': 0.001, 'weight_decay': 0.01}, 3e-3, 0)
# ns defaults are numerically unstable, but it is tolerable in real training (see nsteps/nparam tests)
# TODO: disabled due to big atol
# def test_multistep_muon_high_lr_momentum_wd(self): self._test_muon(10, {'lr': 10, 'weight_decay': 0.01}, 1e-1, 3e-4)
def test_multistep_muon_no_nesterov_momentum(self): self._test_muon(10, {'lr': 0.001, 'nesterov': False}, 1e-3, 0)
# TODO: disabled due to big atol
# def test_multistep_muon_high_lr_no_nesterov_momentum(self): self._test_muon(10, {'lr': 10, 'nesterov': False}, 5e-2, 1e-1)
def test_muon_ns_steps(self): self._test_muon(1, {'lr': 0.001, 'ns_steps': 3}, 1e-4, 0)
# TODO: disabled due to big atol
# def test_muon_high_lr_ns_steps(self): self._test_muon(1, {'lr': 10, 'ns_steps': 3}, 1e-5, 3e-4)
def test_muon_ns_coefficients(self): self._test_muon(1, {'lr': 0.001,'ns_coefficients': (2.0,-1.5,0.5)}, 1e-5, 3e-4)
# TODO: disabled due to big atol
# def test_muon_high_lr_ns_coefficients(self): self._test_muon(1, {'lr': 10,'ns_coefficients': (2.0,-1.5,0.5)}, 1e-5, 3e-4)
def test_muon_momentum_wd_ns_steps_ns_coefficients(self):
self._test_muon(10, {'lr': 0.001, 'momentum': 0.90, 'weight_decay': 0.01, 'ns_steps': 3, 'ns_coefficients': (2.0,-1.5,0.5)}, 1e-4, 0)
# TODO: disabled due to big atol
# def test_multistep_muon_high_lr_momentum_wd_ns_steps_ns_coefficients(self):
# self._test_muon(10, {'lr': 10, 'momentum': 0.90, 'weight_decay': 0.01, 'ns_steps': 3, 'ns_coefficients': (2.0,-1.5,0.5)}, 1e-5, 3e-4)
def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0)
def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4)
def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0)
def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-4, 1e-4)
def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 2e-3, 5e-4)
def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 5e-4, 2e-3)
def test_duped_weights(self):
for Opt in [Adam, AdamW, SGD]:
losses = []
for i in range(2):
w = Tensor(x_init.copy())
opt = Opt([w], lr=0.1) if i == 0 else Opt([w, w], lr=0.1)
loss = None
for _ in range(3):
loss = w.sum()
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.numpy())
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mixed_precision(self):
old_default_float, dtypes.default_float = dtypes.default_float, dtypes.half
# weight update would overflow without upcasting
self._test_sgd(10, {'lr': 1e10}, 1e-6, 3e-4)
self._test_adam(1, {'lr': 1e10}, 1e-4, 1e-4)
self._test_adamw(1, {'lr': 1e10}, 1e-4, 1e-4)
dtypes.default_float = old_default_float
def test_assert_tensor_train(self):
t = Tensor.ones((1,1), requires_grad=True)
optimizer = Adam([t])
optimizer.zero_grad()
old_state = Tensor.training
t.sum().backward()
Tensor.training = False
self.assertRaises(RuntimeError, optimizer.step)
Tensor.training = True
optimizer.step()
Tensor.training = old_state
def test_lamb_cpu_offload(self):
# test that LAMB works when optimizer params (m, v, b1_t, b2_t) are moved to CPU
t = Tensor(x_init.copy(), requires_grad=True)
opt = LAMB([t])
# move optimizer state to CPU
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
# run a step
t.sum().backward()
opt.step()
self.assertEqual(t.device, Device.DEFAULT)
self.assertEqual(opt.m[0].device, "CPU")
@needs_second_gpu
def test_lamb_cpu_offload_multi(self):
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
t = Tensor(x_init.copy(), requires_grad=True).shard(ds, axis=1)
ds = t.device
opt = LAMB([t])
# move optimizer state to CPU
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
# run a step
t.sum().backward()
opt.step()
self.assertEqual(t.device, ds)
self.assertEqual(opt.m[0].device, "CPU")
if __name__ == '__main__':
unittest.main()