mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-07 21:26:21 -05:00
* feat: world * feat: tests * feat: no more backwards * feat: recv into * feat: whoops * feat: test in ci * feat: some debug logging * feat: workflow naming * feat: need to set pythonpath * feat: just send to same device * feat: allreduce * feat: test * feat: need contiguous * feat: test in ci * feat: exit with correct code * feat: don't need that * feat: opencl wait_for just doesn't work * feat: synchronize on out * feat: try? * feat: try again? * feat: add extra realizes * feat: print * feat: seed * feat: tol * feat: test ones and zeros * feat: remove print * feat: are you just flaky * feat: seperate scatter and gather? * feat: just try synchronizing * feat: remove print again * feat: bring back difference * feat: no sync * feat: revert that * feat: back to wait_for * fix: typo
66 lines
1.7 KiB
Python
66 lines
1.7 KiB
Python
from extra import dist
|
|
from tinygrad.jit import TinyJit
|
|
if __name__ == "__main__":
|
|
dist.preinit()
|
|
|
|
from extra.dist import world
|
|
from tinygrad.helpers import CI, getenv
|
|
from tinygrad.tensor import Tensor
|
|
import numpy as np
|
|
|
|
@TinyJit
|
|
def send_jit(t, target_rank, cache_id=None) -> Tensor:
|
|
return world.send(t, target_rank, cache_id=cache_id).realize()
|
|
|
|
@TinyJit
|
|
def recv_jit(t, target_rank, cache_id=None) -> Tensor:
|
|
return world.recv(t, target_rank, cache_id=cache_id).realize()
|
|
|
|
SIZE = 2048 if not CI else 2
|
|
|
|
def run():
|
|
# set a deterministic seed so that both ranks generate the same random tensor
|
|
Tensor.manual_seed(42)
|
|
|
|
rank = getenv("RANK")
|
|
|
|
# loop 3 times to make sure it works with the jit
|
|
for _ in range(3):
|
|
# create a tensor to send
|
|
t = Tensor.randn(SIZE, SIZE)
|
|
|
|
# send to rank 1
|
|
if rank == 0:
|
|
send_jit(t, 1, cache_id="test")
|
|
elif rank == 1:
|
|
t2 = Tensor.empty(SIZE, SIZE)
|
|
recv_jit(t2, 0, cache_id="test")
|
|
|
|
# recv from rank 1
|
|
if rank == 0:
|
|
t2 = Tensor.empty(SIZE, SIZE)
|
|
recv_jit(t2, 1, cache_id="test2")
|
|
elif rank == 1:
|
|
send_jit(t2, 0, cache_id="test2")
|
|
|
|
# check that the received tensor is the same as the sent tensor
|
|
if rank == 0:
|
|
assert np.allclose(t.numpy(), t2.numpy())
|
|
|
|
print(f"rank {rank} passed")
|
|
|
|
if __name__ == "__main__":
|
|
devices = ["gpu:0", "gpu:1" if not CI else "gpu:0"]
|
|
world_size = len(devices)
|
|
|
|
dist.init_oob(world_size)
|
|
|
|
processes = []
|
|
for rank, device in enumerate(devices):
|
|
processes.append(dist.spawn(rank, device, fn=run, args=()))
|
|
for p in processes: p.join()
|
|
|
|
# exit with error code if any of the processes failed
|
|
for p in processes:
|
|
if p.exitcode != 0: exit(p.exitcode)
|