mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* feat: world * feat: tests * feat: no more backwards * feat: recv into * feat: whoops * feat: test in ci * feat: some debug logging * feat: workflow naming * feat: need to set pythonpath * feat: just send to same device * feat: allreduce * feat: test * feat: need contiguous * feat: test in ci * feat: exit with correct code * feat: don't need that * feat: opencl wait_for just doesn't work * feat: synchronize on out * feat: try? * feat: try again? * feat: add extra realizes * feat: print * feat: seed * feat: tol * feat: test ones and zeros * feat: remove print * feat: are you just flaky * feat: seperate scatter and gather? * feat: just try synchronizing * feat: remove print again * feat: bring back difference * feat: no sync * feat: revert that * feat: back to wait_for * fix: typo
44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
from tinygrad.tensor import Tensor, Device
|
|
from tinygrad.helpers import getenv
|
|
|
|
from extra.dist import world
|
|
|
|
def allreduce(t:Tensor, cache_id=None) -> Tensor:
|
|
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
|
cache_id = f"{RANK}-{cache_id}" if cache_id is not None else None
|
|
|
|
# flatten
|
|
flattened = t.flatten()
|
|
|
|
# pad to evenly divide
|
|
if flattened.shape[0] % WORLD_SIZE != 0:
|
|
flattened = Tensor.cat(flattened, Tensor.zeros(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
|
|
|
# chunk
|
|
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
|
reduced = chunks[RANK]
|
|
|
|
next_rank = (RANK + 1) % WORLD_SIZE
|
|
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
|
|
# scatter reduce
|
|
current_chunk_index = RANK
|
|
for i in range(WORLD_SIZE - 1):
|
|
world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-s" if cache_id is not None else None)
|
|
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
recv_buf = Tensor.empty(*reduced.shape)
|
|
world.recv(recv_buf, prev_rank)
|
|
reduced = chunks[current_chunk_index] + recv_buf
|
|
|
|
# gather
|
|
chunks[current_chunk_index] = reduced
|
|
current_chunk_index = (RANK + 1) % WORLD_SIZE
|
|
for i in range(WORLD_SIZE - 1):
|
|
world.send(reduced, next_rank, cache_id=f"{cache_id}-{i}-g" if cache_id is not None else None)
|
|
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
recv_buf = Tensor.empty(*reduced.shape)
|
|
world.recv(recv_buf, prev_rank)
|
|
reduced = chunks[current_chunk_index] = recv_buf
|
|
|
|
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)
|