mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
minor multi cleanups [pr] (#8625)
This commit is contained in:
@@ -14,11 +14,10 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
|
||||
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
|
||||
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
||||
factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
|
||||
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
||||
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
||||
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
||||
acc = 0
|
||||
chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
|
||||
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
|
||||
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
||||
|
||||
# scatter-reduce
|
||||
@@ -64,9 +63,8 @@ class MultiLazyBuffer(MathTrait):
|
||||
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
||||
|
||||
def copy_to_device(self, device:str) -> UOp:
|
||||
if self.axis is None:
|
||||
# if we already have a copy on the device, return that
|
||||
return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
|
||||
# if we already have a copy on the device, return that
|
||||
if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
|
||||
# copy lbs to device, pad to final shape, and sum
|
||||
llbs:list[UOp] = []
|
||||
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
||||
@@ -78,8 +76,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
# passthroughs
|
||||
@property
|
||||
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
|
||||
def cast(self, dtype:DType, bitcast:bool=False):
|
||||
return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
|
||||
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
||||
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
||||
|
||||
Reference in New Issue
Block a user