* all2all

* um

* fix

* x

* um

* simler

* mypy

* fix

* t

* cmnts
This commit is contained in:
nimlgen
2025-12-31 16:38:32 +03:00
committed by GitHub
parent f7ee644950
commit 25440f0f72
6 changed files with 66 additions and 58 deletions

View File

@@ -1,7 +1,7 @@
from tinygrad import Tensor, Device, GlobalCounters, TinyJit, dtypes from tinygrad import Tensor, Device, GlobalCounters, TinyJit, dtypes
from tinygrad.helpers import getenv, Context, RING, DEBUG from tinygrad.helpers import getenv, Context, DEBUG
def test(devs: list[str], N: int, iters:int = 10): def test(devs: list[str], N: int, iters:int = 10, name:str = "allreduce"):
@TinyJit @TinyJit
def f(t: Tensor) -> Tensor: t.sum(0).realize() def f(t: Tensor) -> Tensor: t.sum(0).realize()
@@ -17,39 +17,33 @@ def test(devs: list[str], N: int, iters:int = 10):
i_secs = GlobalCounters.time_sum_s i_secs = GlobalCounters.time_sum_s
i_gflops = GlobalCounters.global_ops/i_secs/10**9 i_gflops = GlobalCounters.global_ops/i_secs/10**9
i_gbs = (N*4)/i_secs/10**9 i_gbs = (N*4)/i_secs/10**9
print(f"{'ring_allreduce' if RING >= 2 else 'naive_allreduce'} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s") print(f"{name} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
secs += i_secs secs += i_secs
gflops += i_gflops gflops += i_gflops
gbs += i_gbs gbs += i_gbs
return (gflops/iters, gbs/iters, secs/iters) return (gflops/iters, gbs/iters, secs/iters)
def run(sz, n_gpus=6, iters=10, use_ring=False): def run(sz, n_gpus=6, iters=10, ring=0, all2all=0):
devs = tuple([f"{Device.DEFAULT}:{x}" for x in range(n_gpus)]) devs = tuple([f"{Device.DEFAULT}:{x}" for x in range(n_gpus)])
N = sz // dtypes.float32.itemsize N = sz // dtypes.float32.itemsize
with Context(RING=(2 if use_ring else 0), DEBUG=max(DEBUG.value, 2)): return test(devs, N, iters=iters) name = "all2all" if all2all else ("ring" if ring else "naive")
with Context(RING=(2 if ring else 0), ALL2ALL=(2 if all2all else 0), JIT_BATCH_SIZE=0, DEBUG=max(DEBUG.value, 2)):
return test(devs, N, iters=iters, name=name)
def main(): def main():
ONLY_RING = getenv("ONLY_RING", 0)
n_gpus = getenv("GPUS", 6) n_gpus = getenv("GPUS", 6)
iters = getenv("ITERS", 10) iters = getenv("ITERS", 10)
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
if getenv("BENCHMARK_SPLIT"): results = {}
l, r = 0, 512 for name, kwargs in [("naive", {}), ("ring", {"ring": 2}), ("all2all", {"all2all": 2})]:
while r - l > 1: results[name] = run(sz, n_gpus=n_gpus, iters=iters, **kwargs)
m = (l + r) // 2
(ring_gflops, ring_gbs, ring_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100, use_ring=True) print("\n=== RESULTS ===")
(naive_gflops, naive_gbs, naive_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100, use_ring=False) for name, (gflops, gbs, secs) in results.items():
if ring_secs > naive_secs: l = m print(f"{name.upper()}:\n {secs:.6f} seconds/iter\n {gflops:.2f} GFLOP/s\n {gbs:.2f} GB/s")
else: r = m
print("Better split", r * 1024, "elements")
else:
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus, iters=iters)
if not ONLY_RING: (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus, iters=iters)
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
if not ONLY_RING: print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -256,6 +256,11 @@ class TestMultiTensor(unittest.TestCase):
a,b = _test_allreduce(Tensor.rand(256, 256)) a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5) np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_all2all(self):
with Context(ALL2ALL=2):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_copy_jit(self): def test_copy_jit(self):
@TinyJit @TinyJit
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1) def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)

View File

@@ -181,8 +181,8 @@ JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVa
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0) TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)
LRU = ContextVar("LRU", 1) RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0) VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0) CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)

View File

@@ -1,6 +1,6 @@
import collections, time import collections, itertools, time
from typing import Any, cast from typing import Any, cast
from tinygrad.helpers import round_up, PROFILE, merge_dicts, getenv, dedup, suppress_finalizing from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
@@ -49,7 +49,9 @@ class HCQGraph(MultiGraphRunner):
self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {} self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {}
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices} self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation self.copy_queues: dict[tuple[HCQCompiled, int], HWQueue] = {} # lazy allocation, keyed by (device, queue_idx)
self.num_copy_queues: int = getenv("HCQ_NUM_SDMA", 2 if ALL2ALL >= 1 else 1)
self.copy_queue_cnt: collections.defaultdict[HCQCompiled, itertools.count] = collections.defaultdict(itertools.count)
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()}, self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
**{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}} **{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
@@ -85,7 +87,8 @@ class HCQGraph(MultiGraphRunner):
enqueue_queue = self.comp_queues[enqueue_dev] enqueue_queue = self.comp_queues[enqueue_dev]
else: else:
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue" assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t()) queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx))
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0)) out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
@@ -175,14 +178,17 @@ class HCQGraph(MultiGraphRunner):
for dev in self.devices: for dev in self.devices:
for dep_dev in list(self.copy_to_devs[dev]) + [dev]: for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1) for copy_q in self._dev_copy_queues(dep_dev):
if copy_q in self.signals: self.comp_queues[dev].wait(self.signals[copy_q], cast(int, last_j[copy_q]) + 1)
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev) self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
if dev in self.copy_queues: self.copy_queues[dev].bind(dev) for copy_q in self._dev_copy_queues(dev): copy_q.bind(dev)
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices} self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals] self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Wait and restore signals # Wait and restore signals
self.kickoff_value += 1 self.kickoff_value += 1
@@ -205,8 +211,7 @@ class HCQGraph(MultiGraphRunner):
for dev in self.devices: for dev in self.devices:
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {})) self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local) for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local)
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline()) self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
if wait: if wait:

View File

@@ -354,7 +354,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
cpu_devices: list[HCQCompiled] = [] cpu_devices: list[HCQCompiled] = []
def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType], def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType],
comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000): comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0 self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
from tinygrad.runtime.graph.hcq import HCQGraph from tinygrad.runtime.graph.hcq import HCQGraph

View File

@@ -1,6 +1,6 @@
from typing import cast from typing import cast
import functools, itertools, operator import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
from tinygrad.device import Device from tinygrad.device import Device
@@ -35,45 +35,49 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None if not isinstance(buf.device, tuple): return None
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}" assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape) n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically) # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks. # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1)) use_all2all = (ALL2ALL >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {buf.dtype}") use_ring = not use_all2all and (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {n_lbs}x{numel} | {buf.dtype}")
# contiguous before we copy it # contiguous before we copy it
buf = buf.contiguous() buf = buf.contiguous()
# copy to all devices. if you shrink later, that'll be handled # naive: copy to all devices. if you shrink later, that'll be handled
if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y), if not use_ring and not use_all2all:
[UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(len(buf.device))]) return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(n_lbs)])
# new ring reduce # chunk data into n_lbs pieces
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1) factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left) chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (n_lbs - left), initial=0)))
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
# extract chunks and scatter-reduce # reduce-scatter
reduced_chunks = [] reduced_chunks = []
for i,(s,e) in enumerate(chunks): for i,(s,e) in enumerate(chunks):
chunk = buf.reshape((numel,)).shrink(((s,e),)) if use_all2all:
reduced_chunk = chunk chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(n_lbs)]
for step in range(n_lbs-1): reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs else:
# copy the chunk from the src device to the dest (operating device), and select the chunk on the dest device chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
reduced_chunk = reduced_chunk.copy_to_device(buf.device[dest], src if isinstance(reduced_chunk.device, tuple) else None) \ for step in range(n_lbs-1):
.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest)) src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
reduced_chunks.append(reduced_chunk) cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced)
# allgather # allgather
copied_chunks = [] copied_chunks = []
for i,c in enumerate(reduced_chunks): for i,rc in enumerate(reduced_chunks):
this_chunk: list[UOp|None] = [None] * len(buf.device) if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
this_chunk[(i+len(buf.device)-1)%n_lbs] = c else:
for step in range(n_lbs-1): this_chunk: list[UOp|None] = [None] * n_lbs
dest = (i+step)%n_lbs this_chunk[(i+n_lbs-1)%n_lbs] = rc
this_chunk[dest] = c = c.copy_to_device(buf.device[dest]) for step in range(n_lbs-1):
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk)))) this_chunk[(i+step)%n_lbs] = rc = rc.copy_to_device(buf.device[(i+step)%n_lbs])
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
# reassemble # reassemble
pads = [((s,numel-e),) for s,e in chunks] pads = [((s,numel-e),) for s,e in chunks]