From 7860a808015568d4324985a481a97174be338ae5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 14 Jan 2025 19:19:13 -0500 Subject: [PATCH] simpler MultiLazyBuffer alu [pr] (#8622) --- tinygrad/multi.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 8e65f46ac9..c9990534ce 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -102,12 +102,10 @@ class MultiLazyBuffer(MathTrait): assert any(new_real), "output contains no real lb" for mlb in msrcs: if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs) - elif mlb.axis is None and axis is not None: - assert bounds is not None - srcs.append(to_sharded(mlb.lbs, axis, bounds)) else: assert axis is not None and bounds is not None - srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds)) + if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds)) + else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds)) new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} # NOTE: const dtype should match real new_dtype = next(iter(new_real_lbs.values())).dtype