touchup multi dtype in elementwise (#5305)

only need to check real once, also added type annotation
This commit is contained in:
chenyu
2024-07-06 11:54:12 -04:00
committed by GitHub
parent 7ddda9f9f1
commit 356e5d2e54

View File

@@ -104,7 +104,7 @@ class MultiLazyBuffer:
# NOTE: they all have to share an axis, we always choose [-1]
axis = axes[-1] if len(axes := dedup([x.axis for x in msrcs if x.axis is not None])) else None
srcs = []
srcs:List[List[LazyBuffer]] = []
not_all_real = any(not all(mlb.real) for mlb in msrcs)
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
assert any(new_real), "output contains no real lb"
@@ -115,8 +115,7 @@ class MultiLazyBuffer:
new_real_lbs = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs),new_real)) if r}
# NOTE: const dtype should match real
real_dtype = next(iter(new_real_lbs.values())).dtype
return MultiLazyBuffer([new_real_lbs[i] if r else lsrcs[0].const(0).cast(real_dtype) \
for i, (lsrcs,r) in enumerate(zip(zip(*srcs),new_real))], axis, new_real)
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))