From fda73c818068d2bb52afad1e036857f8485f4352 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 16 Dec 2025 19:56:30 -0500 Subject: [PATCH] support LAMB param offload (#13730) also added Tensor.shard_like --- test/test_multitensor.py | 22 ++++++++++++++++++++++ test/test_optim.py | 29 ++++++++++++++++++++++++++++- tinygrad/nn/optim.py | 3 ++- tinygrad/tensor.py | 9 ++++++++- tinygrad/uop/ops.py | 2 +- 5 files changed, 61 insertions(+), 4 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 3ad7b36ca9..bbb0bc6952 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -63,6 +63,28 @@ class TestMultiTensor(unittest.TestCase): assert GlobalCounters.kernel_count == 0 (X + X).realize() + def test_shard_like(self): + X = Tensor.ones(256).shard(devices_2, 0) + Y = Tensor.zeros(256).shard_like(X) + self.assertEqual(Y.device, X.device) + self.assertEqual(Y.uop.axis, 0) + # also test with axis=None + X2 = Tensor.ones(256).shard(devices_2, axis=None) + Y2 = Tensor.zeros(256).shard_like(X2) + self.assertEqual(Y2.device, X2.device) + self.assertEqual(Y2.uop.axis, None) + # test with single device + X3 = Tensor.ones(256) + Y3 = Tensor.zeros(256).shard_like(X3) + self.assertEqual(Y3.device, X3.device) + # cannot shard_like multi unless it's a no-op + X4 = Tensor.ones(256).shard(devices_2, 0) + Y4 = Tensor.ones(256).shard(devices_2, 0).shard_like(X4) + self.assertEqual(Y4.device, X4.device) + self.assertEqual(Y4.uop.axis, 0) + with self.assertRaises(RuntimeError): + Tensor.ones(256).shard(devices_2, None).shard_like(X4) + def _test_shard_op(self, op, out, n=4): t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0) r = op(t).realize() diff --git a/test/test_optim.py b/test/test_optim.py index 40bbd01636..5ed5925710 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -2,9 +2,10 @@ import numpy as np import torch import unittest from tinygrad import Tensor, Device, dtypes -from tinygrad.nn.optim import Adam, SGD, AdamW, Muon +from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB from tinygrad.helpers import CI from tinygrad.device import is_dtype_supported +from test.helpers import needs_second_gpu np.random.seed(1337) x_init = np.random.randn(1,4).astype(np.float32) @@ -163,5 +164,31 @@ class TestOptim(unittest.TestCase): optimizer.step() Tensor.training = old_state + def test_lamb_cpu_offload(self): + # test that LAMB works when optimizer params (m, v, b1_t, b2_t) are moved to CPU + t = Tensor(x_init.copy(), requires_grad=True) + opt = LAMB([t]) + # move optimizer state to CPU + for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU") + # run a step + t.sum().backward() + opt.step() + self.assertEqual(t.device, Device.DEFAULT) + self.assertEqual(opt.m[0].device, "CPU") + + @needs_second_gpu + def test_lamb_cpu_offload_multi(self): + ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(2)) + t = Tensor(x_init.copy(), requires_grad=True).shard(ds, axis=1) + ds = t.device + opt = LAMB([t]) + # move optimizer state to CPU + for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU") + # run a step + t.sum().backward() + opt.step() + self.assertEqual(t.device, ds) + self.assertEqual(opt.m[0].device, "CPU") + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index bcca52f0e0..60086772b9 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -163,11 +163,12 @@ class LAMB(Optimizer): self.b1_t *= self.b1 self.b2_t *= self.b2 for i, (t, g) in enumerate(zip(params, grads)): + if g.device != self.m[i].device: g = g.contiguous().to(self.m[i].device) self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype)) self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype)) m_hat = self.m[i] / (1.0 - self.b1_t) v_hat = self.v[i] / (1.0 - self.b2_t) - up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach() + up = (m_hat / (v_hat.sqrt() + self.eps)).shard_like(t) + self.wd * t.detach() if not self.adam: r1 = t.detach().square().sum().sqrt() r2 = up.square().sum().sqrt() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d6b68b683b..369675f4f7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -396,7 +396,7 @@ class Tensor(OpMixin): print(t.shard((t.device, t.device), axis=1).uop) ``` """ - assert isinstance(self.device, str), "can't shard a MultiLazyBuffer" + if not isinstance(self.device, str): raise RuntimeError("can't shard a MultiLazyBuffer") if len(devices) == 1: return self.to(devices[0]) devices = tuple(canonicalize_device(x) for x in devices) mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices) @@ -408,6 +408,13 @@ class Tensor(OpMixin): """ return self.replace(self.shard(devices, axis)) + def shard_like(self, y:Tensor) -> Tensor: + """ + Shards the tensor the same way as `y` (same devices and axis). + """ + if isinstance(y.device, str): return self.to(y.device) + return self if isinstance(self.device, tuple) and (y.device, y.uop.axis) == (self.device, self.uop.axis) else self.shard(y.device, y.uop.axis) + CHUNK_SIZE = 2**20 def fs_load(self, size:int) -> Tensor: """ diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 3829483c1c..7ff428ce1b 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -616,7 +616,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.op is Ops.BUFFERIZE: return self.arg.device if self.op is Ops.AFTER: return self.src[0]._device if self.op is Ops.MSELECT: - assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device" + assert isinstance(self.src[0].device, tuple), f"mselect must be on tuple device, getting {self.src[0].device}" return self.src[0].device[self.arg] if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src) if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device