minor multi cleanups [pr] (#8625)

This commit is contained in:
chenyu
2025-01-14 22:25:23 -05:00
committed by GitHub
parent 504ad08e73
commit 7fb1c7af61

View File

@@ -14,11 +14,10 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
acc = 0
chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
# scatter-reduce
@@ -64,9 +63,8 @@ class MultiLazyBuffer(MathTrait):
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
def copy_to_device(self, device:str) -> UOp:
if self.axis is None:
# if we already have a copy on the device, return that
return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
# if we already have a copy on the device, return that
if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
# copy lbs to device, pad to final shape, and sum
llbs:list[UOp] = []
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
@@ -78,8 +76,7 @@ class MultiLazyBuffer(MathTrait):
# passthroughs
@property
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
def cast(self, dtype:DType, bitcast:bool=False):
return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)