From 877a5d4c455f7e26d5575fa9f0b9c207c165c20a Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 19 Feb 2026 09:02:15 -0500 Subject: [PATCH] improve types and simplify allgather in multi [pr] (#14878) --- tinygrad/schedule/multi.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 7d3b3fada3..b6ca8b85aa 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,4 +1,3 @@ -from typing import cast import functools, itertools from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite @@ -29,7 +28,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0))) # reduce-scatter - reduced_chunks = [] + reduced_chunks:list[UOp] = [] for i,(s,e) in enumerate(chunks): if use_all2all: chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)] @@ -43,16 +42,15 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: reduced_chunks.append(reduced) # allgather - copied_chunks = [] + copied_chunks:list[UOp] = [] for i,rc in enumerate(reduced_chunks): if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg)) elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev)))) else: - this_chunk: list[UOp|None] = [None] * ndev - this_chunk[(i+ndev-1)%ndev] = rc + chain:list[UOp] = [rc] for step in range(ndev-1): - this_chunk[(i+step)%ndev] = rc = rc.copy_to_device(buf.device[(i+step)%ndev]) - copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk)))) + chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev])) + copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev)))) # reassemble return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape) @@ -97,7 +95,7 @@ def alu_multi(root:UOp): axis = root.axis assert axis is not None - srcs = [] + srcs:list[UOp] = [] for mlb in msrcs: if mlb.axis == axis: # same axis, just copy through