support LAMB param offload (#13730)

also added Tensor.shard_like
This commit is contained in:
chenyu
2025-12-16 19:56:30 -05:00
committed by GitHub
parent cf0c28d5ae
commit fda73c8180
5 changed files with 61 additions and 4 deletions

View File

@@ -63,6 +63,28 @@ class TestMultiTensor(unittest.TestCase):
assert GlobalCounters.kernel_count == 0
(X + X).realize()
def test_shard_like(self):
X = Tensor.ones(256).shard(devices_2, 0)
Y = Tensor.zeros(256).shard_like(X)
self.assertEqual(Y.device, X.device)
self.assertEqual(Y.uop.axis, 0)
# also test with axis=None
X2 = Tensor.ones(256).shard(devices_2, axis=None)
Y2 = Tensor.zeros(256).shard_like(X2)
self.assertEqual(Y2.device, X2.device)
self.assertEqual(Y2.uop.axis, None)
# test with single device
X3 = Tensor.ones(256)
Y3 = Tensor.zeros(256).shard_like(X3)
self.assertEqual(Y3.device, X3.device)
# cannot shard_like multi unless it's a no-op
X4 = Tensor.ones(256).shard(devices_2, 0)
Y4 = Tensor.ones(256).shard(devices_2, 0).shard_like(X4)
self.assertEqual(Y4.device, X4.device)
self.assertEqual(Y4.uop.axis, 0)
with self.assertRaises(RuntimeError):
Tensor.ones(256).shard(devices_2, None).shard_like(X4)
def _test_shard_op(self, op, out, n=4):
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
r = op(t).realize()

View File

@@ -2,9 +2,10 @@ import numpy as np
import torch
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB
from tinygrad.helpers import CI
from tinygrad.device import is_dtype_supported
from test.helpers import needs_second_gpu
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
@@ -163,5 +164,31 @@ class TestOptim(unittest.TestCase):
optimizer.step()
Tensor.training = old_state
def test_lamb_cpu_offload(self):
# test that LAMB works when optimizer params (m, v, b1_t, b2_t) are moved to CPU
t = Tensor(x_init.copy(), requires_grad=True)
opt = LAMB([t])
# move optimizer state to CPU
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
# run a step
t.sum().backward()
opt.step()
self.assertEqual(t.device, Device.DEFAULT)
self.assertEqual(opt.m[0].device, "CPU")
@needs_second_gpu
def test_lamb_cpu_offload_multi(self):
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
t = Tensor(x_init.copy(), requires_grad=True).shard(ds, axis=1)
ds = t.device
opt = LAMB([t])
# move optimizer state to CPU
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
# run a step
t.sum().backward()
opt.step()
self.assertEqual(t.device, ds)
self.assertEqual(opt.m[0].device, "CPU")
if __name__ == '__main__':
unittest.main()

View File

@@ -163,11 +163,12 @@ class LAMB(Optimizer):
self.b1_t *= self.b1
self.b2_t *= self.b2
for i, (t, g) in enumerate(zip(params, grads)):
if g.device != self.m[i].device: g = g.contiguous().to(self.m[i].device)
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
m_hat = self.m[i] / (1.0 - self.b1_t)
v_hat = self.v[i] / (1.0 - self.b2_t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
up = (m_hat / (v_hat.sqrt() + self.eps)).shard_like(t) + self.wd * t.detach()
if not self.adam:
r1 = t.detach().square().sum().sqrt()
r2 = up.square().sum().sqrt()

View File

@@ -396,7 +396,7 @@ class Tensor(OpMixin):
print(t.shard((t.device, t.device), axis=1).uop)
```
"""
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
if not isinstance(self.device, str): raise RuntimeError("can't shard a MultiLazyBuffer")
if len(devices) == 1: return self.to(devices[0])
devices = tuple(canonicalize_device(x) for x in devices)
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
@@ -408,6 +408,13 @@ class Tensor(OpMixin):
"""
return self.replace(self.shard(devices, axis))
def shard_like(self, y:Tensor) -> Tensor:
"""
Shards the tensor the same way as `y` (same devices and axis).
"""
if isinstance(y.device, str): return self.to(y.device)
return self if isinstance(self.device, tuple) and (y.device, y.uop.axis) == (self.device, self.uop.axis) else self.shard(y.device, y.uop.axis)
CHUNK_SIZE = 2**20
def fs_load(self, size:int) -> Tensor:
"""

View File

@@ -616,7 +616,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if self.op is Ops.BUFFERIZE: return self.arg.device
if self.op is Ops.AFTER: return self.src[0]._device
if self.op is Ops.MSELECT:
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
assert isinstance(self.src[0].device, tuple), f"mselect must be on tuple device, getting {self.src[0].device}"
return self.src[0].device[self.arg]
if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src)
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device