Ring allreduce in multitensor (#3000)

* Ring allreduce v3

* Configurable size, number of gpus and jit in benchmark

* ScheduleBarrier v0

* GB/s that make sense

* ScheduleBarrier v0.1

* Fallback on 2 GPUs

* ScheduleBarrier v0.2

* ScheduleBarrier v0.3

* ScheduleBarrier v0.3.1

* ScheduleBarrier v0.3.2

* Replace ScheduleBarrier with automatic optimization

* unused import

* fix comment

* typing

* better fallback

* python 3.8

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
uuuvn
2024-03-20 20:20:01 +02:00
committed by GitHub
parent 455f7bea9b
commit c5bf9e4c96
4 changed files with 120 additions and 10 deletions

View File

@@ -0,0 +1,63 @@
import time
import functools
from tinygrad import Tensor, Device
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import ReduceOps, BinaryOps, GlobalCounters
from tinygrad.features.multi import MultiLazyBuffer, ring_allreduce
from tinygrad.features.jit import TinyJit
from tinygrad.realize import create_schedule, run_schedule
from tinygrad.helpers import getenv
from typing import List, Union
def naive_allreduce(op: ReduceOps, lbs: List[LazyBuffer]):
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
def realize(x: Union[LazyBuffer, List[LazyBuffer]]):
x = x if isinstance(x, list) else [x]
run_schedule(create_schedule(x))
for lb in x: Device[lb.device].synchronize()
def test(impl, devs: List[str], N: int, iters:int = 10):
def _wrapped(impl, op: ReduceOps, t: Tensor) -> Tensor:
return Tensor(MultiLazyBuffer(impl(op, t.lazydata.lbs), 0), device=devs)
_jitted = TinyJit(_wrapped) if getenv("USEJIT", 1) == 1 else _wrapped
secs, gflops, gbs = 0, 0, 0
for i in range(-2, iters):
GlobalCounters.reset()
lbs = [Tensor.full((N,), float(1+i), device=d).contiguous().lazydata for i,d in enumerate(devs)]
realize(lbs)
start = time.time()
realize(_jitted(impl, ReduceOps.SUM, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
end = time.time()
if i < 0:
# First time is slow due to kernel compilation
continue
i_secs = end-start
i_gflops = GlobalCounters.global_ops/i_secs/10**9
i_gbs = (N*4)/i_secs/10**9
print(f"{impl.__name__} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
secs += i_secs
gflops += i_gflops
gbs += i_gbs
return (gflops/iters, gbs/iters, secs/iters)
def main():
dev, n_gpus = Device.DEFAULT, getenv("GPUS", 6) # number of gpus
devs = tuple([f"{dev}:{x}" for x in range(n_gpus)])
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
f32 = 4 # 4 bytes
N = sz//f32
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
(ring_gflops, ring_gbs, ring_secs) = test(ring_allreduce, devs, N)
(naive_gflops, naive_gbs, naive_secs) = test(naive_allreduce, devs, N)
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
if __name__ == "__main__":
main()

View File

@@ -1,11 +1,13 @@
import unittest, functools
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit
from tinygrad.device import BufferCopy
from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.helpers import CI
from tinygrad.ops import LoadOps, ReduceOps, BinaryOps
from tinygrad.helpers import CI, prod
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.realize import create_schedule
from tinygrad.features.multi import ring_allreduce, MultiLazyBuffer
from random import randint
import numpy as np
from hypothesis import given, strategies as strat, settings
@@ -90,6 +92,24 @@ class TestMultiTensor(unittest.TestCase):
fn = f(n)
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
@unittest.skipIf(CI and Device.DEFAULT == "CLANG", "clang is slow")
def test_fuzz_allreduce(self):
def naive_allreduce(lbs):
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
random.seed(41)
for it in range(100):
for n in range(2, 4+1):
t = Tensor.rand(tuple([(n if i == 0 else 1) * randint(1, 10) for i in range(randint(1, 4))])).shard_(tuple([d0, d1, d2, d3][:n]), 0)
a = Tensor(MultiLazyBuffer(naive_allreduce(t.lazydata.lbs), 0))
b = Tensor(MultiLazyBuffer(ring_allreduce(ReduceOps.SUM, t.lazydata.lbs), 0))
diff = a - b
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
assert mean_err < 1e-6, f"big mean error, iteration {it}_{n}"
assert max_err < 1e-6, f"big max error, iteration {it}_{n}"
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()
W = Tensor.kaiming_uniform(N, N).realize()