mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Ring allreduce in multitensor (#3000)
* Ring allreduce v3 * Configurable size, number of gpus and jit in benchmark * ScheduleBarrier v0 * GB/s that make sense * ScheduleBarrier v0.1 * Fallback on 2 GPUs * ScheduleBarrier v0.2 * ScheduleBarrier v0.3 * ScheduleBarrier v0.3.1 * ScheduleBarrier v0.3.2 * Replace ScheduleBarrier with automatic optimization * unused import * fix comment * typing * better fallback * python 3.8 --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com> Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
63
test/external/external_benchmark_multitensor_allreduce.py
vendored
Normal file
63
test/external/external_benchmark_multitensor_allreduce.py
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
import time
|
||||
import functools
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, GlobalCounters
|
||||
from tinygrad.features.multi import MultiLazyBuffer, ring_allreduce
|
||||
from tinygrad.features.jit import TinyJit
|
||||
from tinygrad.realize import create_schedule, run_schedule
|
||||
from tinygrad.helpers import getenv
|
||||
from typing import List, Union
|
||||
|
||||
def naive_allreduce(op: ReduceOps, lbs: List[LazyBuffer]):
|
||||
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
|
||||
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
||||
def realize(x: Union[LazyBuffer, List[LazyBuffer]]):
|
||||
x = x if isinstance(x, list) else [x]
|
||||
run_schedule(create_schedule(x))
|
||||
for lb in x: Device[lb.device].synchronize()
|
||||
|
||||
def test(impl, devs: List[str], N: int, iters:int = 10):
|
||||
def _wrapped(impl, op: ReduceOps, t: Tensor) -> Tensor:
|
||||
return Tensor(MultiLazyBuffer(impl(op, t.lazydata.lbs), 0), device=devs)
|
||||
_jitted = TinyJit(_wrapped) if getenv("USEJIT", 1) == 1 else _wrapped
|
||||
|
||||
secs, gflops, gbs = 0, 0, 0
|
||||
for i in range(-2, iters):
|
||||
GlobalCounters.reset()
|
||||
lbs = [Tensor.full((N,), float(1+i), device=d).contiguous().lazydata for i,d in enumerate(devs)]
|
||||
realize(lbs)
|
||||
start = time.time()
|
||||
realize(_jitted(impl, ReduceOps.SUM, Tensor(MultiLazyBuffer(lbs, 0), device=devs)).lazydata.lbs)
|
||||
end = time.time()
|
||||
if i < 0:
|
||||
# First time is slow due to kernel compilation
|
||||
continue
|
||||
i_secs = end-start
|
||||
i_gflops = GlobalCounters.global_ops/i_secs/10**9
|
||||
i_gbs = (N*4)/i_secs/10**9
|
||||
print(f"{impl.__name__} iter {i+1}/{iters}: {i_secs:.6f} sec {i_gflops:.2f} GFLOP/s {i_gbs:.2f} GB/s")
|
||||
secs += i_secs
|
||||
gflops += i_gflops
|
||||
gbs += i_gbs
|
||||
|
||||
return (gflops/iters, gbs/iters, secs/iters)
|
||||
|
||||
|
||||
def main():
|
||||
dev, n_gpus = Device.DEFAULT, getenv("GPUS", 6) # number of gpus
|
||||
devs = tuple([f"{dev}:{x}" for x in range(n_gpus)])
|
||||
|
||||
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
|
||||
f32 = 4 # 4 bytes
|
||||
N = sz//f32
|
||||
|
||||
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
|
||||
(ring_gflops, ring_gbs, ring_secs) = test(ring_allreduce, devs, N)
|
||||
(naive_gflops, naive_gbs, naive_secs) = test(naive_allreduce, devs, N)
|
||||
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
|
||||
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,11 +1,13 @@
|
||||
import unittest, functools
|
||||
import unittest, functools, random
|
||||
from typing import List
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit
|
||||
from tinygrad.device import BufferCopy
|
||||
from tinygrad.ops import LoadOps, ReduceOps
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.ops import LoadOps, ReduceOps, BinaryOps
|
||||
from tinygrad.helpers import CI, prod
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.realize import create_schedule
|
||||
from tinygrad.features.multi import ring_allreduce, MultiLazyBuffer
|
||||
from random import randint
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
|
||||
@@ -90,6 +92,24 @@ class TestMultiTensor(unittest.TestCase):
|
||||
fn = f(n)
|
||||
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CLANG", "clang is slow")
|
||||
def test_fuzz_allreduce(self):
|
||||
def naive_allreduce(lbs):
|
||||
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
||||
random.seed(41)
|
||||
for it in range(100):
|
||||
for n in range(2, 4+1):
|
||||
t = Tensor.rand(tuple([(n if i == 0 else 1) * randint(1, 10) for i in range(randint(1, 4))])).shard_(tuple([d0, d1, d2, d3][:n]), 0)
|
||||
a = Tensor(MultiLazyBuffer(naive_allreduce(t.lazydata.lbs), 0))
|
||||
b = Tensor(MultiLazyBuffer(ring_allreduce(ReduceOps.SUM, t.lazydata.lbs), 0))
|
||||
diff = a - b
|
||||
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
|
||||
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
|
||||
assert mean_err < 1e-6, f"big mean error, iteration {it}_{n}"
|
||||
assert max_err < 1e-6, f"big max error, iteration {it}_{n}"
|
||||
|
||||
|
||||
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
|
||||
X = Tensor.kaiming_uniform(N, N).realize()
|
||||
W = Tensor.kaiming_uniform(N, N).realize()
|
||||
|
||||
@@ -1,16 +1,43 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Union, Any, Tuple, List
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, dedup, round_up, prod, DEBUG
|
||||
from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, getenv
|
||||
from tinygrad.dtype import DType, Scalar
|
||||
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import sint
|
||||
|
||||
def all_reduce(op:ReduceOps, lbs):
|
||||
# TODO: replace this with ring reduce
|
||||
def ring_allreduce(op: ReduceOps, lbs: List[LazyBuffer]):
|
||||
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
||||
assert all_same([lb.shape[0] for lb in lbs]), "allreduce with uneven shards is undefined"
|
||||
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
|
||||
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
||||
n_lbs, dim = len(lbs), prod(lbs[0].shape)
|
||||
# Ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# so just fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
if n_lbs < 3 or dim < 256_000 or not getenv("RING", 1):
|
||||
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
base, left = dim // n_lbs, dim % n_lbs
|
||||
c_lens = [base + 1 if left - i > 0 else base for i in range(n_lbs)]
|
||||
acc = 0
|
||||
chunks = [(acc, (acc := acc + i)) for i in c_lens if i > 0]
|
||||
chunked = [[lb.reshape((dim,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
||||
|
||||
# Scatter-reduce step
|
||||
for step in range(n_lbs - 1):
|
||||
for i in range(len(chunks)):
|
||||
s, r = (i+step)%n_lbs, (i+step+1)%n_lbs
|
||||
chunked[r][i] = chunked[r][i].e(bop, chunked[s][i].copy_to_device(chunked[r][i].device, force=True))
|
||||
|
||||
# Allgather step
|
||||
for step in range(n_lbs - 1):
|
||||
for i in range(len(chunks)):
|
||||
s, r = (i+step-1)%n_lbs, (i+step)%n_lbs
|
||||
chunked[r][i] = chunked[s][i].copy_to_device(chunked[r][i].device, force=True)
|
||||
|
||||
# Assemble chunks back
|
||||
pads = [((s,dim-e),) for s,e in chunks]
|
||||
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [c.pad(pads[i]) for i,c in enumerate(lb_c)]).reshape(lbs[0].shape) for lb_c in chunked]
|
||||
|
||||
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
||||
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
||||
@@ -90,7 +117,7 @@ class MultiLazyBuffer:
|
||||
# all-reduce on sharded axes
|
||||
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
reduced_parts = [x.r(op, axis) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)]
|
||||
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
|
||||
if all(self.real): return MultiLazyBuffer(ring_allreduce(op, reduced_parts), None)
|
||||
return MultiLazyBuffer(reduced_parts, None, self.real)
|
||||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
|
||||
|
||||
@@ -91,12 +91,12 @@ class LazyBuffer:
|
||||
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)
|
||||
|
||||
def copy_to_device(self, device:str) -> LazyBuffer:
|
||||
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
|
||||
# no COPY
|
||||
if self.device == device: return self
|
||||
|
||||
# double COPY = one COPY
|
||||
if self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
|
||||
if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is LoadOps.COPY:
|
||||
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
||||
|
||||
# const doesn't have to be copied (issues with disk tensor)
|
||||
|
||||
Reference in New Issue
Block a user