diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 4e9d52fb29..375a6825f4 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -33,10 +33,6 @@ def _test_allreduce(t:Tensor): b.realize() return aa, b -def _test_multiple_to_single_device(ring:int) -> Tensor: - t = Tensor.empty(32).shard(devices_4, 0).to(Device.DEFAULT) - with Context(RING=ring, SCACHE=0): return t.realize() - @unittest.skipIf(not_support_multi_device(), "no multi") class TestMultiTensor(unittest.TestCase): @needs_second_gpu @@ -261,13 +257,16 @@ class TestMultiTensor(unittest.TestCase): np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) def test_multiple_to_single_device_naive(self): - t = _test_multiple_to_single_device(0) + with Context(RING=0): + t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize() self.assertEqual(t.device, Device.DEFAULT) + np.testing.assert_equal(t.numpy(), np.arange(32)) - @unittest.skip("TODO: ring allreduce ignores target device") def test_multiple_to_single_device_ring(self): - t = _test_multiple_to_single_device(2) + with Context(RING=2): + t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize() self.assertEqual(t.device, Device.DEFAULT) + np.testing.assert_equal(t.numpy(), np.arange(32)) def test_allreduce_all2all(self): with Context(ALL2ALL=2): diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index b6a502c845..b9a41c3f77 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -71,7 +71,8 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: # allgather copied_chunks = [] for i,rc in enumerate(reduced_chunks): - if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs)))) + if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg)) + elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs)))) else: this_chunk: list[UOp|None] = [None] * n_lbs this_chunk[(i+n_lbs-1)%n_lbs] = rc