diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 95a17a974c..b91005f485 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -104,8 +104,8 @@ jobs: strategy: fail-fast: false matrix: - task: [optimage, openpilot] - name: ${{ matrix.task=='optimage'&&'GPU OPT and IMAGE Tests'||'openpilot (OpenCL) Tests'}} + task: [optimage, openpilot, multigpu] + name: ${{ matrix.task=='optimage'&&'GPU OPT and IMAGE Tests'|| matrix.task=='openpilot'&&'openpilot (OpenCL) Tests'|| matrix.task=='multigpu'&&'MultiGPU Tests'}} runs-on: ubuntu-20.04 timeout-minutes: 20 @@ -154,6 +154,10 @@ jobs: - if: ${{ matrix.task == 'openpilot' }} name: Test tensor core ops run: GPU=1 TC=2 python3 -m pytest -n=auto test/test_ops.py + - if: ${{ matrix.task == 'multigpu' }} + name: Test multigpu + run: | + PYTHONPATH="." python test/external/dist/test_world.py testmetalwebgpu: name: Metal and WebGPU Tests diff --git a/.gitignore b/.gitignore index 0b132e239b..b834f91a73 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ notebooks *.pyc *.so build -dist +/dist *.egg-info /env a.out diff --git a/extra/dist/__init__.py b/extra/dist/__init__.py new file mode 100644 index 0000000000..d870805739 --- /dev/null +++ b/extra/dist/__init__.py @@ -0,0 +1,61 @@ +# this file needs to be very careful with its imports as to not accidentally initialize the runtimes +from multiprocessing.connection import Connection +from typing import Any, Callable, List, Tuple +from tinygrad.helpers import DEBUG, getenv +import multiprocessing as mp +import os + +# this needs to be called before everything else if you are using distributed +def preinit(): + os.environ["DELAYED_RUNTIME_INIT"] = "1" + mp.set_start_method("spawn") + +# out-of-band communication/synchronization +class _OOB: + def __init__(self, pipes:List[Tuple[Connection, Connection]]): + self.pipes = pipes + + # send some data to a target rank, blocks until data is received + def send(self, data:Any, target_rank:int): + self.pipes[getenv("RANK") * getenv("WORLD_SIZE") + target_rank][1].send(data) + + # receive some data from a target rank, blocks until data is received + def recv(self, target_rank:int) -> Any: + return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv() +OOB = None + +def init_oob(world_size:int): + os.environ["WORLD_SIZE"] = str(world_size) + + global OOB + OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)]) + +# this runs in the spawned process so we can do all the delayed runtime initialization +def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()): + # setup the rank + os.environ["RANK"] = str(rank) + + # setup out of band communication + global OOB + OOB = oob + + # do specific runtime initialization for distributed + from tinygrad.lazy import Device + device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1]) + if "GPU" in device: + from tinygrad.runtime.ops_gpu import CL + CL.post_init(device_num) + elif "HIP" in device: + import extra.hip_wrapper as hip + hip.hipSetDevice(device_num) + if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}") + + # convert device to be process specific + Device.DEFAULT = device.split(":")[0] + + fn(*args) + +# wrapper around mp.Process that initializes the runtime +def spawn(rank:int, device:str, fn:Callable, args=()) -> mp.Process: + (p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start() + return p diff --git a/extra/dist/world.py b/extra/dist/world.py new file mode 100644 index 0000000000..05046dc7c7 --- /dev/null +++ b/extra/dist/world.py @@ -0,0 +1,77 @@ +from typing import Any, Optional, Tuple +from extra import dist +from multiprocessing import shared_memory +from tinygrad.helpers import DEBUG, GlobalCounters, colored +from tinygrad.lazy import LazyBuffer +from tinygrad.runtime.lib import RawBufferCopyIn, RawBufferCopyInOut +from tinygrad.runtime.ops_shm import RawShmBuffer +from tinygrad.tensor import Tensor, Function +import numpy as np + +# fake the function signature of ASTRunner so we can put it in the cache +def __send_rb(args:Tuple[RawBufferCopyInOut, RawShmBuffer, int, Any], jit=False, force_wait=False): + args[0]._copyout(np.frombuffer(args[1]._buffer(), dtype=args[0].dtype.np)) + dist.OOB.send(args[3], args[2]) + if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} sent {args[0]} to rank {args[2]}") + +def __recv_rb(args:Tuple[RawBufferCopyIn, RawShmBuffer, int], jit=False, force_wait=False): + dist.OOB.recv(args[2]) + args[0]._copyin(args[1].toCPU()) + if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} recv {args[0]} from rank {args[2]}") + +# send a rawbuffer from out rank to the target rank +def _send_rb(x:RawBufferCopyInOut, target_rank:int, cache_id:Optional[str]=None): + assert isinstance(x, RawBufferCopyInOut), "we only support RawBufferCopyInOut for now" + # cache the shared memory so we don't have to create it every time + if cache_id is not None and cache_id in _send_rb.shared_memory_cache: + shm_name = _send_rb.shared_memory_cache[cache_id] + else: + shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name + s.close() + if cache_id is not None: _send_rb.shared_memory_cache[cache_id] = shm_name + # copy the buffer into shared memory + device = f"{shm_name},{cache_id}" if cache_id is not None else shm_name + rb = RawShmBuffer(x.size, x.dtype, device=device) + __send_rb((x, rb, target_rank, (shm_name, cache_id))) + + # jit support + if GlobalCounters.cache is not None: GlobalCounters.cache.append((__send_rb, [x, rb, target_rank, None])) +setattr(_send_rb, "shared_memory_cache", {}) + +# receive a rawbuffer from the target rank +def _recv_rb(x:RawBufferCopyIn, target_rank:int): + assert isinstance(x, RawBufferCopyIn), "we only support RawBufferCopyIn for now" + extra = dist.OOB.recv(target_rank) + device = f"{extra[0]},{extra[1]}" if extra[1] is not None else f"{extra[0]}" + rb = RawShmBuffer(x.size, x.dtype, device=device) + x._copyin(rb.toCPU()) + if DEBUG >= 2: print(f"**** got {x} from rank {target_rank}") + + if extra[1] is None: + (s := shared_memory.SharedMemory(name=extra[0])).close() + s.unlink() + + # jit support + if GlobalCounters.cache is not None: GlobalCounters.cache.append((__recv_rb, [x, rb, target_rank])) + +# sends a lazybuffer from our rank to the target rank +def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None: _send_rb(x.contiguous().realize().realized, target_rank, cache_id=cache_id) + +# receive a lazybuffer from the target rank +def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer: + _recv_rb(x.realize().realized, target_rank) + return x + +class Send(Function): + def forward(self, x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> LazyBuffer: + self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype + _send_lb(x, target_rank, cache_id=cache_id) + return x + +class Recv(Function): + def forward(self, x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> LazyBuffer: + self.target_rank, self.cache_id = target_rank, cache_id + return _recv_lb(x, target_rank) + +def send(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Send.apply(x, target_rank=target_rank, cache_id=cache_id) +def recv(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Recv.apply(x, target_rank=target_rank, cache_id=cache_id) diff --git a/test/external/dist/test_world.py b/test/external/dist/test_world.py new file mode 100644 index 0000000000..68fe902abc --- /dev/null +++ b/test/external/dist/test_world.py @@ -0,0 +1,61 @@ +from extra import dist +from tinygrad.jit import TinyJit +if __name__ == "__main__": + dist.preinit() + +from extra.dist import world +from tinygrad.helpers import CI, getenv +from tinygrad.tensor import Tensor +import numpy as np + +@TinyJit +def send_jit(t, target_rank, cache_id=None) -> Tensor: + return world.send(t, target_rank, cache_id=cache_id).realize() + +@TinyJit +def recv_jit(t, target_rank, cache_id=None) -> Tensor: + return world.recv(t, target_rank, cache_id=cache_id).realize() + +SIZE = 2048 if not CI else 2 + +def run(): + # set a deterministic seed so that both ranks generate the same random tensor + Tensor.manual_seed(42) + + rank = getenv("RANK") + + # loop 3 times to make sure it works with the jit + for _ in range(3): + # create a tensor to send + t = Tensor.randn(SIZE, SIZE) + + # send to rank 1 + if rank == 0: + send_jit(t, 1, cache_id="test") + elif rank == 1: + t2 = Tensor.empty(SIZE, SIZE) + recv_jit(t2, 0, cache_id="test") + + # recv from rank 1 + if rank == 0: + t2 = Tensor.empty(SIZE, SIZE) + recv_jit(t2, 1, cache_id="test2") + elif rank == 1: + send_jit(t2, 0, cache_id="test2") + + # check that the received tensor is the same as the sent tensor + if rank == 0: + assert np.allclose(t.numpy(), t2.numpy()) + + print(f"rank {rank} passed") + +if __name__ == "__main__": + devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"] + world_size = len(devices) + + dist.init_oob(world_size) + + processes = [] + for rank, device in enumerate(devices): + processes.append(dist.spawn(rank, device, fn=run, args=())) + for p in processes: p.join()