From 7fb1c7af6101f757880e83bc5c007b85afb29c06 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 14 Jan 2025 22:25:23 -0500 Subject: [PATCH] minor multi cleanups [pr] (#8625) --- tinygrad/multi.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index c9990534ce..0dbb3520cc 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -14,11 +14,10 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]: if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}") if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs] - factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0) + factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1) base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left) - acc = 0 - chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0] + chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0))) chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs] # scatter-reduce @@ -64,9 +63,8 @@ class MultiLazyBuffer(MathTrait): def __repr__(self): return f"" def copy_to_device(self, device:str) -> UOp: - if self.axis is None: - # if we already have a copy on the device, return that - return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device)) + # if we already have a copy on the device, return that + if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device)) # copy lbs to device, pad to final shape, and sum llbs:list[UOp] = [] for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds): @@ -78,8 +76,7 @@ class MultiLazyBuffer(MathTrait): # passthroughs @property def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs) - def cast(self, dtype:DType, bitcast:bool=False): - return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real) + def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real) def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real) def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real) def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)