mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
@@ -63,6 +63,28 @@ class TestMultiTensor(unittest.TestCase):
|
|||||||
assert GlobalCounters.kernel_count == 0
|
assert GlobalCounters.kernel_count == 0
|
||||||
(X + X).realize()
|
(X + X).realize()
|
||||||
|
|
||||||
|
def test_shard_like(self):
|
||||||
|
X = Tensor.ones(256).shard(devices_2, 0)
|
||||||
|
Y = Tensor.zeros(256).shard_like(X)
|
||||||
|
self.assertEqual(Y.device, X.device)
|
||||||
|
self.assertEqual(Y.uop.axis, 0)
|
||||||
|
# also test with axis=None
|
||||||
|
X2 = Tensor.ones(256).shard(devices_2, axis=None)
|
||||||
|
Y2 = Tensor.zeros(256).shard_like(X2)
|
||||||
|
self.assertEqual(Y2.device, X2.device)
|
||||||
|
self.assertEqual(Y2.uop.axis, None)
|
||||||
|
# test with single device
|
||||||
|
X3 = Tensor.ones(256)
|
||||||
|
Y3 = Tensor.zeros(256).shard_like(X3)
|
||||||
|
self.assertEqual(Y3.device, X3.device)
|
||||||
|
# cannot shard_like multi unless it's a no-op
|
||||||
|
X4 = Tensor.ones(256).shard(devices_2, 0)
|
||||||
|
Y4 = Tensor.ones(256).shard(devices_2, 0).shard_like(X4)
|
||||||
|
self.assertEqual(Y4.device, X4.device)
|
||||||
|
self.assertEqual(Y4.uop.axis, 0)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
Tensor.ones(256).shard(devices_2, None).shard_like(X4)
|
||||||
|
|
||||||
def _test_shard_op(self, op, out, n=4):
|
def _test_shard_op(self, op, out, n=4):
|
||||||
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
|
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
|
||||||
r = op(t).realize()
|
r = op(t).realize()
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor, Device, dtypes
|
from tinygrad import Tensor, Device, dtypes
|
||||||
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon
|
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon, LAMB
|
||||||
from tinygrad.helpers import CI
|
from tinygrad.helpers import CI
|
||||||
from tinygrad.device import is_dtype_supported
|
from tinygrad.device import is_dtype_supported
|
||||||
|
from test.helpers import needs_second_gpu
|
||||||
|
|
||||||
np.random.seed(1337)
|
np.random.seed(1337)
|
||||||
x_init = np.random.randn(1,4).astype(np.float32)
|
x_init = np.random.randn(1,4).astype(np.float32)
|
||||||
@@ -163,5 +164,31 @@ class TestOptim(unittest.TestCase):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
Tensor.training = old_state
|
Tensor.training = old_state
|
||||||
|
|
||||||
|
def test_lamb_cpu_offload(self):
|
||||||
|
# test that LAMB works when optimizer params (m, v, b1_t, b2_t) are moved to CPU
|
||||||
|
t = Tensor(x_init.copy(), requires_grad=True)
|
||||||
|
opt = LAMB([t])
|
||||||
|
# move optimizer state to CPU
|
||||||
|
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
|
||||||
|
# run a step
|
||||||
|
t.sum().backward()
|
||||||
|
opt.step()
|
||||||
|
self.assertEqual(t.device, Device.DEFAULT)
|
||||||
|
self.assertEqual(opt.m[0].device, "CPU")
|
||||||
|
|
||||||
|
@needs_second_gpu
|
||||||
|
def test_lamb_cpu_offload_multi(self):
|
||||||
|
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
||||||
|
t = Tensor(x_init.copy(), requires_grad=True).shard(ds, axis=1)
|
||||||
|
ds = t.device
|
||||||
|
opt = LAMB([t])
|
||||||
|
# move optimizer state to CPU
|
||||||
|
for p in opt.m + opt.v + [opt.b1_t, opt.b2_t]: p.to_("CPU")
|
||||||
|
# run a step
|
||||||
|
t.sum().backward()
|
||||||
|
opt.step()
|
||||||
|
self.assertEqual(t.device, ds)
|
||||||
|
self.assertEqual(opt.m[0].device, "CPU")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -163,11 +163,12 @@ class LAMB(Optimizer):
|
|||||||
self.b1_t *= self.b1
|
self.b1_t *= self.b1
|
||||||
self.b2_t *= self.b2
|
self.b2_t *= self.b2
|
||||||
for i, (t, g) in enumerate(zip(params, grads)):
|
for i, (t, g) in enumerate(zip(params, grads)):
|
||||||
|
if g.device != self.m[i].device: g = g.contiguous().to(self.m[i].device)
|
||||||
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
|
self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
|
||||||
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
|
self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
|
||||||
m_hat = self.m[i] / (1.0 - self.b1_t)
|
m_hat = self.m[i] / (1.0 - self.b1_t)
|
||||||
v_hat = self.v[i] / (1.0 - self.b2_t)
|
v_hat = self.v[i] / (1.0 - self.b2_t)
|
||||||
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
up = (m_hat / (v_hat.sqrt() + self.eps)).shard_like(t) + self.wd * t.detach()
|
||||||
if not self.adam:
|
if not self.adam:
|
||||||
r1 = t.detach().square().sum().sqrt()
|
r1 = t.detach().square().sum().sqrt()
|
||||||
r2 = up.square().sum().sqrt()
|
r2 = up.square().sum().sqrt()
|
||||||
|
|||||||
@@ -396,7 +396,7 @@ class Tensor(OpMixin):
|
|||||||
print(t.shard((t.device, t.device), axis=1).uop)
|
print(t.shard((t.device, t.device), axis=1).uop)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
|
if not isinstance(self.device, str): raise RuntimeError("can't shard a MultiLazyBuffer")
|
||||||
if len(devices) == 1: return self.to(devices[0])
|
if len(devices) == 1: return self.to(devices[0])
|
||||||
devices = tuple(canonicalize_device(x) for x in devices)
|
devices = tuple(canonicalize_device(x) for x in devices)
|
||||||
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
mlb = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices)
|
||||||
@@ -408,6 +408,13 @@ class Tensor(OpMixin):
|
|||||||
"""
|
"""
|
||||||
return self.replace(self.shard(devices, axis))
|
return self.replace(self.shard(devices, axis))
|
||||||
|
|
||||||
|
def shard_like(self, y:Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Shards the tensor the same way as `y` (same devices and axis).
|
||||||
|
"""
|
||||||
|
if isinstance(y.device, str): return self.to(y.device)
|
||||||
|
return self if isinstance(self.device, tuple) and (y.device, y.uop.axis) == (self.device, self.uop.axis) else self.shard(y.device, y.uop.axis)
|
||||||
|
|
||||||
CHUNK_SIZE = 2**20
|
CHUNK_SIZE = 2**20
|
||||||
def fs_load(self, size:int) -> Tensor:
|
def fs_load(self, size:int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -616,7 +616,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||||||
if self.op is Ops.BUFFERIZE: return self.arg.device
|
if self.op is Ops.BUFFERIZE: return self.arg.device
|
||||||
if self.op is Ops.AFTER: return self.src[0]._device
|
if self.op is Ops.AFTER: return self.src[0]._device
|
||||||
if self.op is Ops.MSELECT:
|
if self.op is Ops.MSELECT:
|
||||||
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
|
assert isinstance(self.src[0].device, tuple), f"mselect must be on tuple device, getting {self.src[0].device}"
|
||||||
return self.src[0].device[self.arg]
|
return self.src[0].device[self.arg]
|
||||||
if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src)
|
if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src)
|
||||||
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device
|
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device
|
||||||
|
|||||||
Reference in New Issue
Block a user