diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 240c655101..cc80af7863 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -59,8 +59,8 @@ def alu_multi(root:UOp): # same axis, just copy through srcs.append(mlb.src[0]) else: - # axis mismatch, unshard it, send it to all devices, and shard it correctly - srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis)) + # axis mismatch, copy to all devices, and shard it correctly + srcs.append(copy_multi(mlb, mlb.device)._shard(axis)) return srcs[0].alu(root.op, *srcs[1:]).multi(axis) def reduce_multi(root:UOp, multi:UOp): @@ -103,8 +103,7 @@ def flip_multi(root:UOp, multi:UOp): assert multi.axis is None or not root.marg[multi.axis], "flipping not supported on sharded axis" return multi.src[0].flip([i for i,x in enumerate(root.marg) if x]).multi(multi.axis) -# from multiple devices -> one -def copy_multi(multi:UOp, device:UOp): +def copy_multi(multi:UOp, device:str | tuple[str, ...] | UOp): assert multi.axis is not None, "all multi ops have axis" return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)