Files
tinygrad/extra/dist/collectives.py
wozeparrot 29d5801387 distributed collectives (#1519)
* feat: world

* feat: tests

* feat: no more backwards

* feat: recv into

* feat: whoops

* feat: test in ci

* feat: some debug logging

* feat: workflow naming

* feat: need to set pythonpath

* feat: just send to same device

* feat: allreduce

* feat: test

* feat: need contiguous

* feat: test in ci

* feat: exit with correct code

* feat: don't need that

* feat: opencl wait_for just doesn't work

* feat: synchronize on out

* feat: try?

* feat: try again?

* feat: add extra realizes

* feat: print

* feat: seed

* feat: tol

* feat: test ones and zeros

* feat: remove print

* feat: are you just flaky

* feat: seperate scatter and gather?

* feat: just try synchronizing

* feat: remove print again

* feat: bring back difference

* feat: no sync

* feat: revert that

* feat: back to wait_for

* fix: typo
2023-08-11 10:22:07 -07:00

44 lines
1.6 KiB
Python

from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import getenv
from extra.dist import world
def allreduce(t:Tensor, cache_id=None) -> Tensor:
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
cache_id = f"{RANK}-{cache_id}" if cache_id is not None else None
# flatten
flattened = t.flatten()
# pad to evenly divide
if flattened.shape[0] % WORLD_SIZE != 0:
flattened = Tensor.cat(flattened, Tensor.zeros(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
# chunk
chunks = flattened.chunk(WORLD_SIZE, dim=0)
reduced = chunks[RANK]
next_rank = (RANK + 1) % WORLD_SIZE
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
# scatter reduce
current_chunk_index = RANK
for i in range(WORLD_SIZE - 1):
world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-s" if cache_id is not None else None)
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
recv_buf = Tensor.empty(*reduced.shape)
world.recv(recv_buf, prev_rank)
reduced = chunks[current_chunk_index] + recv_buf
# gather
chunks[current_chunk_index] = reduced
current_chunk_index = (RANK + 1) % WORLD_SIZE
for i in range(WORLD_SIZE - 1):
world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-g" if cache_id is not None else None)
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
recv_buf = Tensor.empty(*reduced.shape)
world.recv(recv_buf, prev_rank)
reduced = chunks[current_chunk_index] = recv_buf
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)