diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 1d5236d835..0596f22055 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -204,10 +204,10 @@ jobs: run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt - name: Run LLaMA-3 8B on 4 GPUs run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt - - name: Run LLaMA-3 8B on 6 GPUs - run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt - - name: Run LLaMA-2 70B - run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt + # - name: Run LLaMA-3 8B on 6 GPUs + # run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt + # - name: Run LLaMA-2 70B + # run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt - name: Run Mixtral 8x7B run: time NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt - name: Run GPT2 @@ -391,12 +391,12 @@ jobs: run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt - name: Run LLaMA-3 8B on 4 GPUs run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt - - name: Run LLaMA-3 8B on 6 GPUs - run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt + # - name: Run LLaMA-3 8B on 6 GPUs + # run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt - name: Restore amdgpu run: sudo modprobe amdgpu - - name: Run LLaMA-2 70B - run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt + # - name: Run LLaMA-2 70B + # run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt - name: Run Mixtral 8x7B run: time AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt - name: Run GPT2 diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 67cb9757e1..f53056c406 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -157,6 +157,7 @@ class TestMultiTensor(unittest.TestCase): strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)), strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1))) def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign): + N = N * len(devices) X = Tensor.rand(N*N).reshape(N, N).mul(sign) n = X.numpy() X.shard_(devices, shard_axis) @@ -438,6 +439,7 @@ class TestMultiTensor(unittest.TestCase): assert isinstance(jf.jit_cache[4].prg, BufferCopy) assert isinstance(jf.jit_cache[5].prg, graph_d1) + @unittest.skip("no longer supports uneven shard") def test_uneven_shard(self): for N in range(1, 6): X = Tensor.rand(4, 1, 257).contiguous().realize() @@ -450,6 +452,7 @@ class TestMultiTensor(unittest.TestCase): np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1))) np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1))) + @unittest.skip("no longer supports uneven shard") def test_uneven_multiple_zeros(self): for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []): for N in (1, 2, 3, 4): @@ -458,6 +461,7 @@ class TestMultiTensor(unittest.TestCase): X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize() np.testing.assert_equal(X.numpy(), data) + @unittest.skip("no longer supports uneven shard") def test_uneven_shard_with_empty(self): N = 4 X = Tensor.rand(16, 1, 3).contiguous().realize() @@ -470,6 +474,7 @@ class TestMultiTensor(unittest.TestCase): # test reshape with empty shard np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6)) + @unittest.skip("no longer supports uneven shard") def test_multiple_uneven_shard(self): N = 4 X = Tensor.rand(4, 1, 257).contiguous().realize() @@ -531,7 +536,7 @@ class TestMultiTensor(unittest.TestCase): with self.assertRaises((AssertionError, ValueError)): t0.reshape((26*15,7)) - @unittest.skip("no longer supports splits") + @unittest.skip("no longer supports uneven shard") def test_reshape_on_axis_uneven(self): def reshape_helper(t0, t, t_axis): np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy()) @@ -603,6 +608,7 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) + @unittest.skip("no longer supports uneven shard") def test_rand_like_uneven_shard(self): t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1) t2 = Tensor.rand_like(t) @@ -653,6 +659,7 @@ class TestMultiTensor(unittest.TestCase): assert set(unique) == {0, 2}, unique assert 200 < counts[0] < 312, counts[0] + @unittest.skip("no longer supports uneven shard") def test_dropout_on_uneven_shard_axis(self): with Tensor.train(): X = Tensor.ones(256).shard(devices_3, axis=0) @@ -814,6 +821,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3) np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3) + @unittest.skip("no longer supports uneven shard") def test_uneven(self): t = Tensor.arange(24).reshape(3, 8).contiguous().realize() t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0) diff --git a/test/test_nn.py b/test/test_nn.py index e36b805c48..738574a989 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -551,7 +551,7 @@ class TestNN(unittest.TestCase): @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") def test_load_state_dict_sharded_model(self): - devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2") + devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3") layer = Conv2d(3, 5, kernel_size=3) layer.weight.shard_(devices, 3) @@ -572,7 +572,7 @@ class TestNN(unittest.TestCase): @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") def test_load_state_dict_sharded_dict(self): - devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2") + devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3") layer = Conv2d(3, 5, kernel_size=3) state_dict = { @@ -589,7 +589,7 @@ class TestNN(unittest.TestCase): @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") def test_load_state_dict_sharded_model_dict_same_axis(self): - devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2") + devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3") layer = Conv2d(3, 5, kernel_size=3) layer.weight.shard_(devices, 3) @@ -610,7 +610,8 @@ class TestNN(unittest.TestCase): @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI") def test_load_state_dict_sharded_model_dict_different_axis(self): - devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2") + devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3") + devices5 = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4", f"{Device.DEFAULT}:5") layer = Conv2d(3, 5, kernel_size=3) layer.weight.shard_(devices, 3) @@ -619,14 +620,14 @@ class TestNN(unittest.TestCase): # different shard axis state_dict = { 'weight': Tensor.randn(5, 3, 3, 3).shard(devices, None), - 'bias': Tensor.randn(5).shard(devices, 0), + 'bias': Tensor.randn(5).shard(devices5, 0), } load_state_dict(layer, state_dict) # NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this? self.assertEqual(layer.weight.device, devices) self.assertEqual(layer.weight.lazydata.axis, None) - self.assertEqual(layer.bias.device, devices) + self.assertEqual(layer.bias.device, devices5) self.assertEqual(layer.bias.lazydata.axis, 0) np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy()) np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy()) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 07a3430f7b..919d33a97f 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -38,7 +38,7 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]: return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked] def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]: - if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}") + if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}") return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))] class MultiLazyBuffer(MathTrait): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 60b9735bd9..171c7afe31 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -399,7 +399,7 @@ class Tensor(SimpleMathTrait): Shards the tensor across the given devices. Optionally specify which axis to shard on. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.empty(2, 3) + t = Tensor.empty(2, 4) print(t.shard((t.device, t.device), axis=1).lazydata) ``` """