hot fix use DEBUG >= 3 for allreduce message (#3869)

This commit is contained in:
chenyu
2024-03-21 23:40:44 -04:00
committed by GitHub
parent 6729f20aab
commit dca69df197

View File

@@ -16,7 +16,7 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_ring = (RING >= 2 or (n_lbs > 2 and dim > 256_000 and RING >= 1))
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
if DEBUG >= 3: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{dim} | {lbs[0].dtype}")
if not use_ring:
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
base, left = dim // n_lbs, dim % n_lbs