From 1891ebb655829d4b984c2bcdde105aa799a57313 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 25 Apr 2024 23:45:28 -0400 Subject: [PATCH] make ring allreduce chunks a multiple of 2^n if possible (#4302) in resnet, instead of chunking as [43691, 43691, 43691, 43691, 43690, 43690], chunk as [43712, 43712, 43680, 43680, 43680, 43680] and those can have 32 local. more than 2X faster for the applicable kernels and overall 1% for resnet --- tinygrad/features/multi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index cac2915c71..7fc4b6ad84 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -19,8 +19,9 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]: if DEBUG >= 3: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}") if not use_ring: return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs] - base, left = dim // n_lbs, dim % n_lbs - c_lens = [base + 1 if left - i > 0 else base for i in range(n_lbs)] + factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0) + base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs + c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)] acc = 0 chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0] chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]