mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -2,9 +2,10 @@ import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from test.helpers import needs_second_gpu
|
||||
|
||||
np.random.seed(1337)
|
||||
x_init = np.random.randn(1,4).astype(np.float32)
|
||||
@@ -163,5 +164,31 @@ class TestOptim(unittest.TestCase):
|
||||
optimizer.step()
|
||||
Tensor.training = old_state
|
||||
|
||||
def test_lamb_cpu_offload(self):
|
||||
# test that LAMB works when optimizer params (m, v, b1_t, b2_t) are moved to CPU
|
||||
t = Tensor(x_init.copy(), requires_grad=True)
|
||||
opt = LAMB([t])
|
||||
# move optimizer state to CPU
|
||||
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
|
||||
# run a step
|
||||
t.sum().backward()
|
||||
opt.step()
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
self.assertEqual(opt.m[0].device, "CPU")
|
||||
|
||||
@needs_second_gpu
|
||||
def test_lamb_cpu_offload_multi(self):
|
||||
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
||||
t = Tensor(x_init.copy(), requires_grad=True).shard(ds, axis=1)
|
||||
ds = t.device
|
||||
opt = LAMB([t])
|
||||
# move optimizer state to CPU
|
||||
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
|
||||
# run a step
|
||||
t.sum().backward()
|
||||
opt.step()
|
||||
self.assertEqual(t.device, ds)
|
||||
self.assertEqual(opt.m[0].device, "CPU")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user