raise RuntimeError for uneven shard [pr] (#8593)

no 7B llama on 6 GPUs

skip 70B
This commit is contained in:
chenyu
2025-01-14 14:51:48 -05:00
committed by GitHub
parent d5a646d492
commit 393eec3201
5 changed files with 26 additions and 17 deletions

View File

@@ -204,10 +204,10 @@ jobs:
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
- name: Run LLaMA-3 8B on 6 GPUs
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
- name: Run LLaMA-2 70B
run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
# - name: Run LLaMA-3 8B on 6 GPUs
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
# - name: Run LLaMA-2 70B
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
- name: Run Mixtral 8x7B
run: time NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
- name: Run GPT2
@@ -391,12 +391,12 @@ jobs:
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
- name: Run LLaMA-3 8B on 6 GPUs
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
# - name: Run LLaMA-3 8B on 6 GPUs
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
- name: Restore amdgpu
run: sudo modprobe amdgpu
- name: Run LLaMA-2 70B
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
# - name: Run LLaMA-2 70B
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
- name: Run Mixtral 8x7B
run: time AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
- name: Run GPT2

View File

@@ -157,6 +157,7 @@ class TestMultiTensor(unittest.TestCase):
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
N = N * len(devices)
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
n = X.numpy()
X.shard_(devices, shard_axis)
@@ -438,6 +439,7 @@ class TestMultiTensor(unittest.TestCase):
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
assert isinstance(jf.jit_cache[5].prg, graph_d1)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard(self):
for N in range(1, 6):
X = Tensor.rand(4, 1, 257).contiguous().realize()
@@ -450,6 +452,7 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
@unittest.skip("no longer supports uneven shard")
def test_uneven_multiple_zeros(self):
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
for N in (1, 2, 3, 4):
@@ -458,6 +461,7 @@ class TestMultiTensor(unittest.TestCase):
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
np.testing.assert_equal(X.numpy(), data)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard_with_empty(self):
N = 4
X = Tensor.rand(16, 1, 3).contiguous().realize()
@@ -470,6 +474,7 @@ class TestMultiTensor(unittest.TestCase):
# test reshape with empty shard
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
@unittest.skip("no longer supports uneven shard")
def test_multiple_uneven_shard(self):
N = 4
X = Tensor.rand(4, 1, 257).contiguous().realize()
@@ -531,7 +536,7 @@ class TestMultiTensor(unittest.TestCase):
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7))
@unittest.skip("no longer supports splits")
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
def reshape_helper(t0, t, t_axis):
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
@@ -603,6 +608,7 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
@unittest.skip("no longer supports uneven shard")
def test_rand_like_uneven_shard(self):
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
t2 = Tensor.rand_like(t)
@@ -653,6 +659,7 @@ class TestMultiTensor(unittest.TestCase):
assert set(unique) == {0, 2}, unique
assert 200 < counts[0] < 312, counts[0]
@unittest.skip("no longer supports uneven shard")
def test_dropout_on_uneven_shard_axis(self):
with Tensor.train():
X = Tensor.ones(256).shard(devices_3, axis=0)
@@ -814,6 +821,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
@unittest.skip("no longer supports uneven shard")
def test_uneven(self):
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)

View File

@@ -551,7 +551,7 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
@@ -572,7 +572,7 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_dict(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
state_dict = {
@@ -589,7 +589,7 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model_dict_same_axis(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
@@ -610,7 +610,8 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model_dict_different_axis(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
devices5 = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4", f"{Device.DEFAULT}:5")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
@@ -619,14 +620,14 @@ class TestNN(unittest.TestCase):
# different shard axis
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, None),
'bias': Tensor.randn(5).shard(devices, 0),
'bias': Tensor.randn(5).shard(devices5, 0),
}
load_state_dict(layer, state_dict)
# NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this?
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.weight.lazydata.axis, None)
self.assertEqual(layer.bias.device, devices)
self.assertEqual(layer.bias.device, devices5)
self.assertEqual(layer.bias.lazydata.axis, 0)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())

View File

@@ -38,7 +38,7 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
class MultiLazyBuffer(MathTrait):

View File

@@ -399,7 +399,7 @@ class Tensor(SimpleMathTrait):
Shards the tensor across the given devices. Optionally specify which axis to shard on.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3)
t = Tensor.empty(2, 4)
print(t.shard((t.device, t.device), axis=1).lazydata)
```
"""