diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index e79c34e6c8..7d3b3fada3 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -8,36 +8,36 @@ from tinygrad.dtype import dtypes def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: if not isinstance(buf.device, tuple): return None assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}" - n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape) + ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape) # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically) # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks. - use_all2all = (ALL2ALL >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1)) - use_ring = not use_all2all and (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1)) - if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {n_lbs}x{numel} | {buf.dtype}") + use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1)) + use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1)) + if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}") # contiguous before we copy it buf = buf.contiguous() # naive: copy to all devices. if you shrink later, that'll be handled if not use_ring and not use_all2all: - return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(n_lbs)]) + return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)]) - # chunk data into n_lbs pieces + # chunk data into ndev pieces factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1) - base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs - chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (n_lbs - left), initial=0))) + base, left = divmod(numel // factor, ndev) + chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0))) # reduce-scatter reduced_chunks = [] for i,(s,e) in enumerate(chunks): if use_all2all: - chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(n_lbs)] + chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)] reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i)) else: chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),)) - for step in range(n_lbs-1): - src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs + for step in range(ndev-1): + src, dest = (i+step)%ndev, (i+step+1)%ndev cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None) reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest)) reduced_chunks.append(reduced) @@ -46,12 +46,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: copied_chunks = [] for i,rc in enumerate(reduced_chunks): if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg)) - elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs)))) + elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev)))) else: - this_chunk: list[UOp|None] = [None] * n_lbs - this_chunk[(i+n_lbs-1)%n_lbs] = rc - for step in range(n_lbs-1): - this_chunk[(i+step)%n_lbs] = rc = rc.copy_to_device(buf.device[(i+step)%n_lbs]) + this_chunk: list[UOp|None] = [None] * ndev + this_chunk[(i+ndev-1)%ndev] = rc + for step in range(ndev-1): + this_chunk[(i+step)%ndev] = rc = rc.copy_to_device(buf.device[(i+step)%ndev]) copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk)))) # reassemble