mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
make ring allreduce chunks a multiple of 2^n if possible (#4302)
in resnet, instead of chunking as [43691, 43691, 43691, 43691, 43690, 43690], chunk as [43712, 43712, 43680, 43680, 43680, 43680] and those can have 32 local. more than 2X faster for the applicable kernels and overall 1% for resnet
This commit is contained in:
@@ -19,8 +19,9 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
|
||||
if DEBUG >= 3: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
|
||||
if not use_ring:
|
||||
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
base, left = dim // n_lbs, dim % n_lbs
|
||||
c_lens = [base + 1 if left - i > 0 else base for i in range(n_lbs)]
|
||||
factor = max(f for f in [32, 16, 8, 4, 2, 1] if dim % f == 0)
|
||||
base, left = (dim // factor) // n_lbs, (dim // factor) % n_lbs
|
||||
c_lens = [(base + 1) * factor if i < left else base * factor for i in range(n_lbs)]
|
||||
acc = 0
|
||||
chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
|
||||
chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
||||
|
||||
Reference in New Issue
Block a user