mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
minor handle_allreduce cleanup [pr] (#14876)
no more lbs, also use a divmod
This commit is contained in:
@@ -8,36 +8,36 @@ from tinygrad.dtype import dtypes
|
||||
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
if not isinstance(buf.device, tuple): return None
|
||||
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
|
||||
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
||||
ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
||||
|
||||
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
use_all2all = (ALL2ALL >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
|
||||
use_ring = not use_all2all and (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {n_lbs}x{numel} | {buf.dtype}")
|
||||
use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
|
||||
use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}")
|
||||
|
||||
# contiguous before we copy it
|
||||
buf = buf.contiguous()
|
||||
|
||||
# naive: copy to all devices. if you shrink later, that'll be handled
|
||||
if not use_ring and not use_all2all:
|
||||
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(n_lbs)])
|
||||
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
|
||||
|
||||
# chunk data into n_lbs pieces
|
||||
# chunk data into ndev pieces
|
||||
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
||||
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
||||
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (n_lbs - left), initial=0)))
|
||||
base, left = divmod(numel // factor, ndev)
|
||||
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
|
||||
|
||||
# reduce-scatter
|
||||
reduced_chunks = []
|
||||
for i,(s,e) in enumerate(chunks):
|
||||
if use_all2all:
|
||||
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(n_lbs)]
|
||||
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
|
||||
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
|
||||
else:
|
||||
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
|
||||
for step in range(n_lbs-1):
|
||||
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
||||
for step in range(ndev-1):
|
||||
src, dest = (i+step)%ndev, (i+step+1)%ndev
|
||||
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
|
||||
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
|
||||
reduced_chunks.append(reduced)
|
||||
@@ -46,12 +46,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
copied_chunks = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
|
||||
else:
|
||||
this_chunk: list[UOp|None] = [None] * n_lbs
|
||||
this_chunk[(i+n_lbs-1)%n_lbs] = rc
|
||||
for step in range(n_lbs-1):
|
||||
this_chunk[(i+step)%n_lbs] = rc = rc.copy_to_device(buf.device[(i+step)%n_lbs])
|
||||
this_chunk: list[UOp|None] = [None] * ndev
|
||||
this_chunk[(i+ndev-1)%ndev] = rc
|
||||
for step in range(ndev-1):
|
||||
this_chunk[(i+step)%ndev] = rc = rc.copy_to_device(buf.device[(i+step)%ndev])
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
|
||||
|
||||
# reassemble
|
||||
|
||||
Reference in New Issue
Block a user