mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 00:08:16 -05:00
simpler MultiLazyBuffer alu [pr] (#8622)
This commit is contained in:
@@ -102,12 +102,10 @@ class MultiLazyBuffer(MathTrait):
|
||||
assert any(new_real), "output contains no real lb"
|
||||
for mlb in msrcs:
|
||||
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
|
||||
elif mlb.axis is None and axis is not None:
|
||||
assert bounds is not None
|
||||
srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
||||
else:
|
||||
assert axis is not None and bounds is not None
|
||||
srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
||||
if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
||||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
||||
new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
||||
# NOTE: const dtype should match real
|
||||
new_dtype = next(iter(new_real_lbs.values())).dtype
|
||||
|
||||
Reference in New Issue
Block a user