mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 08:48:15 -05:00
cleanup multi [pr] (#7491)
This commit is contained in:
@@ -75,9 +75,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
def copy_to_device(self, device:str) -> LazyBuffer:
|
||||
if self.axis is None:
|
||||
# if we already have a copy on the device, return that
|
||||
for lb in self.real_lbs:
|
||||
if lb.device == device: return lb
|
||||
return self.real_lbs[0].copy_to_device(device)
|
||||
return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
|
||||
# copy lbs to device, pad to final shape, and sum
|
||||
llbs:List[LazyBuffer] = []
|
||||
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
||||
@@ -87,7 +85,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
return functools.reduce(operator.add, llbs)
|
||||
|
||||
# passthroughs
|
||||
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
|
||||
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
|
||||
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
|
||||
return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real)
|
||||
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
|
||||
@@ -104,7 +102,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
|
||||
srcs:List[List[LazyBuffer]] = []
|
||||
not_all_real = any(not all(mlb.real) for mlb in msrcs)
|
||||
not_all_real = not all(all(mlb.real) for mlb in msrcs)
|
||||
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
|
||||
assert any(new_real), "output contains no real lb"
|
||||
for mlb in msrcs:
|
||||
@@ -113,8 +111,8 @@ class MultiLazyBuffer(MathTrait):
|
||||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
||||
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
||||
# NOTE: const dtype should match real
|
||||
real_dtype = next(iter(new_real_lbs.values())).dtype
|
||||
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
|
||||
new_dtype = next(iter(new_real_lbs.values())).dtype
|
||||
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
|
||||
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
||||
if self.axis is not None and self.axis in axis:
|
||||
@@ -139,20 +137,18 @@ class MultiLazyBuffer(MathTrait):
|
||||
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
||||
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
||||
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
|
||||
assert all(prod(lb.shape[self.axis:]) % prod(arg[new_axis + 1:]) == 0 for lb in self.lbs),\
|
||||
f"reshape cannot move items between shards {self.shape} {arg} {self.bounds}"
|
||||
return MultiLazyBuffer([x.reshape(
|
||||
tuple(s if a != new_axis else prod(x.shape[self.axis:]) // prod(arg[new_axis + 1:]) for a, s in enumerate(arg))
|
||||
) for x in self.lbs], new_axis, self.real)
|
||||
assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}"
|
||||
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs]
|
||||
return MultiLazyBuffer(lbs, new_axis, self.real)
|
||||
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
|
||||
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
|
||||
# pad on shard axis -> fill others with zeros and set real to all True
|
||||
if self.axis is not None and arg[self.axis] != (0,0):
|
||||
# pad back to whole axis, remove real mask
|
||||
assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time"
|
||||
assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \
|
||||
sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
|
||||
assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time"
|
||||
dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)]
|
||||
assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
|
||||
return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis)
|
||||
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user