From 25440f0f72b8ae98925313c58f4ff0456ee54789 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 31 Dec 2025 16:38:32 +0300 Subject: [PATCH] all2all (#13902) * all2all * um * fix * x * um * simler * mypy * fix * t * cmnts --- ...xternal_benchmark_multitensor_allreduce.py | 38 ++++++------- test/test_multitensor.py | 5 ++ tinygrad/helpers.py | 4 +- tinygrad/runtime/graph/hcq.py | 21 +++++--- tinygrad/runtime/support/hcq.py | 2 +- tinygrad/schedule/multi.py | 54 ++++++++++--------- 6 files changed, 66 insertions(+), 58 deletions(-) diff --git a/test/external/external_benchmark_multitensor_allreduce.py b/test/external/external_benchmark_multitensor_allreduce.py index e2fd001a0d..7dcc210547 100644 --- a/test/external/external_benchmark_multitensor_allreduce.py +++ b/test/external/external_benchmark_multitensor_allreduce.py @@ -1,7 +1,7 @@ from tinygrad import Tensor, Device, GlobalCounters, TinyJit, dtypes -from tinygrad.helpers import getenv, Context, RING, DEBUG +from tinygrad.helpers import getenv, Context, DEBUG -def test(devs: list[str], N: int, iters:int = 10): +def test(devs: list[str], N: int, iters:int = 10, name:str = "allreduce"): @TinyJit def f(t: Tensor) -> Tensor: t.sum(0).realize() @@ -17,39 +17,33 @@ def test(devs: list[str], N: int, iters:int = 10): i_secs = GlobalCounters.time_sum_s i_gflops = GlobalCounters.global_ops/i_secs/10**9 i_gbs = (N*4)/i_secs/10**9 - print(f"{'ring_allreduce' if RING >= 2 else 'naive_allreduce'} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s") + print(f"{name} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s") secs += i_secs gflops += i_gflops gbs += i_gbs return (gflops/iters, gbs/iters, secs/iters) -def run(sz, n_gpus=6, iters=10, use_ring=False): +def run(sz, n_gpus=6, iters=10, ring=0, all2all=0): devs = tuple([f"{Device.DEFAULT}:{x}" for x in range(n_gpus)]) N = sz // dtypes.float32.itemsize - with Context(RING=(2 if use_ring else 0), DEBUG=max(DEBUG.value, 2)): return test(devs, N, iters=iters) + name = "all2all" if all2all else ("ring" if ring else "naive") + with Context(RING=(2 if ring else 0), ALL2ALL=(2 if all2all else 0), JIT_BATCH_SIZE=0, DEBUG=max(DEBUG.value, 2)): + return test(devs, N, iters=iters, name=name) def main(): - ONLY_RING = getenv("ONLY_RING", 0) n_gpus = getenv("GPUS", 6) iters = getenv("ITERS", 10) + sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu + print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.") - if getenv("BENCHMARK_SPLIT"): - l, r = 0, 512 - while r - l > 1: - m = (l + r) // 2 - (ring_gflops, ring_gbs, ring_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100, use_ring=True) - (naive_gflops, naive_gbs, naive_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100, use_ring=False) - if ring_secs > naive_secs: l = m - else: r = m - print("Better split", r * 1024, "elements") - else: - sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu - print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.") - (ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus, iters=iters) - if not ONLY_RING: (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus, iters=iters) - print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s") - if not ONLY_RING: print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s") + results = {} + for name, kwargs in [("naive", {}), ("ring", {"ring": 2}), ("all2all", {"all2all": 2})]: + results[name] = run(sz, n_gpus=n_gpus, iters=iters, **kwargs) + + print("\n=== RESULTS ===") + for name, (gflops, gbs, secs) in results.items(): + print(f"{name.upper()}:\n {secs:.6f} seconds/iter\n {gflops:.2f} GFLOP/s\n {gbs:.2f} GB/s") if __name__ == "__main__": main() diff --git a/test/test_multitensor.py b/test/test_multitensor.py index e090766904..3916ca9115 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -256,6 +256,11 @@ class TestMultiTensor(unittest.TestCase): a,b = _test_allreduce(Tensor.rand(256, 256)) np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) + def test_allreduce_all2all(self): + with Context(ALL2ALL=2): + a,b = _test_allreduce(Tensor.rand(256, 256)) + np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) + def test_copy_jit(self): @TinyJit def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e997985ac3..5e08ef2fa4 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -181,8 +181,8 @@ JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVa WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0) -SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) -LRU = ContextVar("LRU", 1) +SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1) +RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0) CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index f1093917c3..6eaf1fa7b9 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,6 +1,6 @@ -import collections, time +import collections, itertools, time from typing import Any, cast -from tinygrad.helpers import round_up, PROFILE, merge_dicts, getenv, dedup, suppress_finalizing +from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent from tinygrad.dtype import dtypes @@ -49,7 +49,9 @@ class HCQGraph(MultiGraphRunner): self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {} self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices} - self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation + self.copy_queues: dict[tuple[HCQCompiled, int], HWQueue] = {} # lazy allocation, keyed by (device, queue_idx) + self.num_copy_queues: int = getenv("HCQ_NUM_SDMA", 2 if ALL2ALL >= 1 else 1) + self.copy_queue_cnt: collections.defaultdict[HCQCompiled, itertools.count] = collections.defaultdict(itertools.count) self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()}, **{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}} @@ -85,7 +87,8 @@ class HCQGraph(MultiGraphRunner): enqueue_queue = self.comp_queues[enqueue_dev] else: assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue" - enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t()) + queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues + enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx)) out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0)) @@ -175,14 +178,17 @@ class HCQGraph(MultiGraphRunner): for dev in self.devices: for dep_dev in list(self.copy_to_devs[dev]) + [dev]: - if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1) + for copy_q in self._dev_copy_queues(dep_dev): + if copy_q in self.signals: self.comp_queues[dev].wait(self.signals[copy_q], cast(int, last_j[copy_q]) + 1) self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev) - if dev in self.copy_queues: self.copy_queues[dev].bind(dev) + for copy_q in self._dev_copy_queues(dev): copy_q.bind(dev) self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices} self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals] + def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev] + def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: # Wait and restore signals self.kickoff_value += 1 @@ -205,8 +211,7 @@ class HCQGraph(MultiGraphRunner): for dev in self.devices: self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {})) - if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local) - + for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local) self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline()) if wait: diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index 43c73ddf7c..f64dbd6188 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -354,7 +354,7 @@ class HCQCompiled(Compiled, Generic[SignalType]): cpu_devices: list[HCQCompiled] = [] def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType], - comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000): + comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000): self.device_id:int = int(device.split(":")[1]) if ":" in device else 0 from tinygrad.runtime.graph.hcq import HCQGraph diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 23a81ea6c7..a989466983 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,6 +1,6 @@ from typing import cast import functools, itertools, operator -from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv +from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite from tinygrad.device import Device @@ -35,45 +35,49 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: if not isinstance(buf.device, tuple): return None assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}" n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape) + # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically) # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks. - use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1)) - if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {buf.dtype}") + use_all2all = (ALL2ALL >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1)) + use_ring = not use_all2all and (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1)) + if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {n_lbs}x{numel} | {buf.dtype}") # contiguous before we copy it buf = buf.contiguous() - # copy to all devices. if you shrink later, that'll be handled - if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y), - [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(len(buf.device))]) + # naive: copy to all devices. if you shrink later, that'll be handled + if not use_ring and not use_all2all: + return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(n_lbs)]) - # new ring reduce + # chunk data into n_lbs pieces factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1) base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs - chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left) - chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0))) + chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (n_lbs - left), initial=0))) - # extract chunks and scatter-reduce + # reduce-scatter reduced_chunks = [] for i,(s,e) in enumerate(chunks): - chunk = buf.reshape((numel,)).shrink(((s,e),)) - reduced_chunk = chunk - for step in range(n_lbs-1): - src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs - # copy the chunk from the src device to the dest (operating device), and select the chunk on the dest device - reduced_chunk = reduced_chunk.copy_to_device(buf.device[dest], src if isinstance(reduced_chunk.device, tuple) else None) \ - .alu(red.arg, chunk.copy_to_device(buf.device[dest], dest)) - reduced_chunks.append(reduced_chunk) + if use_all2all: + chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(n_lbs)] + reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i)) + else: + chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),)) + for step in range(n_lbs-1): + src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs + cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None) + reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest)) + reduced_chunks.append(reduced) # allgather copied_chunks = [] - for i,c in enumerate(reduced_chunks): - this_chunk: list[UOp|None] = [None] * len(buf.device) - this_chunk[(i+len(buf.device)-1)%n_lbs] = c - for step in range(n_lbs-1): - dest = (i+step)%n_lbs - this_chunk[dest] = c = c.copy_to_device(buf.device[dest]) - copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk)))) + for i,rc in enumerate(reduced_chunks): + if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs)))) + else: + this_chunk: list[UOp|None] = [None] * n_lbs + this_chunk[(i+n_lbs-1)%n_lbs] = rc + for step in range(n_lbs-1): + this_chunk[(i+step)%n_lbs] = rc = rc.copy_to_device(buf.device[(i+step)%n_lbs]) + copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk)))) # reassemble pads = [((s,numel-e),) for s,e in chunks]