mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
remove custom splits in Tensor.shard [pr] (#8602)
towards even split only
This commit is contained in:
@@ -460,18 +460,15 @@ class TestMultiTensor(unittest.TestCase):
|
||||
|
||||
def test_uneven_shard_with_empty(self):
|
||||
N = 4
|
||||
X = Tensor.rand(16, 1, 17).contiguous().realize()
|
||||
X = Tensor.rand(16, 1, 3).contiguous().realize()
|
||||
np_x = X.numpy()
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
|
||||
# test empty shard
|
||||
np.testing.assert_equal(X.shard(devices, 0, (2, 2, 12, 0)).numpy(), np_x)
|
||||
np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x)
|
||||
|
||||
# test reshape with empty shard
|
||||
np.testing.assert_equal(X.shard(devices, 0, (2, 2, 12, 0)).reshape(8, 1, 34).numpy(), np_x.reshape(8, 1, 34))
|
||||
|
||||
# test elementwise with empty shard
|
||||
np.testing.assert_equal((X.shard(devices, 0, (2, 2, 12, 0)) + X.shard(devices, 0, (0, 0, 1, 15))).numpy(), np_x + np_x)
|
||||
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
|
||||
|
||||
def test_multiple_uneven_shard(self):
|
||||
N = 4
|
||||
@@ -479,8 +476,8 @@ class TestMultiTensor(unittest.TestCase):
|
||||
Y = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
np_x, np_y = X.numpy(), Y.numpy()
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
X.shard_(devices, 2, (2, 38, 47, 170))
|
||||
Y.shard_(devices, 2, (34, 53, 51, 119))
|
||||
X.shard_(devices, 2)
|
||||
Y.shard_(devices, 2)
|
||||
np.testing.assert_equal(X.numpy(), np_x)
|
||||
np.testing.assert_equal(Y.numpy(), np_y)
|
||||
np.testing.assert_equal((X + Y).numpy(), np_x + np_y)
|
||||
@@ -534,6 +531,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((26*15,7))
|
||||
|
||||
@unittest.skip("no longer supports splits")
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
def reshape_helper(t0, t, t_axis):
|
||||
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
|
||||
@@ -606,7 +604,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
|
||||
def test_rand_like_uneven_shard(self):
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1, splits=(14, 7, 21))
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
|
||||
t2 = Tensor.rand_like(t)
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
@@ -657,7 +655,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
|
||||
def test_dropout_on_uneven_shard_axis(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).shard(devices_3, axis=0, splits=(100, 50, 106))
|
||||
X = Tensor.ones(256).shard(devices_3, axis=0)
|
||||
output = X.dropout(0.5).numpy()
|
||||
unique, counts = np.unique(output, return_counts=True)
|
||||
assert set(unique) == {0, 2}, unique
|
||||
|
||||
@@ -394,33 +394,30 @@ class Tensor(SimpleMathTrait):
|
||||
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
|
||||
return self.replace(real)
|
||||
|
||||
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None) -> Tensor:
|
||||
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
||||
"""
|
||||
Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
|
||||
Shards the tensor across the given devices. Optionally specify which axis to shard on.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.empty(2, 3)
|
||||
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
|
||||
print(t.shard((t.device, t.device), axis=1).lazydata)
|
||||
```
|
||||
|
||||
"""
|
||||
assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer"
|
||||
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
|
||||
if axis is not None:
|
||||
axis = self._resolve_dim(axis)
|
||||
if splits is None:
|
||||
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
|
||||
sz = ceildiv(total, len(devices))
|
||||
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
|
||||
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
||||
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
|
||||
sz = ceildiv(total, len(devices))
|
||||
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
|
||||
bounds = tuple(itertools.pairwise(itertools.accumulate(splits, initial=0)))
|
||||
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None):
|
||||
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
|
||||
"""
|
||||
Shards the tensor across the given devices in place.
|
||||
"""
|
||||
return self.replace(self.shard(devices, axis, splits))
|
||||
return self.replace(self.shard(devices, axis))
|
||||
|
||||
@staticmethod
|
||||
def from_uop(y:UOp, **kwargs) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user