mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix RING with single dest (#14482)
This commit is contained in:
@@ -33,10 +33,6 @@ def _test_allreduce(t:Tensor):
|
||||
b.realize()
|
||||
return aa, b
|
||||
|
||||
def _test_multiple_to_single_device(ring:int) -> Tensor:
|
||||
t = Tensor.empty(32).shard(devices_4, 0).to(Device.DEFAULT)
|
||||
with Context(RING=ring, SCACHE=0): return t.realize()
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "no multi")
|
||||
class TestMultiTensor(unittest.TestCase):
|
||||
@needs_second_gpu
|
||||
@@ -261,13 +257,16 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||
|
||||
def test_multiple_to_single_device_naive(self):
|
||||
t = _test_multiple_to_single_device(0)
|
||||
with Context(RING=0):
|
||||
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(32))
|
||||
|
||||
@unittest.skip("TODO: ring allreduce ignores target device")
|
||||
def test_multiple_to_single_device_ring(self):
|
||||
t = _test_multiple_to_single_device(2)
|
||||
with Context(RING=2):
|
||||
t = Tensor.arange(32).shard(devices_4, 0).to(Device.DEFAULT).realize()
|
||||
self.assertEqual(t.device, Device.DEFAULT)
|
||||
np.testing.assert_equal(t.numpy(), np.arange(32))
|
||||
|
||||
def test_allreduce_all2all(self):
|
||||
with Context(ALL2ALL=2):
|
||||
|
||||
@@ -71,7 +71,8 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
# allgather
|
||||
copied_chunks = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
|
||||
else:
|
||||
this_chunk: list[UOp|None] = [None] * n_lbs
|
||||
this_chunk[(i+n_lbs-1)%n_lbs] = rc
|
||||
|
||||
Reference in New Issue
Block a user