mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
@@ -63,6 +63,28 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert GlobalCounters.kernel_count == 0
|
||||
(X + X).realize()
|
||||
|
||||
def test_shard_like(self):
|
||||
X = Tensor.ones(256).shard(devices_2, 0)
|
||||
Y = Tensor.zeros(256).shard_like(X)
|
||||
self.assertEqual(Y.device, X.device)
|
||||
self.assertEqual(Y.uop.axis, 0)
|
||||
# also test with axis=None
|
||||
X2 = Tensor.ones(256).shard(devices_2, axis=None)
|
||||
Y2 = Tensor.zeros(256).shard_like(X2)
|
||||
self.assertEqual(Y2.device, X2.device)
|
||||
self.assertEqual(Y2.uop.axis, None)
|
||||
# test with single device
|
||||
X3 = Tensor.ones(256)
|
||||
Y3 = Tensor.zeros(256).shard_like(X3)
|
||||
self.assertEqual(Y3.device, X3.device)
|
||||
# cannot shard_like multi unless it's a no-op
|
||||
X4 = Tensor.ones(256).shard(devices_2, 0)
|
||||
Y4 = Tensor.ones(256).shard(devices_2, 0).shard_like(X4)
|
||||
self.assertEqual(Y4.device, X4.device)
|
||||
self.assertEqual(Y4.uop.axis, 0)
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.ones(256).shard(devices_2, None).shard_like(X4)
|
||||
|
||||
def _test_shard_op(self, op, out, n=4):
|
||||
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
|
||||
r = op(t).realize()
|
||||
|
||||
@@ -2,9 +2,10 @@ import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon
|
||||
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from test.helpers import needs_second_gpu
|
||||
|
||||
np.random.seed(1337)
|
||||
x_init = np.random.randn(1,4).astype(np.float32)
|
||||
@@ -163,5 +164,31 @@ class TestOptim(unittest.TestCase):
|
||||
optimizer.step()
|
||||
Tensor.training = old_state
|
||||
|
||||
def test_lamb_cpu_offload(self):
|
||||
# test that LAMB works when optimizer params (m, v, b1_t, b2_t) are moved to CPU
|
||||
t = Tensor(x_init.copy(), requires_grad=True)
|
||||
opt = LAMB([t])
|
||||
# move optimizer state to CPU
|
||||
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
|
||||
# run a step
|
||||
t.sum().backward()
|
||||
opt.step()
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
self.assertEqual(opt.m[0].device, "CPU")
|
||||
|
||||
@needs_second_gpu
|
||||
def test_lamb_cpu_offload_multi(self):
|
||||
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
||||
t = Tensor(x_init.copy(), requires_grad=True).shard(ds, axis=1)
|
||||
ds = t.device
|
||||
opt = LAMB([t])
|
||||
# move optimizer state to CPU
|
||||
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
|
||||
# run a step
|
||||
t.sum().backward()
|
||||
opt.step()
|
||||
self.assertEqual(t.device, ds)
|
||||
self.assertEqual(opt.m[0].device, "CPU")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -163,11 +163,12 @@ class LAMB(Optimizer):
|
||||
self.b1_t *= self.b1
|
||||
self.b2_t *= self.b2
|
||||
for i, (t, g) in enumerate(zip(params, grads)):
|
||||
if g.device != self.m[i].device: g = g.contiguous().to(self.m[i].device)
|
||||
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
|
||||
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
|
||||
m_hat = self.m[i] / (1.0 - self.b1_t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2_t)
|
||||
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
||||
up = (m_hat / (v_hat.sqrt() + self.eps)).shard_like(t) + self.wd * t.detach()
|
||||
if not self.adam:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = up.square().sum().sqrt()
|
||||
|
||||
@@ -396,7 +396,7 @@ class Tensor(OpMixin):
|
||||
print(t.shard((t.device, t.device), axis=1).uop)
|
||||
```
|
||||
"""
|
||||
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
|
||||
if not isinstance(self.device, str): raise RuntimeError("can't shard a MultiLazyBuffer")
|
||||
if len(devices) == 1: return self.to(devices[0])
|
||||
devices = tuple(canonicalize_device(x) for x in devices)
|
||||
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
||||
@@ -408,6 +408,13 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return self.replace(self.shard(devices, axis))
|
||||
|
||||
def shard_like(self, y:Tensor) -> Tensor:
|
||||
"""
|
||||
Shards the tensor the same way as `y` (same devices and axis).
|
||||
"""
|
||||
if isinstance(y.device, str): return self.to(y.device)
|
||||
return self if isinstance(self.device, tuple) and (y.device, y.uop.axis) == (self.device, self.uop.axis) else self.shard(y.device, y.uop.axis)
|
||||
|
||||
CHUNK_SIZE = 2**20
|
||||
def fs_load(self, size:int) -> Tensor:
|
||||
"""
|
||||
|
||||
@@ -616,7 +616,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.BUFFERIZE: return self.arg.device
|
||||
if self.op is Ops.AFTER: return self.src[0]._device
|
||||
if self.op is Ops.MSELECT:
|
||||
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
|
||||
assert isinstance(self.src[0].device, tuple), f"mselect must be on tuple device, getting {self.src[0].device}"
|
||||
return self.src[0].device[self.arg]
|
||||
if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src)
|
||||
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device
|
||||
|
||||
Reference in New Issue
Block a user