Ring allreduce in multitensor (#3000)

* Ring allreduce v3

* Configurable size, number of gpus and jit in benchmark

* ScheduleBarrier v0

* GB/s that make sense

* ScheduleBarrier v0.1

* Fallback on 2 GPUs

* ScheduleBarrier v0.2

* ScheduleBarrier v0.3

* ScheduleBarrier v0.3.1

* ScheduleBarrier v0.3.2

* Replace ScheduleBarrier with automatic optimization

* unused import

* fix comment

* typing

* better fallback

* python 3.8

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
uuuvn
2024-03-20 20:20:01 +02:00
committed by GitHub
parent 455f7bea9b
commit c5bf9e4c96
4 changed files with 120 additions and 10 deletions

View File

@@ -0,0 +1,63 @@
import time
import functools
from tinygrad import Tensor, Device
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import ReduceOps, BinaryOps, GlobalCounters
from tinygrad.features.multi import MultiLazyBuffer, ring_allreduce
from tinygrad.features.jit import TinyJit
from tinygrad.realize import create_schedule, run_schedule
from tinygrad.helpers import getenv
from typing import List, Union
def naive_allreduce(op: ReduceOps, lbs: List[LazyBuffer]):
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
def realize(x: Union[LazyBuffer, List[LazyBuffer]]):
x = x if isinstance(x, list) else [x]
run_schedule(create_schedule(x))
for lb in x: Device[lb.device].synchronize()
def test(impl, devs: List[str], N: int, iters:int = 10):
def _wrapped(impl, op: ReduceOps, t: Tensor) -> Tensor:
return Tensor(MultiLazyBuffer(impl(op, t.lazydata.lbs), 0), device=devs)
_jitted = TinyJit(_wrapped) if getenv("USEJIT", 1) == 1 else _wrapped
secs, gflops, gbs = 0, 0, 0
for i in range(-2, iters):
GlobalCounters.reset()
lbs = [Tensor.full((N,), float(1+i), device=d).contiguous().lazydata for i,d in enumerate(devs)]
realize(lbs)
start = time.time()
realize(_jitted(impl, ReduceOps.SUM, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
end = time.time()
if i < 0:
# First time is slow due to kernel compilation
continue
i_secs = end-start
i_gflops = GlobalCounters.global_ops/i_secs/10**9
i_gbs = (N*4)/i_secs/10**9
print(f"{impl.__name__} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
secs += i_secs
gflops += i_gflops
gbs += i_gbs
return (gflops/iters, gbs/iters, secs/iters)
def main():
dev, n_gpus = Device.DEFAULT, getenv("GPUS", 6) # number of gpus
devs = tuple([f"{dev}:{x}" for x in range(n_gpus)])
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
f32 = 4 # 4 bytes
N = sz//f32
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
(ring_gflops, ring_gbs, ring_secs) = test(ring_allreduce, devs, N)
(naive_gflops, naive_gbs, naive_secs) = test(naive_allreduce, devs, N)
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
if __name__ == "__main__":
main()

View File

@@ -1,11 +1,13 @@
import unittest, functools
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit
from tinygrad.device import BufferCopy
from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.helpers import CI
from tinygrad.ops import LoadOps, ReduceOps, BinaryOps
from tinygrad.helpers import CI, prod
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.realize import create_schedule
from tinygrad.features.multi import ring_allreduce, MultiLazyBuffer
from random import randint
import numpy as np
from hypothesis import given, strategies as strat, settings
@@ -90,6 +92,24 @@ class TestMultiTensor(unittest.TestCase):
fn = f(n)
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
@unittest.skipIf(CI and Device.DEFAULT == "CLANG", "clang is slow")
def test_fuzz_allreduce(self):
def naive_allreduce(lbs):
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
random.seed(41)
for it in range(100):
for n in range(2, 4+1):
t = Tensor.rand(tuple([(n if i == 0 else 1) * randint(1, 10) for i in range(randint(1, 4))])).shard_(tuple([d0, d1, d2, d3][:n]), 0)
a = Tensor(MultiLazyBuffer(naive_allreduce(t.lazydata.lbs), 0))
b = Tensor(MultiLazyBuffer(ring_allreduce(ReduceOps.SUM, t.lazydata.lbs), 0))
diff = a - b
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
assert mean_err < 1e-6, f"big mean error, iteration {it}_{n}"
assert max_err < 1e-6, f"big max error, iteration {it}_{n}"
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()
W = Tensor.kaiming_uniform(N, N).realize()

View File

@@ -1,16 +1,43 @@
from __future__ import annotations
from typing import Optional, Union, Any, Tuple, List
import functools, itertools, operator
from tinygrad.helpers import all_same, dedup, round_up, prod, DEBUG
from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, getenv
from tinygrad.dtype import DType, Scalar
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import sint
def all_reduce(op:ReduceOps, lbs):
# TODO: replace this with ring reduce
def ring_allreduce(op: ReduceOps, lbs: List[LazyBuffer]):
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
n_lbs, dim = len(lbs), prod(lbs[0].shape)
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
if n_lbs < 3 or dim < 256_000 or not getenv("RING", 1):
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
base, left = dim // n_lbs, dim % n_lbs
c_lens = [base + 1 if left - i > 0 else base for i in range(n_lbs)]
acc = 0
chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
# Scatter-reduce step
for step in range(n_lbs - 1):
for i in range(len(chunks)):
s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
chunked[r][i] = chunked[r][i].e(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
# Allgather step
for step in range(n_lbs - 1):
for i in range(len(chunks)):
s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
# Assemble chunks back
pads = [((s,dim-e),) for s,e in chunks]
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
@@ -90,7 +117,7 @@ class MultiLazyBuffer:
# all-reduce on sharded axes
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
reduced_parts = [x.r(op, axis) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)]
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
if all(self.real): return MultiLazyBuffer(ring_allreduce(op, reduced_parts), None)
return MultiLazyBuffer(reduced_parts, None, self.real)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)

View File

@@ -91,12 +91,12 @@ class LazyBuffer:
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)
def copy_to_device(self, device:str) -> LazyBuffer:
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
# no COPY
if self.device == device: return self
# double COPY = one COPY
if self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
# const doesn't have to be copied (issues with disk tensor)