make ring allreduce chunks a multiple of 2^n if possible (#4302)

in resnet, instead of chunking as [43691, 43691, 43691, 43691, 43690, 43690], chunk as [43712, 43712, 43680, 43680, 43680, 43680] and those can have 32 local.

more than 2X faster for the applicable kernels and overall 1% for resnet
This commit is contained in:
chenyu
2024-04-25 23:45:28 -04:00
committed by GitHub
parent 1e37c4a7a1
commit 1891ebb655

View File

@@ -19,8 +19,9 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
if DEBUG >= 3: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
if not use_ring:
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
base, left = dim // n_lbs, dim % n_lbs
c_lens = [base + 1 if left - i > 0 else base for i in range(n_lbs)]
factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
acc = 0
chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]