mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
all_reduce cosmetic change [pr] (#7490)
This commit is contained in:
@@ -11,36 +11,35 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
||||
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
||||
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
||||
bop = REDUCE_ALU[op]
|
||||
n_lbs, shape, numel = len(lbs), lbs[0].shape, prod(lbs[0].shape)
|
||||
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
|
||||
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
||||
n_lbs, dim = len(lbs), prod(lbs[0].shape)
|
||||
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
use_ring = (RING >= 2 or (n_lbs > 2 and dim > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
|
||||
if not use_ring:
|
||||
return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
|
||||
base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
|
||||
c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
|
||||
factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
|
||||
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
||||
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
||||
acc = 0
|
||||
chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
|
||||
chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
||||
chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
|
||||
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
||||
|
||||
# Scatter-reduce step
|
||||
for step in range(n_lbs - 1):
|
||||
# scatter-reduce
|
||||
for step in range(n_lbs-1):
|
||||
for i in range(len(chunks)):
|
||||
s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
|
||||
chunked[r][i] = chunked[r][i].alu(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
|
||||
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
||||
chunked[dest][i] = chunked[dest][i].alu(bop, chunked[src][i].copy_to_device(chunked[dest][i].device, force=True))
|
||||
|
||||
# Allgather step
|
||||
for step in range(n_lbs - 1):
|
||||
# allgather
|
||||
for step in range(n_lbs-1):
|
||||
for i in range(len(chunks)):
|
||||
s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
|
||||
chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
|
||||
src, dest = (i+step-1)%n_lbs, (i+step)%n_lbs
|
||||
chunked[dest][i] = chunked[src][i].copy_to_device(chunked[dest][i].device, force=True)
|
||||
|
||||
# Assemble chunks back
|
||||
pads = [((s,dim-e),) for s,e in chunks]
|
||||
return [functools.reduce(operator.add, [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
|
||||
# assemble chunks back
|
||||
pads = [((s,numel-e),) for s,e in chunks]
|
||||
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
|
||||
|
||||
def to_sharded(lbs:List[LazyBuffer], axis:int, bounds: Tuple[Tuple[int, int], ...]) -> List[LazyBuffer]:
|
||||
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
||||
|
||||
Reference in New Issue
Block a user