diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 4d39b1630c..1f79811cef 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -75,9 +75,7 @@ class MultiLazyBuffer(MathTrait): def copy_to_device(self, device:str) -> LazyBuffer: if self.axis is None: # if we already have a copy on the device, return that - for lb in self.real_lbs: - if lb.device == device: return lb - return self.real_lbs[0].copy_to_device(device) + return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device)) # copy lbs to device, pad to final shape, and sum llbs:List[LazyBuffer] = [] for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds): @@ -87,7 +85,7 @@ class MultiLazyBuffer(MathTrait): return functools.reduce(operator.add, llbs) # passthroughs - def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True) + def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs) def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True): return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real) def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real) @@ -104,7 +102,7 @@ class MultiLazyBuffer(MathTrait): # NOTE: they all have to share an axis, we always choose [-1] axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) srcs:List[List[LazyBuffer]] = [] - not_all_real = any(not all(mlb.real) for mlb in msrcs) + not_all_real = not all(all(mlb.real) for mlb in msrcs) new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real assert any(new_real), "output contains no real lb" for mlb in msrcs: @@ -113,8 +111,8 @@ class MultiLazyBuffer(MathTrait): else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds)) new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} # NOTE: const dtype should match real - real_dtype = next(iter(new_real_lbs.values())).dtype - return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real) + new_dtype = next(iter(new_real_lbs.values())).dtype + return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real) def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer: if self.axis is not None and self.axis in axis: @@ -139,20 +137,18 @@ class MultiLazyBuffer(MathTrait): # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1 - assert all(prod(lb.shape[self.axis:]) % prod(arg[new_axis + 1:]) == 0 for lb in self.lbs),\ - f"reshape cannot move items between shards {self.shape} {arg} {self.bounds}" - return MultiLazyBuffer([x.reshape( - tuple(s if a != new_axis else prod(x.shape[self.axis:]) // prod(arg[new_axis + 1:]) for a, s in enumerate(arg)) - ) for x in self.lbs], new_axis, self.real) + assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}" + lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs] + return MultiLazyBuffer(lbs, new_axis, self.real) def pad(self, arg:Tuple[Tuple[sint, sint], ...]): assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}" # pad on shard axis -> fill others with zeros and set real to all True if self.axis is not None and arg[self.axis] != (0,0): # pad back to whole axis, remove real mask - assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time" - assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \ - sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis" + assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time" + dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)] + assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis" return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis) return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)