* all2all

* um

* fix

* x

* um

* simler

* mypy

* fix

* t

* cmnts
This commit is contained in:
nimlgen
2025-12-31 16:38:32 +03:00
committed by GitHub
parent f7ee644950
commit 25440f0f72
6 changed files with 66 additions and 58 deletions

View File

@@ -1,7 +1,7 @@
from tinygrad import Tensor, Device, GlobalCounters, TinyJit, dtypes
from tinygrad.helpers import getenv, Context, RING, DEBUG
from tinygrad.helpers import getenv, Context, DEBUG
def test(devs: list[str], N: int, iters:int = 10):
def test(devs: list[str], N: int, iters:int = 10, name:str = "allreduce"):
@TinyJit
def f(t: Tensor) -> Tensor: t.sum(0).realize()
@@ -17,39 +17,33 @@ def test(devs: list[str], N: int, iters:int = 10):
i_secs = GlobalCounters.time_sum_s
i_gflops = GlobalCounters.global_ops/i_secs/10**9
i_gbs = (N*4)/i_secs/10**9
print(f"{'ring_allreduce' if RING >= 2 else 'naive_allreduce'} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
print(f"{name} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
secs += i_secs
gflops += i_gflops
gbs += i_gbs
return (gflops/iters, gbs/iters, secs/iters)
def run(sz, n_gpus=6, iters=10, use_ring=False):
def run(sz, n_gpus=6, iters=10, ring=0, all2all=0):
devs = tuple([f"{Device.DEFAULT}:{x}" for x in range(n_gpus)])
N = sz // dtypes.float32.itemsize
with Context(RING=(2 if use_ring else 0), DEBUG=max(DEBUG.value, 2)): return test(devs, N, iters=iters)
name = "all2all" if all2all else ("ring" if ring else "naive")
with Context(RING=(2 if ring else 0), ALL2ALL=(2 if all2all else 0), JIT_BATCH_SIZE=0, DEBUG=max(DEBUG.value, 2)):
return test(devs, N, iters=iters, name=name)
def main():
ONLY_RING = getenv("ONLY_RING", 0)
n_gpus = getenv("GPUS", 6)
iters = getenv("ITERS", 10)
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
if getenv("BENCHMARK_SPLIT"):
l, r = 0, 512
while r - l > 1:
m = (l + r) // 2
(ring_gflops, ring_gbs, ring_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100, use_ring=True)
(naive_gflops, naive_gbs, naive_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100, use_ring=False)
if ring_secs > naive_secs: l = m
else: r = m
print("Better split", r * 1024, "elements")
else:
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus, iters=iters)
if not ONLY_RING: (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus, iters=iters)
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
if not ONLY_RING: print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
results = {}
for name, kwargs in [("naive", {}), ("ring", {"ring": 2}), ("all2all", {"all2all": 2})]:
results[name] = run(sz, n_gpus=n_gpus, iters=iters, **kwargs)
print("\n=== RESULTS ===")
for name, (gflops, gbs, secs) in results.items():
print(f"{name.upper()}:\n {secs:.6f} seconds/iter\n {gflops:.2f} GFLOP/s\n {gbs:.2f} GB/s")
if __name__ == "__main__":
main()

View File

@@ -256,6 +256,11 @@ class TestMultiTensor(unittest.TestCase):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_all2all(self):
with Context(ALL2ALL=2):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_copy_jit(self):
@TinyJit
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)

View File

@@ -181,8 +181,8 @@ JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVa
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
LRU = ContextVar("LRU", 1)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)
RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)

View File

@@ -1,6 +1,6 @@
import collections, time
import collections, itertools, time
from typing import Any, cast
from tinygrad.helpers import round_up, PROFILE, merge_dicts, getenv, dedup, suppress_finalizing
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.dtype import dtypes
@@ -49,7 +49,9 @@ class HCQGraph(MultiGraphRunner):
self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {}
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
self.copy_queues: dict[tuple[HCQCompiled, int], HWQueue] = {} # lazy allocation, keyed by (device, queue_idx)
self.num_copy_queues: int = getenv("HCQ_NUM_SDMA", 2 if ALL2ALL >= 1 else 1)
self.copy_queue_cnt: collections.defaultdict[HCQCompiled, itertools.count] = collections.defaultdict(itertools.count)
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
**{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
@@ -85,7 +87,8 @@ class HCQGraph(MultiGraphRunner):
enqueue_queue = self.comp_queues[enqueue_dev]
else:
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx))
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
@@ -175,14 +178,17 @@ class HCQGraph(MultiGraphRunner):
for dev in self.devices:
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
for copy_q in self._dev_copy_queues(dep_dev):
if copy_q in self.signals: self.comp_queues[dev].wait(self.signals[copy_q], cast(int, last_j[copy_q]) + 1)
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
for copy_q in self._dev_copy_queues(dev): copy_q.bind(dev)
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Wait and restore signals
self.kickoff_value += 1
@@ -205,8 +211,7 @@ class HCQGraph(MultiGraphRunner):
for dev in self.devices:
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local)
for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local)
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
if wait:

View File

@@ -354,7 +354,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
cpu_devices: list[HCQCompiled] = []
def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType],
comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
from tinygrad.runtime.graph.hcq import HCQGraph

View File

@@ -1,6 +1,6 @@
from typing import cast
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
from tinygrad.device import Device
@@ -35,45 +35,49 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {buf.dtype}")
use_all2all = (ALL2ALL >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
use_ring = not use_all2all and (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {n_lbs}x{numel} | {buf.dtype}")
# contiguous before we copy it
buf = buf.contiguous()
# copy to all devices. if you shrink later, that'll be handled
if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y),
[UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(len(buf.device))])
# naive: copy to all devices. if you shrink later, that'll be handled
if not use_ring and not use_all2all:
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(n_lbs)])
# new ring reduce
# chunk data into n_lbs pieces
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (n_lbs - left), initial=0)))
# extract chunks and scatter-reduce
# reduce-scatter
reduced_chunks = []
for i,(s,e) in enumerate(chunks):
chunk = buf.reshape((numel,)).shrink(((s,e),))
reduced_chunk = chunk
for step in range(n_lbs-1):
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
# copy the chunk from the src device to the dest (operating device), and select the chunk on the dest device
reduced_chunk = reduced_chunk.copy_to_device(buf.device[dest], src if isinstance(reduced_chunk.device, tuple) else None) \
.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced_chunk)
if use_all2all:
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(n_lbs)]
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
else:
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
for step in range(n_lbs-1):
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced)
# allgather
copied_chunks = []
for i,c in enumerate(reduced_chunks):
this_chunk: list[UOp|None] = [None] * len(buf.device)
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
for step in range(n_lbs-1):
dest = (i+step)%n_lbs
this_chunk[dest] = c = c.copy_to_device(buf.device[dest])
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
for i,rc in enumerate(reduced_chunks):
if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
else:
this_chunk: list[UOp|None] = [None] * n_lbs
this_chunk[(i+n_lbs-1)%n_lbs] = rc
for step in range(n_lbs-1):
this_chunk[(i+step)%n_lbs] = rc = rc.copy_to_device(buf.device[(i+step)%n_lbs])
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
# reassemble
pads = [((s,numel-e),) for s,e in chunks]