simpler MultiLazyBuffer alu [pr] (#8622)

This commit is contained in:
chenyu
2025-01-14 19:19:13 -05:00
committed by GitHub
parent 930728c069
commit 7860a80801

View File

@@ -102,12 +102,10 @@ class MultiLazyBuffer(MathTrait):
assert any(new_real), "output contains no real lb"
for mlb in msrcs:
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
elif mlb.axis is None and axis is not None:
assert bounds is not None
srcs.append(to_sharded(mlb.lbs, axis, bounds))
else:
assert axis is not None and bounds is not None
srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
# NOTE: const dtype should match real
new_dtype = next(iter(new_real_lbs.values())).dtype