fix RING with single dest (#14482)

This commit is contained in:
chenyu
2026-02-01 12:14:46 -05:00
committed by GitHub
parent 3ff390159b
commit 6deeccc192
2 changed files with 8 additions and 8 deletions

View File

@@ -33,10 +33,6 @@ def _test_allreduce(t:Tensor):
b.realize()
return aa, b
def _test_multiple_to_single_device(ring:int) -> Tensor:
t = Tensor.empty(32).shard(devices_4, 0).to(Device.DEFAULT)
with Context(RING=ring, SCACHE=0): return t.realize()
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestMultiTensor(unittest.TestCase):
@needs_second_gpu
@@ -261,13 +257,16 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_multiple_to_single_device_naive(self):
t = _test_multiple_to_single_device(0)
with Context(RING=0):
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
@unittest.skip("TODO: ring allreduce ignores target device")
def test_multiple_to_single_device_ring(self):
t = _test_multiple_to_single_device(2)
with Context(RING=2):
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
def test_allreduce_all2all(self):
with Context(ALL2ALL=2):

View File

@@ -71,7 +71,8 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
# allgather
copied_chunks = []
for i,rc in enumerate(reduced_chunks):
if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
else:
this_chunk: list[UOp|None] = [None] * n_lbs
this_chunk[(i+n_lbs-1)%n_lbs] = rc