From 836cf42c2ea5c35d25b1aa6e72269bdabec05c25 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 3 Feb 2025 19:00:14 -0500 Subject: [PATCH] fix rand_like for multi (#8880) --- .github/workflows/benchmark.yml | 6 ++---- test/test_multitensor.py | 21 ++++++++++++++------- tinygrad/tensor.py | 3 ++- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index cf5b117aad..d16a6c8a7f 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -300,9 +300,8 @@ jobs: - name: Run 10 MLPerf ResNet50 training steps (6 gpu) run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt - name: Run 10 MLPerf Bert training steps (6 gpu) - # TODO: remove DISABLE_DROPOUT once dropout is fixed # TODO: remove BERT_LAYERS once scheduler is fast - run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt + run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt - uses: actions/upload-artifact@v4 with: name: Speed (NVIDIA Training) @@ -498,9 +497,8 @@ jobs: - name: Run 10 MLPerf ResNet50 training steps (6 gpu) run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt - name: Run 10 MLPerf Bert training steps (6 gpu) - # TODO: remove DISABLE_DROPOUT once dropout is fixed # TODO: remove BERT_LAYERS once scheduler is fast - run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt + run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt - uses: actions/upload-artifact@v4 with: name: Speed (AMD Training) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 9913a04605..841a29e4d4 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -651,14 +651,21 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.lazydata.axis, t2.lazydata.axis) def test_rand_like_from_alu(self): - # TODO: fix this, which will also fix multi device dropout - a = Tensor.ones(4, 4).shard(devices_2, axis=0) - with self.assertRaises(ValueError): - (a + a).rand_like() + a = Tensor.ones(4, 4).shard(devices_4, axis=0) + aa = a + a + self.assertEqual(aa.device, devices_4) + self.assertEqual(aa.lazydata.axis, 0) + raa = aa.rand_like() + self.assertEqual(raa.device, devices_4) + self.assertEqual(raa.lazydata.axis, 0) - b = Tensor.empty(4, 4).shard(devices_2, axis=None) - with self.assertRaises(ValueError): - (a + b).rand_like() + b = Tensor.empty(4, 4).shard(devices_4, axis=None) + ab = a + b + self.assertEqual(ab.device, devices_4) + self.assertEqual(ab.lazydata.axis, 0) + rab = ab.rand_like() + self.assertEqual(rab.device, devices_4) + self.assertEqual(rab.lazydata.axis, 0) @unittest.skip("no longer supports uneven shard") def test_rand_like_uneven_shard(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index de142aa06f..d90b931b47 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -736,7 +736,8 @@ class Tensor(SimpleMathTrait): if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) contiguous = kwargs.pop("contiguous", True) - rands = [Tensor.rand(*lb.shape, device=cast(str, lb.device), dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.src] + sharded_shape = tuple(s//len(self.device) if a==self.lazydata.axis else s for a,s in enumerate(self.shape)) + rands = [Tensor.rand(sharded_shape, device=d, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for d in self.device] return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)