mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 14:45:35 -05:00
* feat: train cifar using multigpu * feat: split eval batch across 5 * feat: cleaner allreduce * feat: 93.88% * feat: cleaner batch chunking from bert * feat: cleaner grad sync * feat: tinygrad argmax * feat: make it work with different gpu counts * feat: move some stuff into the normal __init__ * feat: autodetect gpu count * feat: move import inside
42 lines
1.6 KiB
Python
42 lines
1.6 KiB
Python
from tinygrad.tensor import Tensor, Device
|
|
from tinygrad.helpers import getenv
|
|
|
|
from extra.dist import world
|
|
|
|
def allreduce(t:Tensor, cache_id=None) -> Tensor:
|
|
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
|
cache_id = f"{RANK}-{cache_id}" if cache_id is not None else None
|
|
|
|
# flatten
|
|
flattened = t.flatten()
|
|
|
|
# pad to evenly divide
|
|
if flattened.shape[0] % WORLD_SIZE != 0:
|
|
flattened = Tensor.cat(flattened, Tensor.zeros(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
|
|
|
# chunk
|
|
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
|
|
|
next_rank = (RANK + 1) % WORLD_SIZE
|
|
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
|
|
# scatter reduce
|
|
current_chunk_index = RANK
|
|
for i in range(WORLD_SIZE - 1):
|
|
world.send(chunks[current_chunk_index], next_rank, cache_id=f"{cache_id}-{i}-s" if cache_id is not None else None)
|
|
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
|
|
world.recv(recv_buf, prev_rank)
|
|
chunks[current_chunk_index] += recv_buf
|
|
|
|
# gather
|
|
current_chunk_index = (RANK + 1) % WORLD_SIZE
|
|
for i in range(WORLD_SIZE - 1):
|
|
world.send(chunks[current_chunk_index], next_rank, cache_id=f"{cache_id}-{i}-g" if cache_id is not None else None)
|
|
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
|
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
|
|
world.recv(recv_buf, prev_rank)
|
|
chunks[current_chunk_index].assign(recv_buf)
|
|
|
|
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)
|