mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
raise RuntimeError for uneven shard [pr] (#8593)
no 7B llama on 6 GPUs skip 70B
This commit is contained in:
16
.github/workflows/benchmark.yml
vendored
16
.github/workflows/benchmark.yml
vendored
@@ -204,10 +204,10 @@ jobs:
|
||||
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
- name: Run LLaMA-3 8B on 6 GPUs
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
- name: Run LLaMA-2 70B
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
|
||||
- name: Run GPT2
|
||||
@@ -391,12 +391,12 @@ jobs:
|
||||
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs
|
||||
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
- name: Run LLaMA-3 8B on 6 GPUs
|
||||
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
- name: Restore amdgpu
|
||||
run: sudo modprobe amdgpu
|
||||
- name: Run LLaMA-2 70B
|
||||
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
|
||||
- name: Run GPT2
|
||||
|
||||
@@ -157,6 +157,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
||||
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
|
||||
N = N * len(devices)
|
||||
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
|
||||
n = X.numpy()
|
||||
X.shard_(devices, shard_axis)
|
||||
@@ -438,6 +439,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
|
||||
assert isinstance(jf.jit_cache[5].prg, graph_d1)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_shard(self):
|
||||
for N in range(1, 6):
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
@@ -450,6 +452,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
|
||||
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_multiple_zeros(self):
|
||||
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
|
||||
for N in (1, 2, 3, 4):
|
||||
@@ -458,6 +461,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
|
||||
np.testing.assert_equal(X.numpy(), data)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_shard_with_empty(self):
|
||||
N = 4
|
||||
X = Tensor.rand(16, 1, 3).contiguous().realize()
|
||||
@@ -470,6 +474,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
# test reshape with empty shard
|
||||
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_multiple_uneven_shard(self):
|
||||
N = 4
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
@@ -531,7 +536,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((26*15,7))
|
||||
|
||||
@unittest.skip("no longer supports splits")
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
def reshape_helper(t0, t, t_axis):
|
||||
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
|
||||
@@ -603,6 +608,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_rand_like_uneven_shard(self):
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
|
||||
t2 = Tensor.rand_like(t)
|
||||
@@ -653,6 +659,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert set(unique) == {0, 2}, unique
|
||||
assert 200 < counts[0] < 312, counts[0]
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_dropout_on_uneven_shard_axis(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).shard(devices_3, axis=0)
|
||||
@@ -814,6 +821,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
|
||||
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven(self):
|
||||
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
|
||||
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
|
||||
|
||||
@@ -551,7 +551,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded_model(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
layer.weight.shard_(devices, 3)
|
||||
@@ -572,7 +572,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded_dict(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
state_dict = {
|
||||
@@ -589,7 +589,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded_model_dict_same_axis(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
layer.weight.shard_(devices, 3)
|
||||
@@ -610,7 +610,8 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded_model_dict_different_axis(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
|
||||
devices5 = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4", f"{Device.DEFAULT}:5")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
layer.weight.shard_(devices, 3)
|
||||
@@ -619,14 +620,14 @@ class TestNN(unittest.TestCase):
|
||||
# different shard axis
|
||||
state_dict = {
|
||||
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, None),
|
||||
'bias': Tensor.randn(5).shard(devices, 0),
|
||||
'bias': Tensor.randn(5).shard(devices5, 0),
|
||||
}
|
||||
load_state_dict(layer, state_dict)
|
||||
|
||||
# NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this?
|
||||
self.assertEqual(layer.weight.device, devices)
|
||||
self.assertEqual(layer.weight.lazydata.axis, None)
|
||||
self.assertEqual(layer.bias.device, devices)
|
||||
self.assertEqual(layer.bias.device, devices5)
|
||||
self.assertEqual(layer.bias.lazydata.axis, 0)
|
||||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
@@ -38,7 +38,7 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
|
||||
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
|
||||
|
||||
def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
|
||||
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
||||
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
||||
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
|
||||
|
||||
class MultiLazyBuffer(MathTrait):
|
||||
|
||||
@@ -399,7 +399,7 @@ class Tensor(SimpleMathTrait):
|
||||
Shards the tensor across the given devices. Optionally specify which axis to shard on.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.empty(2, 3)
|
||||
t = Tensor.empty(2, 4)
|
||||
print(t.shard((t.device, t.device), axis=1).lazydata)
|
||||
```
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user