mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
raise RuntimeError for uneven shard [pr] (#8593)
no 7B llama on 6 GPUs skip 70B
This commit is contained in:
@@ -157,6 +157,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
||||
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
|
||||
N = N * len(devices)
|
||||
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
|
||||
n = X.numpy()
|
||||
X.shard_(devices, shard_axis)
|
||||
@@ -438,6 +439,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
|
||||
assert isinstance(jf.jit_cache[5].prg, graph_d1)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_shard(self):
|
||||
for N in range(1, 6):
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
@@ -450,6 +452,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
|
||||
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_multiple_zeros(self):
|
||||
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
|
||||
for N in (1, 2, 3, 4):
|
||||
@@ -458,6 +461,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
|
||||
np.testing.assert_equal(X.numpy(), data)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_shard_with_empty(self):
|
||||
N = 4
|
||||
X = Tensor.rand(16, 1, 3).contiguous().realize()
|
||||
@@ -470,6 +474,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
# test reshape with empty shard
|
||||
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_multiple_uneven_shard(self):
|
||||
N = 4
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
@@ -531,7 +536,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((26*15,7))
|
||||
|
||||
@unittest.skip("no longer supports splits")
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
def reshape_helper(t0, t, t_axis):
|
||||
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
|
||||
@@ -603,6 +608,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_rand_like_uneven_shard(self):
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
|
||||
t2 = Tensor.rand_like(t)
|
||||
@@ -653,6 +659,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert set(unique) == {0, 2}, unique
|
||||
assert 200 < counts[0] < 312, counts[0]
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_dropout_on_uneven_shard_axis(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).shard(devices_3, axis=0)
|
||||
@@ -814,6 +821,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
|
||||
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven(self):
|
||||
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
|
||||
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
|
||||
|
||||
Reference in New Issue
Block a user