support LAMB param offload (#13730)

also added Tensor.shard_like
This commit is contained in:
chenyu
2025-12-16 19:56:30 -05:00
committed by GitHub
parent cf0c28d5ae
commit fda73c8180
5 changed files with 61 additions and 4 deletions

View File

@@ -2,9 +2,10 @@ import numpy as np
import torch
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB
from tinygrad.helpers import CI
from tinygrad.device import is_dtype_supported
from test.helpers import needs_second_gpu
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
@@ -163,5 +164,31 @@ class TestOptim(unittest.TestCase):
optimizer.step()
Tensor.training = old_state
def test_lamb_cpu_offload(self):
# test that LAMB works when optimizer params (m, v, b1_t, b2_t) are moved to CPU
t = Tensor(x_init.copy(), requires_grad=True)
opt = LAMB([t])
# move optimizer state to CPU
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
# run a step
t.sum().backward()
opt.step()
self.assertEqual(t.device, Device.DEFAULT)
self.assertEqual(opt.m[0].device, "CPU")
@needs_second_gpu
def test_lamb_cpu_offload_multi(self):
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
t = Tensor(x_init.copy(), requires_grad=True).shard(ds, axis=1)
ds = t.device
opt = LAMB([t])
# move optimizer state to CPU
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
# run a step
t.sum().backward()
opt.step()
self.assertEqual(t.device, ds)
self.assertEqual(opt.m[0].device, "CPU")
if __name__ == '__main__':
unittest.main()