use copy_multi in alu_multi [pr] (#15143)

* use copy_multi in alu_multi [pr]

* copy to anything
This commit is contained in:
chenyu
2026-03-04 22:53:00 -05:00
committed by GitHub
parent 72a9ed6e23
commit b5370fd52d

View File

@@ -59,8 +59,8 @@ def alu_multi(root:UOp):
# same axis, just copy through
srcs.append(mlb.src[0])
else:
# axis mismatch, unshard it, send it to all devices, and shard it correctly
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
# axis mismatch, copy to all devices, and shard it correctly
srcs.append(copy_multi(mlb, mlb.device)._shard(axis))
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
def reduce_multi(root:UOp, multi:UOp):
@@ -103,8 +103,7 @@ def flip_multi(root:UOp, multi:UOp):
assert multi.axis is None or not root.marg[multi.axis], "flipping not supported on sharded axis"
return multi.src[0].flip([i for i,x in enumerate(root.marg) if x]).multi(multi.axis)
# from multiple devices -> one
def copy_multi(multi:UOp, device:UOp):
def copy_multi(multi:UOp, device:str | tuple[str, ...] | UOp):
assert multi.axis is not None, "all multi ops have axis"
return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)