From 29d58013870e19829eae42bc6eecc24961ffed41 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 11 Aug 2023 13:22:07 -0400 Subject: [PATCH] distributed collectives (#1519) * feat: world * feat: tests * feat: no more backwards * feat: recv into * feat: whoops * feat: test in ci * feat: some debug logging * feat: workflow naming * feat: need to set pythonpath * feat: just send to same device * feat: allreduce * feat: test * feat: need contiguous * feat: test in ci * feat: exit with correct code * feat: don't need that * feat: opencl wait_for just doesn't work * feat: synchronize on out * feat: try? * feat: try again? * feat: add extra realizes * feat: print * feat: seed * feat: tol * feat: test ones and zeros * feat: remove print * feat: are you just flaky * feat: seperate scatter and gather? * feat: just try synchronizing * feat: remove print again * feat: bring back difference * feat: no sync * feat: revert that * feat: back to wait_for * fix: typo --- .github/workflows/test.yml | 1 + extra/dist/collectives.py | 43 ++++++++++++++++++++ extra/dist/world.py | 2 +- test/external/dist/test_collectives.py | 56 ++++++++++++++++++++++++++ test/external/dist/test_world.py | 4 ++ tinygrad/runtime/ops_gpu.py | 2 +- 6 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 extra/dist/collectives.py create mode 100644 test/external/dist/test_collectives.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b91005f485..8ac1e08db6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -158,6 +158,7 @@ jobs: name: Test multigpu run: | PYTHONPATH="." python test/external/dist/test_world.py + PYTHONPATH="." python test/external/dist/test_collectives.py testmetalwebgpu: name: Metal and WebGPU Tests diff --git a/extra/dist/collectives.py b/extra/dist/collectives.py new file mode 100644 index 0000000000..f499198e40 --- /dev/null +++ b/extra/dist/collectives.py @@ -0,0 +1,43 @@ +from tinygrad.tensor import Tensor, Device +from tinygrad.helpers import getenv + +from extra.dist import world + +def allreduce(t:Tensor, cache_id=None) -> Tensor: + RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE") + cache_id = f"{RANK}-{cache_id}" if cache_id is not None else None + + # flatten + flattened = t.flatten() + + # pad to evenly divide + if flattened.shape[0] % WORLD_SIZE != 0: + flattened = Tensor.cat(flattened, Tensor.zeros(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE))) + + # chunk + chunks = flattened.chunk(WORLD_SIZE, dim=0) + reduced = chunks[RANK] + + next_rank = (RANK + 1) % WORLD_SIZE + prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE + + # scatter reduce + current_chunk_index = RANK + for i in range(WORLD_SIZE - 1): + world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-s" if cache_id is not None else None) + current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE + recv_buf = Tensor.empty(*reduced.shape) + world.recv(recv_buf, prev_rank) + reduced = chunks[current_chunk_index] + recv_buf + + # gather + chunks[current_chunk_index] = reduced + current_chunk_index = (RANK + 1) % WORLD_SIZE + for i in range(WORLD_SIZE - 1): + world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-g" if cache_id is not None else None) + current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE + recv_buf = Tensor.empty(*reduced.shape) + world.recv(recv_buf, prev_rank) + reduced = chunks[current_chunk_index] = recv_buf + + return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape) diff --git a/extra/dist/world.py b/extra/dist/world.py index 05046dc7c7..467fa6323a 100644 --- a/extra/dist/world.py +++ b/extra/dist/world.py @@ -59,7 +59,7 @@ def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None # receive a lazybuffer from the target rank def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer: - _recv_rb(x.realize().realized, target_rank) + _recv_rb(x.contiguous().realize().realized, target_rank) return x class Send(Function): diff --git a/test/external/dist/test_collectives.py b/test/external/dist/test_collectives.py new file mode 100644 index 0000000000..4d4829502b --- /dev/null +++ b/test/external/dist/test_collectives.py @@ -0,0 +1,56 @@ +from extra import dist +from tinygrad.jit import TinyJit +if __name__ == "__main__": + dist.preinit() + +from extra.dist import collectives +from tinygrad.helpers import CI, getenv +from tinygrad.tensor import Tensor +import numpy as np + +@TinyJit +def allreduce_jit(t:Tensor, cache_id=None) -> Tensor: + return collectives.allreduce(t, cache_id=cache_id).realize() + +SIZE = 2048 if not CI else 2 +SIZE_2 = 255 if not CI else 3 + +def run(): + # set a deterministic seed so that both ranks generate the same random tensor + Tensor.manual_seed(42) + + rank = getenv("RANK") + + # loop 3 times to make sure it works with the jit + for _ in range(3): + # create a tensor to send + t = Tensor.zeros(SIZE, SIZE) if rank == 0 else Tensor.ones(SIZE, SIZE) + t2 = allreduce_jit(t.contiguous().realize(), cache_id="test") + assert np.allclose(np.ones((SIZE, SIZE)), t2.numpy()) + + # reset jit + allreduce_jit.cnt = 0 + + # test uneven chunk sizes + for _ in range(3): + # create a tensor to send + t = Tensor.ones(SIZE_2, SIZE_2, SIZE_2) if rank == 0 else Tensor.zeros(SIZE_2, SIZE_2, SIZE_2) + t2 = allreduce_jit(t.contiguous().realize(), cache_id="test2") + assert np.allclose(np.ones((SIZE_2, SIZE_2, SIZE_2)), t2.numpy()) + + print(f"rank {rank} passed") + +if __name__ == "__main__": + devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"] + world_size = len(devices) + + dist.init_oob(world_size) + + processes = [] + for rank, device in enumerate(devices): + processes.append(dist.spawn(rank, device, fn=run, args=())) + for p in processes: p.join() + + # exit with error code if any of the processes failed + for p in processes: + if p.exitcode != 0: exit(p.exitcode) diff --git a/test/external/dist/test_world.py b/test/external/dist/test_world.py index 68fe902abc..f805716c62 100644 --- a/test/external/dist/test_world.py +++ b/test/external/dist/test_world.py @@ -59,3 +59,7 @@ if __name__ == "__main__": for rank, device in enumerate(devices): processes.append(dist.spawn(rank, device, fn=run, args=())) for p in processes: p.join() + + # exit with error code if any of the processes failed + for p in processes: + if p.exitcode != 0: exit(p.exitcode) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index c24b740f6b..ea18ef6788 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -48,7 +48,7 @@ class CLBuffer(RawBufferCopyInOut): assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}" buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data) mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False) - with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event]) + with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else [])) class CLProgram: def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None):