improve types and simplify allgather in multi [pr] (#14878)

This commit is contained in:
chenyu
2026-02-19 09:02:15 -05:00
committed by GitHub
parent 9317e96881
commit 877a5d4c45

View File

@@ -1,4 +1,3 @@
from typing import cast
import functools, itertools
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
@@ -29,7 +28,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
# reduce-scatter
reduced_chunks = []
reduced_chunks:list[UOp] = []
for i,(s,e) in enumerate(chunks):
if use_all2all:
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
@@ -43,16 +42,15 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
reduced_chunks.append(reduced)
# allgather
copied_chunks = []
copied_chunks:list[UOp] = []
for i,rc in enumerate(reduced_chunks):
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
else:
this_chunk: list[UOp|None] = [None] * ndev
this_chunk[(i+ndev-1)%ndev] = rc
chain:list[UOp] = [rc]
for step in range(ndev-1):
this_chunk[(i+step)%ndev] = rc = rc.copy_to_device(buf.device[(i+step)%ndev])
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
# reassemble
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
@@ -97,7 +95,7 @@ def alu_multi(root:UOp):
axis = root.axis
assert axis is not None
srcs = []
srcs:list[UOp] = []
for mlb in msrcs:
if mlb.axis == axis:
# same axis, just copy through