mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
improve types and simplify allgather in multi [pr] (#14878)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from typing import cast
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||
@@ -29,7 +28,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
|
||||
|
||||
# reduce-scatter
|
||||
reduced_chunks = []
|
||||
reduced_chunks:list[UOp] = []
|
||||
for i,(s,e) in enumerate(chunks):
|
||||
if use_all2all:
|
||||
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
|
||||
@@ -43,16 +42,15 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
reduced_chunks.append(reduced)
|
||||
|
||||
# allgather
|
||||
copied_chunks = []
|
||||
copied_chunks:list[UOp] = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
|
||||
else:
|
||||
this_chunk: list[UOp|None] = [None] * ndev
|
||||
this_chunk[(i+ndev-1)%ndev] = rc
|
||||
chain:list[UOp] = [rc]
|
||||
for step in range(ndev-1):
|
||||
this_chunk[(i+step)%ndev] = rc = rc.copy_to_device(buf.device[(i+step)%ndev])
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
|
||||
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
|
||||
|
||||
# reassemble
|
||||
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
|
||||
@@ -97,7 +95,7 @@ def alu_multi(root:UOp):
|
||||
axis = root.axis
|
||||
assert axis is not None
|
||||
|
||||
srcs = []
|
||||
srcs:list[UOp] = []
|
||||
for mlb in msrcs:
|
||||
if mlb.axis == axis:
|
||||
# same axis, just copy through
|
||||
|
||||
Reference in New Issue
Block a user