allreduce is a function with LATE_ALLREDUCE=1 (#15119)

* allreduce as a function

* allreduce function

* support allreduce function

* LATE_ALLREDUCE
This commit is contained in:
George Hotz
2026-03-04 15:17:58 +08:00
committed by GitHub
parent e7e70a3c95
commit 5ecfe549e7
4 changed files with 86 additions and 56 deletions

View File

@@ -0,0 +1,63 @@
import functools, itertools
from tinygrad.helpers import all_int, prod, DEBUG, RING, ALL2ALL, getenv
from tinygrad.uop.ops import Ops, UOp
# *** allreduce implementation ***
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}")
# contiguous before we copy it
buf = buf.contiguous()
# naive: copy to all devices. if you shrink later, that'll be handled
if not use_ring and not use_all2all:
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
# chunk data into ndev pieces
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = divmod(numel // factor, ndev)
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
# reduce-scatter
reduced_chunks:list[UOp] = []
for i,(s,e) in enumerate(chunks):
if use_all2all:
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
else:
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
for step in range(ndev-1):
src, dest = (i+step)%ndev, (i+step+1)%ndev
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced)
# allgather
copied_chunks:list[UOp] = []
for i,rc in enumerate(reduced_chunks):
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
else:
chain:list[UOp] = [rc]
for step in range(ndev-1):
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
# reassemble
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|None:
# BUFFER without unique have unique added later
if output is None: output = UOp(Ops.BUFFER, red.dtype, (UOp(Ops.NOOP), red.src[1]), red.size).reshape(red.shape)
to = red.param_like(0)
src = buf.param_like(1)
red = UOp(Ops.ALLREDUCE, dtype=red.dtype, src=(src, red.src[1]), arg=red.arg)
return output.after(to.assign(handle_allreduce(src, red)).sink().call(output, buf.contiguous(), name="allreduce", precompile=True))

View File

@@ -1,59 +1,7 @@
import functools, itertools
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
from tinygrad.helpers import all_same, prod, getenv
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call
from tinygrad.dtype import dtypes
# *** allreduce implementation ***
def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
if not isinstance(buf.device, tuple): return None
assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}"
ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape)
# ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically)
# fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks.
use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1))
use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1))
if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}")
# contiguous before we copy it
buf = buf.contiguous()
# naive: copy to all devices. if you shrink later, that'll be handled
if not use_ring and not use_all2all:
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
# chunk data into ndev pieces
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = divmod(numel // factor, ndev)
chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0)))
# reduce-scatter
reduced_chunks:list[UOp] = []
for i,(s,e) in enumerate(chunks):
if use_all2all:
chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)]
reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i))
else:
chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),))
for step in range(ndev-1):
src, dest = (i+step)%ndev, (i+step+1)%ndev
cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None)
reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced)
# allgather
copied_chunks:list[UOp] = []
for i,rc in enumerate(reduced_chunks):
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
else:
chain:list[UOp] = [rc]
for step in range(ndev-1):
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
# reassemble
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
from tinygrad.schedule.allreduce import handle_allreduce
# ***** multi rewrite MSELECT/MSTACK *****
@@ -71,7 +19,6 @@ def mstack_early_shrink(ms:UOp, shrink:UOp):
return ms.replace(src=tuple(ret))
replace_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
# BROADCAST: explicitly expand broadcast copies and combine with MSTACK
(UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x:
UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None),
@@ -87,6 +34,11 @@ replace_allreduce = PatternMatcher([
lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),)+v.src[1:])),
])
_early_allreduce = PatternMatcher([
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),
])
if not getenv("LATE_ALLREDUCE", 0): replace_allreduce = _early_allreduce + replace_allreduce
# ***** multi functions *****
def alu_multi(root:UOp):

View File

@@ -10,6 +10,7 @@ from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.allreduce import create_allreduce_function
# creation can recurse a lot
import sys
@@ -103,6 +104,10 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# resolve calls
(UPat(Ops.CALL, name="c"), resolve_call),
# resolve allreduce (must be bottom up)
(UPat(Ops.ASSIGN, src=(UPat.var("output"), UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"))), create_allreduce_function),
(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), create_allreduce_function),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
@@ -369,6 +374,12 @@ def flatten_bufferize(x:UOp):
return ret
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
def resolve_anonymous_buffer(ctx:itertools.count, b:UOp, c:UOp) -> UOp|None:
dab = b.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx)),)+b.src[1:])
nc_src = tuple(dab if x is b else x for x in c.src)
if nc_src == c.src: return None
return dab.after(c.replace(src=nc_src))
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
@@ -382,7 +393,10 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
# remove MOP on AFTER
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(GroupOp.Movement, name="y"))), lambda x,y: x.after(y.src[0])),
# remove double AFTER
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:]))
(UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:])),
# resolve anonymous buffers
(UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, src=(UPat(Ops.NOOP),), name="b", allow_any_len=True), UPat(Ops.CALL, name="c"))), resolve_anonymous_buffer),
])
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([

View File

@@ -926,6 +926,7 @@ def should_resolve_call(c:UOp) -> bool:
# don't resolve real kernel calls, sink or program
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False
if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY}: return False
if c.arg.precompile: return False
return True
# ******** ops in python ********