remove custom splits in Tensor.shard [pr] (#8602)

towards even split only
This commit is contained in:
chenyu
2025-01-13 21:29:13 -05:00
committed by GitHub
parent 227d96d7a3
commit d443e91d82
2 changed files with 16 additions and 21 deletions

View File

@@ -460,18 +460,15 @@ class TestMultiTensor(unittest.TestCase):
def test_uneven_shard_with_empty(self):
N = 4
X = Tensor.rand(16, 1, 17).contiguous().realize()
X = Tensor.rand(16, 1, 3).contiguous().realize()
np_x = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# test empty shard
np.testing.assert_equal(X.shard(devices, 0, (2, 2, 12, 0)).numpy(), np_x)
np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x)
# test reshape with empty shard
np.testing.assert_equal(X.shard(devices, 0, (2, 2, 12, 0)).reshape(8, 1, 34).numpy(), np_x.reshape(8, 1, 34))
# test elementwise with empty shard
np.testing.assert_equal((X.shard(devices, 0, (2, 2, 12, 0)) + X.shard(devices, 0, (0, 0, 1, 15))).numpy(), np_x + np_x)
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
def test_multiple_uneven_shard(self):
N = 4
@@ -479,8 +476,8 @@ class TestMultiTensor(unittest.TestCase):
Y = Tensor.rand(4, 1, 257).contiguous().realize()
np_x, np_y = X.numpy(), Y.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2, (2, 38, 47, 170))
Y.shard_(devices, 2, (34, 53, 51, 119))
X.shard_(devices, 2)
Y.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), np_x)
np.testing.assert_equal(Y.numpy(), np_y)
np.testing.assert_equal((X + Y).numpy(), np_x + np_y)
@@ -534,6 +531,7 @@ class TestMultiTensor(unittest.TestCase):
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7))
@unittest.skip("no longer supports splits")
def test_reshape_on_axis_uneven(self):
def reshape_helper(t0, t, t_axis):
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
@@ -606,7 +604,7 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
def test_rand_like_uneven_shard(self):
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1, splits=(14, 7, 21))
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
@@ -657,7 +655,7 @@ class TestMultiTensor(unittest.TestCase):
def test_dropout_on_uneven_shard_axis(self):
with Tensor.train():
X = Tensor.ones(256).shard(devices_3, axis=0, splits=(100, 50, 106))
X = Tensor.ones(256).shard(devices_3, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique

View File

@@ -394,33 +394,30 @@ class Tensor(SimpleMathTrait):
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
return self.replace(real)
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None) -> Tensor:
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
"""
Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
Shards the tensor across the given devices. Optionally specify which axis to shard on.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3)
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
print(t.shard((t.device, t.device), axis=1).lazydata)
```
"""
assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer"
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
if axis is not None:
axis = self._resolve_dim(axis)
if splits is None:
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
sz = ceildiv(total, len(devices))
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
sz = ceildiv(total, len(devices))
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
bounds = tuple(itertools.pairwise(itertools.accumulate(splits, initial=0)))
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None):
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
"""
Shards the tensor across the given devices in place.
"""
return self.replace(self.shard(devices, axis, splits))
return self.replace(self.shard(devices, axis))
@staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor: