From b5370fd52dfa8be7bfb70b05ba8cf0f6f2c37b6f Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 4 Mar 2026 22:53:00 -0500 Subject: [PATCH] use copy_multi in alu_multi [pr] (#15143) * use copy_multi in alu_multi [pr] * copy to anything --- tinygrad/schedule/multi.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 240c655101..cc80af7863 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -59,8 +59,8 @@ def alu_multi(root:UOp): # same axis, just copy through srcs.append(mlb.src[0]) else: - # axis mismatch, unshard it, send it to all devices, and shard it correctly - srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis)) + # axis mismatch, copy to all devices, and shard it correctly + srcs.append(copy_multi(mlb, mlb.device)._shard(axis)) return srcs[0].alu(root.op, *srcs[1:]).multi(axis) def reduce_multi(root:UOp, multi:UOp): @@ -103,8 +103,7 @@ def flip_multi(root:UOp, multi:UOp): assert multi.axis is None or not root.marg[multi.axis], "flipping not supported on sharded axis" return multi.src[0].flip([i for i,x in enumerate(root.marg) if x]).multi(multi.axis) -# from multiple devices -> one -def copy_multi(multi:UOp, device:UOp): +def copy_multi(multi:UOp, device:str | tuple[str, ...] | UOp): assert multi.axis is not None, "all multi ops have axis" return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)