Files
tinygrad/test/external/dist/test_world.py
wozeparrot c29653605e hip multigpu training (#1878)
* feat: move to hip

* feat: special path for RawBufferTransfer

* feat: initial rawbuffertransfer

* feat: hip ipc

* feat: working hip ipc

* feat: need to base device without args

* feat: close mem handle

* feat: modified test

* feat: more multihip stuff

* clean: cleanup

* feat: cleaner

* feat: don't crash

* feat: test more

* clean: way cleaner hip wrapper

* feat: barrier

* feat: barrier

* feat: this breaks stuff

* feat: we can use empty here

* feat: maybe fix tests

* feat: maybe fix tests again?

* fix: probably fix tests

* feat: no waiting here

* feat: wait here

* feat: much larger test

* feat: need to sync here

* feat: make this async

* feat: no waiting!

* feat: cut here

* feat: sync copy

* feat: random imports

* feat: much cleaner world

* feat: restore this

* feat: restore this

* clean: cleanup

* feat: set this
2023-10-24 17:35:53 -04:00

69 lines
1.8 KiB
Python

from extra import dist
from tinygrad.jit import TinyJit
if __name__ == "__main__":
dist.preinit()
from extra.dist import world
from tinygrad.helpers import CI, getenv
from tinygrad.tensor import Tensor
import numpy as np
@TinyJit
def send_jit(t, target_rank, cache_id=None) -> Tensor:
return world.send(t, target_rank, cache_id=cache_id).realize()
@TinyJit
def recv_jit(t, target_rank, cache_id=None) -> Tensor:
return world.recv(t, target_rank, cache_id=cache_id).realize()
SIZE = 2048 if not CI else 2
def run():
# set a deterministic seed so that both ranks generate the same random tensor
Tensor.manual_seed(42)
rank = getenv("RANK")
# loop 3 times to make sure it works with the jit
for _ in range(3):
# create a tensor to send
t = Tensor.randn(SIZE, SIZE)
# send to rank 1
if rank == 0:
send_jit(t, 1, cache_id="test")
elif rank == 1:
t2 = Tensor.empty(SIZE, SIZE)
recv_jit(t2, 0, cache_id="test")
# recv from rank 1
if rank == 0:
t2 = Tensor.empty(SIZE, SIZE)
recv_jit(t2, 1, cache_id="test2")
elif rank == 1:
send_jit(t2, 0, cache_id="test2")
# check that the received tensor is the same as the sent tensor
if rank == 0:
assert np.allclose(t.numpy(), t2.numpy()), f"{t2.numpy()} wasn't equal to {t.numpy()}"
print(f"rank {rank} passed")
if __name__ == "__main__":
if getenv("HIP"):
devices = ["hip:0", "hip:1"]
else:
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
world_size = len(devices)
dist.init_oob(world_size)
processes = []
for rank, device in enumerate(devices):
processes.append(dist.spawn(rank, device, fn=run, args=()))
for p in processes: p.join()
# exit with error code if any of the processes failed
for p in processes:
if p.exitcode != 0: exit(p.exitcode)