diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index cac2915c71..7fc4b6ad84 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -19,8 +19,9 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]: if DEBUG >= 3: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}") if not use_ring: return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs] - base, left = dim // n_lbs, dim % n_lbs - c_lens = [base + 1 if left - i > 0 else base for i in range(n_lbs)] + factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0) + base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs + c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)] acc = 0 chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0] chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]