cleanup multi [pr] (#7491)

This commit is contained in:
chenyu
2024-11-02 16:38:34 -04:00
committed by GitHub
parent f8376b3766
commit dc9ffb41a8

View File

@@ -75,9 +75,7 @@ class MultiLazyBuffer(MathTrait):
def copy_to_device(self, device:str) -> LazyBuffer:
if self.axis is None:
# if we already have a copy on the device, return that
for lb in self.real_lbs:
if lb.device == device: return lb
return self.real_lbs[0].copy_to_device(device)
return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
# copy lbs to device, pad to final shape, and sum
llbs:List[LazyBuffer] = []
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
@@ -87,7 +85,7 @@ class MultiLazyBuffer(MathTrait):
return functools.reduce(operator.add, llbs)
# passthroughs
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True)
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True):
return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real)
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
@@ -104,7 +102,7 @@ class MultiLazyBuffer(MathTrait):
# NOTE: they all have to share an axis, we always choose [-1]
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
srcs:List[List[LazyBuffer]] = []
not_all_real = any(not all(mlb.real) for mlb in msrcs)
not_all_real = not all(all(mlb.real) for mlb in msrcs)
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
assert any(new_real), "output contains no real lb"
for mlb in msrcs:
@@ -113,8 +111,8 @@ class MultiLazyBuffer(MathTrait):
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
# NOTE: const dtype should match real
real_dtype = next(iter(new_real_lbs.values())).dtype
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
new_dtype = next(iter(new_real_lbs.values())).dtype
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
if self.axis is not None and self.axis in axis:
@@ -139,20 +137,18 @@ class MultiLazyBuffer(MathTrait):
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
assert all(prod(lb.shape[self.axis:]) % prod(arg[new_axis + 1:]) == 0 for lb in self.lbs),\
f"reshape cannot move items between shards {self.shape} {arg} {self.bounds}"
return MultiLazyBuffer([x.reshape(
tuple(s if a != new_axis else prod(x.shape[self.axis:]) // prod(arg[new_axis + 1:]) for a, s in enumerate(arg))
) for x in self.lbs], new_axis, self.real)
assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}"
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs]
return MultiLazyBuffer(lbs, new_axis, self.real)
def pad(self, arg:Tuple[Tuple[sint, sint], ...]):
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
# pad on shard axis -> fill others with zeros and set real to all True
if self.axis is not None and arg[self.axis] != (0,0):
# pad back to whole axis, remove real mask
assert all(arg[i] == (0, 0) or i == self.axis for i in range(len(self.shape))), "cannot pad sharded and non-sharded axis at the same time"
assert arg[self.axis] == (sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i < self.real.index(True)), \
sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time"
dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)]
assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis)
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)