mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
all2all (#13902)
* all2all * um * fix * x * um * simler * mypy * fix * t * cmnts
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
from tinygrad import Tensor, Device, GlobalCounters, TinyJit, dtypes
|
from tinygrad import Tensor, Device, GlobalCounters, TinyJit, dtypes
|
||||||
from tinygrad.helpers import getenv, Context, RING, DEBUG
|
from tinygrad.helpers import getenv, Context, DEBUG
|
||||||
|
|
||||||
def test(devs: list[str], N: int, iters:int = 10):
|
def test(devs: list[str], N: int, iters:int = 10, name:str = "allreduce"):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def f(t: Tensor) -> Tensor: t.sum(0).realize()
|
def f(t: Tensor) -> Tensor: t.sum(0).realize()
|
||||||
|
|
||||||
@@ -17,39 +17,33 @@ def test(devs: list[str], N: int, iters:int = 10):
|
|||||||
i_secs = GlobalCounters.time_sum_s
|
i_secs = GlobalCounters.time_sum_s
|
||||||
i_gflops = GlobalCounters.global_ops/i_secs/10**9
|
i_gflops = GlobalCounters.global_ops/i_secs/10**9
|
||||||
i_gbs = (N*4)/i_secs/10**9
|
i_gbs = (N*4)/i_secs/10**9
|
||||||
print(f"{'ring_allreduce' if RING >= 2 else 'naive_allreduce'} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
|
print(f"{name} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
|
||||||
secs += i_secs
|
secs += i_secs
|
||||||
gflops += i_gflops
|
gflops += i_gflops
|
||||||
gbs += i_gbs
|
gbs += i_gbs
|
||||||
|
|
||||||
return (gflops/iters, gbs/iters, secs/iters)
|
return (gflops/iters, gbs/iters, secs/iters)
|
||||||
|
|
||||||
def run(sz, n_gpus=6, iters=10, use_ring=False):
|
def run(sz, n_gpus=6, iters=10, ring=0, all2all=0):
|
||||||
devs = tuple([f"{Device.DEFAULT}:{x}" for x in range(n_gpus)])
|
devs = tuple([f"{Device.DEFAULT}:{x}" for x in range(n_gpus)])
|
||||||
N = sz // dtypes.float32.itemsize
|
N = sz // dtypes.float32.itemsize
|
||||||
with Context(RING=(2 if use_ring else 0), DEBUG=max(DEBUG.value, 2)): return test(devs, N, iters=iters)
|
name = "all2all" if all2all else ("ring" if ring else "naive")
|
||||||
|
with Context(RING=(2 if ring else 0), ALL2ALL=(2 if all2all else 0), JIT_BATCH_SIZE=0, DEBUG=max(DEBUG.value, 2)):
|
||||||
|
return test(devs, N, iters=iters, name=name)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ONLY_RING = getenv("ONLY_RING", 0)
|
|
||||||
n_gpus = getenv("GPUS", 6)
|
n_gpus = getenv("GPUS", 6)
|
||||||
iters = getenv("ITERS", 10)
|
iters = getenv("ITERS", 10)
|
||||||
|
|
||||||
if getenv("BENCHMARK_SPLIT"):
|
|
||||||
l, r = 0, 512
|
|
||||||
while r - l > 1:
|
|
||||||
m = (l + r) // 2
|
|
||||||
(ring_gflops, ring_gbs, ring_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100, use_ring=True)
|
|
||||||
(naive_gflops, naive_gbs, naive_secs) = run(m * 1024 * 4, n_gpus=n_gpus, iters=100, use_ring=False)
|
|
||||||
if ring_secs > naive_secs: l = m
|
|
||||||
else: r = m
|
|
||||||
print("Better split", r * 1024, "elements")
|
|
||||||
else:
|
|
||||||
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
|
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
|
||||||
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
|
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
|
||||||
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus, iters=iters)
|
|
||||||
if not ONLY_RING: (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus, iters=iters)
|
results = {}
|
||||||
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
|
for name, kwargs in [("naive", {}), ("ring", {"ring": 2}), ("all2all", {"all2all": 2})]:
|
||||||
if not ONLY_RING: print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
|
results[name] = run(sz, n_gpus=n_gpus, iters=iters, **kwargs)
|
||||||
|
|
||||||
|
print("\n=== RESULTS ===")
|
||||||
|
for name, (gflops, gbs, secs) in results.items():
|
||||||
|
print(f"{name.upper()}:\n {secs:.6f} seconds/iter\n {gflops:.2f} GFLOP/s\n {gbs:.2f} GB/s")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -256,6 +256,11 @@ class TestMultiTensor(unittest.TestCase):
|
|||||||
a,b = _test_allreduce(Tensor.rand(256, 256))
|
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||||
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||||
|
|
||||||
|
def test_allreduce_all2all(self):
|
||||||
|
with Context(ALL2ALL=2):
|
||||||
|
a,b = _test_allreduce(Tensor.rand(256, 256))
|
||||||
|
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
||||||
|
|
||||||
def test_copy_jit(self):
|
def test_copy_jit(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
|
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
|
||||||
|
|||||||
@@ -181,8 +181,8 @@ JIT, JIT_BATCH_SIZE = ContextVar("JIT", 2 if OSX and ARCH_X86 else 1), ContextVa
|
|||||||
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
||||||
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
||||||
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
|
TRANSCENDENTAL, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("NOLOCALS", 0)
|
||||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, LRU = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("LRU", 1)
|
||||||
LRU = ContextVar("LRU", 1)
|
RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0)
|
||||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||||
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
|
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
|
||||||
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import collections, time
|
import collections, itertools, time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
from tinygrad.helpers import round_up, PROFILE, merge_dicts, getenv, dedup, suppress_finalizing
|
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing
|
||||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
|
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
|
||||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
@@ -49,7 +49,9 @@ class HCQGraph(MultiGraphRunner):
|
|||||||
self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {}
|
self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, int|None]] = {}
|
||||||
|
|
||||||
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
||||||
self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation
|
self.copy_queues: dict[tuple[HCQCompiled, int], HWQueue] = {} # lazy allocation, keyed by (device, queue_idx)
|
||||||
|
self.num_copy_queues: int = getenv("HCQ_NUM_SDMA", 2 if ALL2ALL >= 1 else 1)
|
||||||
|
self.copy_queue_cnt: collections.defaultdict[HCQCompiled, itertools.count] = collections.defaultdict(itertools.count)
|
||||||
|
|
||||||
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
|
self.signals: dict[Any, HCQSignal] = {**{dev: dev.new_signal(value=0) for dev in self.devices if not dev._is_cpu()},
|
||||||
**{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
|
**{"KICK": self.devices[0].new_signal(value=0)}, **{dev: self.devices[0].new_signal(value=0) for dev in self.devices if dev._is_cpu()}}
|
||||||
@@ -85,7 +87,8 @@ class HCQGraph(MultiGraphRunner):
|
|||||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||||
else:
|
else:
|
||||||
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
||||||
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
queue_idx = next(self.copy_queue_cnt[enqueue_dev]) % self.num_copy_queues
|
||||||
|
enqueue_queue = self.copy_queues.setdefault((enqueue_dev, queue_idx), enqueue_dev.hw_copy_queue_t(queue_idx=queue_idx))
|
||||||
|
|
||||||
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
|
out_signal = self.signals.setdefault(enqueue_queue, self.devices[0].new_signal(value=0))
|
||||||
|
|
||||||
@@ -175,14 +178,17 @@ class HCQGraph(MultiGraphRunner):
|
|||||||
|
|
||||||
for dev in self.devices:
|
for dev in self.devices:
|
||||||
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
||||||
if dep_dev in self.copy_queues: self.comp_queues[dev].wait(self.signals[(copy_q:=self.copy_queues[dep_dev])], cast(int, last_j[copy_q]) + 1)
|
for copy_q in self._dev_copy_queues(dep_dev):
|
||||||
|
if copy_q in self.signals: self.comp_queues[dev].wait(self.signals[copy_q], cast(int, last_j[copy_q]) + 1)
|
||||||
|
|
||||||
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
|
self.comp_queues[dev].signal(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev] + 1).bind(dev)
|
||||||
if dev in self.copy_queues: self.copy_queues[dev].bind(dev)
|
for copy_q in self._dev_copy_queues(dev): copy_q.bind(dev)
|
||||||
|
|
||||||
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
self.last_timeline: dict[HCQCompiled, tuple[HCQSignal, int]] = {dev: (dev.timeline_signal, 0) for dev in self.devices}
|
||||||
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
self.queue_signals_to_reset = [self.signals[q] for q in list(self.comp_queues.values()) + list(self.copy_queues.values()) if q in self.signals]
|
||||||
|
|
||||||
|
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
|
||||||
|
|
||||||
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
|
||||||
# Wait and restore signals
|
# Wait and restore signals
|
||||||
self.kickoff_value += 1
|
self.kickoff_value += 1
|
||||||
@@ -205,8 +211,7 @@ class HCQGraph(MultiGraphRunner):
|
|||||||
|
|
||||||
for dev in self.devices:
|
for dev in self.devices:
|
||||||
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
|
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))
|
||||||
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals_local)
|
for copy_queue in self._dev_copy_queues(dev): copy_queue.submit(dev, hcq_var_vals_local)
|
||||||
|
|
||||||
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
|
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())
|
||||||
|
|
||||||
if wait:
|
if wait:
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
|||||||
cpu_devices: list[HCQCompiled] = []
|
cpu_devices: list[HCQCompiled] = []
|
||||||
|
|
||||||
def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType],
|
def __init__(self, device:str, allocator:HCQAllocatorBase, compilers:CompilerSet, runtime, signal_t:Type[SignalType],
|
||||||
comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
|
comp_queue_t:Callable[..., HWQueue], copy_queue_t:Callable[..., HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
|
||||||
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
||||||
|
|
||||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
import functools, itertools, operator
|
import functools, itertools, operator
|
||||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
|
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
|
||||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||||
from tinygrad.device import Device
|
from tinygrad.device import Device
|
||||||
|
|
||||||
@@ -35,44 +35,48 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
|||||||
if not isinstance(buf.device, tuple): return None
|
if not isinstance(buf.device, tuple): return None
|
||||||
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
|
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
|
||||||
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
n_lbs, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
||||||
|
|
||||||
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||||
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||||
use_ring = (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
use_all2all = (ALL2ALL >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
|
||||||
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {buf.dtype}")
|
use_ring = not use_all2all and (RING >= 2 or (n_lbs > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||||
|
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {n_lbs}x{numel} | {buf.dtype}")
|
||||||
|
|
||||||
# contiguous before we copy it
|
# contiguous before we copy it
|
||||||
buf = buf.contiguous()
|
buf = buf.contiguous()
|
||||||
|
|
||||||
# copy to all devices. if you shrink later, that'll be handled
|
# naive: copy to all devices. if you shrink later, that'll be handled
|
||||||
if not use_ring: return functools.reduce(lambda x,y: x.alu(red.arg, y),
|
if not use_ring and not use_all2all:
|
||||||
[UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(len(buf.device))])
|
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(n_lbs)])
|
||||||
|
|
||||||
# new ring reduce
|
# chunk data into n_lbs pieces
|
||||||
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
||||||
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
||||||
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (n_lbs - left), initial=0)))
|
||||||
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
|
|
||||||
|
|
||||||
# extract chunks and scatter-reduce
|
# reduce-scatter
|
||||||
reduced_chunks = []
|
reduced_chunks = []
|
||||||
for i,(s,e) in enumerate(chunks):
|
for i,(s,e) in enumerate(chunks):
|
||||||
chunk = buf.reshape((numel,)).shrink(((s,e),))
|
if use_all2all:
|
||||||
reduced_chunk = chunk
|
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(n_lbs)]
|
||||||
|
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
|
||||||
|
else:
|
||||||
|
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
|
||||||
for step in range(n_lbs-1):
|
for step in range(n_lbs-1):
|
||||||
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
src, dest = (i+step)%n_lbs, (i+step+1)%n_lbs
|
||||||
# copy the chunk from the src device to the dest (operating device), and select the chunk on the dest device
|
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
|
||||||
reduced_chunk = reduced_chunk.copy_to_device(buf.device[dest], src if isinstance(reduced_chunk.device, tuple) else None) \
|
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
|
||||||
.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
|
reduced_chunks.append(reduced)
|
||||||
reduced_chunks.append(reduced_chunk)
|
|
||||||
|
|
||||||
# allgather
|
# allgather
|
||||||
copied_chunks = []
|
copied_chunks = []
|
||||||
for i,c in enumerate(reduced_chunks):
|
for i,rc in enumerate(reduced_chunks):
|
||||||
this_chunk: list[UOp|None] = [None] * len(buf.device)
|
if use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(n_lbs))))
|
||||||
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
|
else:
|
||||||
|
this_chunk: list[UOp|None] = [None] * n_lbs
|
||||||
|
this_chunk[(i+n_lbs-1)%n_lbs] = rc
|
||||||
for step in range(n_lbs-1):
|
for step in range(n_lbs-1):
|
||||||
dest = (i+step)%n_lbs
|
this_chunk[(i+step)%n_lbs] = rc = rc.copy_to_device(buf.device[(i+step)%n_lbs])
|
||||||
this_chunk[dest] = c = c.copy_to_device(buf.device[dest])
|
|
||||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
|
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
|
||||||
|
|
||||||
# reassemble
|
# reassemble
|
||||||
|
|||||||
Reference in New Issue
Block a user