mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
allreduce is a function with LATE_ALLREDUCE=1 (#15119)
* allreduce as a function * allreduce function * support allreduce function * LATE_ALLREDUCE
This commit is contained in:
63
tinygrad/schedule/allreduce.py
Normal file
63
tinygrad/schedule/allreduce.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_int, prod, DEBUG, RING, ALL2ALL, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
|
||||
# *** allreduce implementation ***
|
||||
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
if not isinstance(buf.device, tuple): return None
|
||||
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
|
||||
ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
||||
|
||||
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
|
||||
use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}")
|
||||
|
||||
# contiguous before we copy it
|
||||
buf = buf.contiguous()
|
||||
|
||||
# naive: copy to all devices. if you shrink later, that'll be handled
|
||||
if not use_ring and not use_all2all:
|
||||
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
|
||||
|
||||
# chunk data into ndev pieces
|
||||
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
||||
base, left = divmod(numel // factor, ndev)
|
||||
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
|
||||
|
||||
# reduce-scatter
|
||||
reduced_chunks:list[UOp] = []
|
||||
for i,(s,e) in enumerate(chunks):
|
||||
if use_all2all:
|
||||
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
|
||||
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
|
||||
else:
|
||||
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
|
||||
for step in range(ndev-1):
|
||||
src, dest = (i+step)%ndev, (i+step+1)%ndev
|
||||
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
|
||||
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
|
||||
reduced_chunks.append(reduced)
|
||||
|
||||
# allgather
|
||||
copied_chunks:list[UOp] = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
|
||||
else:
|
||||
chain:list[UOp] = [rc]
|
||||
for step in range(ndev-1):
|
||||
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
|
||||
|
||||
# reassemble
|
||||
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
|
||||
|
||||
def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|None:
|
||||
# BUFFER without unique have unique added later
|
||||
if output is None: output = UOp(Ops.BUFFER, red.dtype, (UOp(Ops.NOOP), red.src[1]), red.size).reshape(red.shape)
|
||||
to = red.param_like(0)
|
||||
src = buf.param_like(1)
|
||||
red = UOp(Ops.ALLREDUCE, dtype=red.dtype, src=(src, red.src[1]), arg=red.arg)
|
||||
return output.after(to.assign(handle_allreduce(src, red)).sink().call(output, buf.contiguous(), name="allreduce", precompile=True))
|
||||
@@ -1,59 +1,7 @@
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
|
||||
from tinygrad.helpers import all_same, prod, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
# *** allreduce implementation ***
|
||||
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
if not isinstance(buf.device, tuple): return None
|
||||
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
|
||||
ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
|
||||
|
||||
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
|
||||
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
|
||||
use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
|
||||
use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
|
||||
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}")
|
||||
|
||||
# contiguous before we copy it
|
||||
buf = buf.contiguous()
|
||||
|
||||
# naive: copy to all devices. if you shrink later, that'll be handled
|
||||
if not use_ring and not use_all2all:
|
||||
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
|
||||
|
||||
# chunk data into ndev pieces
|
||||
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
||||
base, left = divmod(numel // factor, ndev)
|
||||
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
|
||||
|
||||
# reduce-scatter
|
||||
reduced_chunks:list[UOp] = []
|
||||
for i,(s,e) in enumerate(chunks):
|
||||
if use_all2all:
|
||||
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
|
||||
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
|
||||
else:
|
||||
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
|
||||
for step in range(ndev-1):
|
||||
src, dest = (i+step)%ndev, (i+step+1)%ndev
|
||||
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
|
||||
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
|
||||
reduced_chunks.append(reduced)
|
||||
|
||||
# allgather
|
||||
copied_chunks:list[UOp] = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
|
||||
else:
|
||||
chain:list[UOp] = [rc]
|
||||
for step in range(ndev-1):
|
||||
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
|
||||
|
||||
# reassemble
|
||||
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
|
||||
from tinygrad.schedule.allreduce import handle_allreduce
|
||||
|
||||
# ***** multi rewrite MSELECT/MSTACK *****
|
||||
|
||||
@@ -71,7 +19,6 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
|
||||
return ms.replace(src=tuple(ret))
|
||||
|
||||
replace_allreduce = PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
|
||||
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
|
||||
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
|
||||
@@ -87,6 +34,11 @@ replace_allreduce = PatternMatcher([
|
||||
lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),)+v.src[1:])),
|
||||
])
|
||||
|
||||
_early_allreduce = PatternMatcher([
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
|
||||
])
|
||||
if not getenv("LATE_ALLREDUCE", 0): replace_allreduce = _early_allreduce + replace_allreduce
|
||||
|
||||
# ***** multi functions *****
|
||||
|
||||
def alu_multi(root:UOp):
|
||||
|
||||
@@ -10,6 +10,7 @@ from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.allreduce import create_allreduce_function
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
@@ -103,6 +104,10 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# resolve calls
|
||||
(UPat(Ops.CALL, name="c"), resolve_call),
|
||||
|
||||
# resolve allreduce (must be bottom up)
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var("output"), UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"))), create_allreduce_function),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), create_allreduce_function),
|
||||
|
||||
# split_reduceop
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
||||
|
||||
@@ -369,6 +374,12 @@ def flatten_bufferize(x:UOp):
|
||||
return ret
|
||||
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
||||
|
||||
def resolve_anonymous_buffer(ctx:itertools.count, b:UOp, c:UOp) -> UOp|None:
|
||||
dab = b.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx)),)+b.src[1:])
|
||||
nc_src = tuple(dab if x is b else x for x in c.src)
|
||||
if nc_src == c.src: return None
|
||||
return dab.after(c.replace(src=nc_src))
|
||||
|
||||
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
|
||||
|
||||
@@ -382,7 +393,10 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
# remove MOP on AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(GroupOp.Movement, name="y"))), lambda x,y: x.after(y.src[0])),
|
||||
# remove double AFTER
|
||||
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:]))
|
||||
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:])),
|
||||
|
||||
# resolve anonymous buffers
|
||||
(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, src=(UPat(Ops.NOOP),), name="b", allow_any_len=True), UPat(Ops.CALL, name="c"))), resolve_anonymous_buffer),
|
||||
])
|
||||
|
||||
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
||||
|
||||
@@ -926,6 +926,7 @@ def should_resolve_call(c:UOp) -> bool:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
|
||||
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY}: return False
|
||||
if c.arg.precompile: return False
|
||||
return True
|
||||
|
||||
# ******** ops in python ********
|
||||
|
||||
Reference in New Issue
Block a user