mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use copy_multi in alu_multi [pr] (#15143)
* use copy_multi in alu_multi [pr] * copy to anything
This commit is contained in:
@@ -59,8 +59,8 @@ def alu_multi(root:UOp):
|
||||
# same axis, just copy through
|
||||
srcs.append(mlb.src[0])
|
||||
else:
|
||||
# axis mismatch, unshard it, send it to all devices, and shard it correctly
|
||||
srcs.append(mlb.src[0]._unshard(mlb.axis).allreduce(Ops.ADD, mlb.device)._shard(axis))
|
||||
# axis mismatch, copy to all devices, and shard it correctly
|
||||
srcs.append(copy_multi(mlb, mlb.device)._shard(axis))
|
||||
return srcs[0].alu(root.op, *srcs[1:]).multi(axis)
|
||||
|
||||
def reduce_multi(root:UOp, multi:UOp):
|
||||
@@ -103,8 +103,7 @@ def flip_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or not root.marg[multi.axis], "flipping not supported on sharded axis"
|
||||
return multi.src[0].flip([i for i,x in enumerate(root.marg) if x]).multi(multi.axis)
|
||||
|
||||
# from multiple devices -> one
|
||||
def copy_multi(multi:UOp, device:UOp):
|
||||
def copy_multi(multi:UOp, device:str | tuple[str, ...] | UOp):
|
||||
assert multi.axis is not None, "all multi ops have axis"
|
||||
return multi.src[0]._unshard(multi.axis).allreduce(Ops.ADD, device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user