From 5ecfe549e718912fe2e5aaea4647d27abf404d7f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 4 Mar 2026 15:17:58 +0800 Subject: [PATCH] allreduce is a function with LATE_ALLREDUCE=1 (#15119) * allreduce as a function * allreduce function * support allreduce function * LATE_ALLREDUCE --- tinygrad/schedule/allreduce.py | 63 ++++++++++++++++++++++++++++++++++ tinygrad/schedule/multi.py | 62 ++++----------------------------- tinygrad/schedule/rangeify.py | 16 ++++++++- tinygrad/uop/ops.py | 1 + 4 files changed, 86 insertions(+), 56 deletions(-) create mode 100644 tinygrad/schedule/allreduce.py diff --git a/tinygrad/schedule/allreduce.py b/tinygrad/schedule/allreduce.py new file mode 100644 index 0000000000..70c09a5ea8 --- /dev/null +++ b/tinygrad/schedule/allreduce.py @@ -0,0 +1,63 @@ +import functools, itertools +from tinygrad.helpers import all_int, prod, DEBUG, RING, ALL2ALL, getenv +from tinygrad.uop.ops import Ops, UOp + +# *** allreduce implementation *** +def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: + if not isinstance(buf.device, tuple): return None + assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}" + ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape) + + # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically) + # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks. + use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1)) + use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1)) + if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}") + + # contiguous before we copy it + buf = buf.contiguous() + + # naive: copy to all devices. if you shrink later, that'll be handled + if not use_ring and not use_all2all: + return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)]) + + # chunk data into ndev pieces + factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1) + base, left = divmod(numel // factor, ndev) + chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0))) + + # reduce-scatter + reduced_chunks:list[UOp] = [] + for i,(s,e) in enumerate(chunks): + if use_all2all: + chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)] + reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i)) + else: + chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),)) + for step in range(ndev-1): + src, dest = (i+step)%ndev, (i+step+1)%ndev + cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None) + reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest)) + reduced_chunks.append(reduced) + + # allgather + copied_chunks:list[UOp] = [] + for i,rc in enumerate(reduced_chunks): + if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg)) + elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev)))) + else: + chain:list[UOp] = [rc] + for step in range(ndev-1): + chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev])) + copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev)))) + + # reassemble + return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape) + +def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|None: + # BUFFER without unique have unique added later + if output is None: output = UOp(Ops.BUFFER, red.dtype, (UOp(Ops.NOOP), red.src[1]), red.size).reshape(red.shape) + to = red.param_like(0) + src = buf.param_like(1) + red = UOp(Ops.ALLREDUCE, dtype=red.dtype, src=(src, red.src[1]), arg=red.arg) + return output.after(to.assign(handle_allreduce(src, red)).sink().call(output, buf.contiguous(), name="allreduce", precompile=True)) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 73a5ff92ab..240c655101 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,59 +1,7 @@ -import functools, itertools -from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv +from tinygrad.helpers import all_same, prod, getenv from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite, should_resolve_call from tinygrad.dtype import dtypes - -# *** allreduce implementation *** -def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: - if not isinstance(buf.device, tuple): return None - assert all_int(buf.shape), f"does not support symbolic shape {buf.shape}" - ndev, shape, numel = len(buf.device), buf.shape, prod(buf.shape) - - # ring allreduce doesn't provide a benefit with only 2 nodes or where number of elements is less than 256k (empirically) - # fallback to naive allreduce to save on kernel dispatch, chunking and reassembling chunks. - use_all2all = (ALL2ALL >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and ALL2ALL >= 1)) - use_ring = not use_all2all and (RING >= 2 or (ndev > 2 and numel > getenv("RING_ALLREDUCE_THRESHOLD", 256_000) and RING >= 1)) - if DEBUG >= 2: print(f"{'ALL2ALL' if use_all2all else 'RING' if use_ring else 'NAIVE'} ALLREDUCE {ndev}x{numel} | {buf.dtype}") - - # contiguous before we copy it - buf = buf.contiguous() - - # naive: copy to all devices. if you shrink later, that'll be handled - if not use_ring and not use_all2all: - return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)]) - - # chunk data into ndev pieces - factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1) - base, left = divmod(numel // factor, ndev) - chunks = list(itertools.pairwise(itertools.accumulate([(base + 1) * factor] * left + [base * factor] * (ndev - left), initial=0))) - - # reduce-scatter - reduced_chunks:list[UOp] = [] - for i,(s,e) in enumerate(chunks): - if use_all2all: - chunks_on_i = [buf.mselect(j).reshape((numel,)).shrink(((s,e),)).copy_to_device(buf.device[i]) for j in range(ndev)] - reduced_chunks.append(functools.reduce(lambda x,y: x.alu(red.arg, y), chunks_on_i)) - else: - chunk, reduced = buf.reshape((numel,)).shrink(((s,e),)), buf.reshape((numel,)).shrink(((s,e),)) - for step in range(ndev-1): - src, dest = (i+step)%ndev, (i+step+1)%ndev - cp = reduced.copy_to_device(buf.device[dest], src if isinstance(reduced.device, tuple) else None) - reduced = cp.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest)) - reduced_chunks.append(reduced) - - # allgather - copied_chunks:list[UOp] = [] - for i,rc in enumerate(reduced_chunks): - if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg)) - elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev)))) - else: - chain:list[UOp] = [rc] - for step in range(ndev-1): - chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev])) - copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev)))) - - # reassemble - return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape) +from tinygrad.schedule.allreduce import handle_allreduce # ***** multi rewrite MSELECT/MSTACK ***** @@ -71,7 +19,6 @@ def mstack_early_shrink(ms:UOp, shrink:UOp): return ms.replace(src=tuple(ret)) replace_allreduce = PatternMatcher([ - (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce), # BROADCAST: explicitly expand broadcast copies and combine with MSTACK (UPat(Ops.COPY, name="c", src=(UPat(GroupOp.All-{Ops.CONST}, name="x"), UPat(Ops.DEVICE))), lambda c,x: UOp(Ops.MSTACK, c.dtype, tuple(x.copy_to_device(d) for d in c.device)) if isinstance(c.device, tuple) and isinstance(x.device, str) else None), @@ -87,6 +34,11 @@ replace_allreduce = PatternMatcher([ lambda s,v,ms: v.replace(src=(s.mselect(ms.arg),)+v.src[1:])), ]) +_early_allreduce = PatternMatcher([ + (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce), +]) +if not getenv("LATE_ALLREDUCE", 0): replace_allreduce = _early_allreduce + replace_allreduce + # ***** multi functions ***** def alu_multi(root:UOp): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 5b839b14df..1189702712 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -10,6 +10,7 @@ from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify from tinygrad.codegen.opt import Opt from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op from tinygrad.schedule.multi import multi_pm +from tinygrad.schedule.allreduce import create_allreduce_function # creation can recurse a lot import sys @@ -103,6 +104,10 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # resolve calls (UPat(Ops.CALL, name="c"), resolve_call), + # resolve allreduce (must be bottom up) + (UPat(Ops.ASSIGN, src=(UPat.var("output"), UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"))), create_allreduce_function), + (UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), create_allreduce_function), + # split_reduceop (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop), @@ -369,6 +374,12 @@ def flatten_bufferize(x:UOp): return ret pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)]) +def resolve_anonymous_buffer(ctx:itertools.count, b:UOp, c:UOp) -> UOp|None: + dab = b.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx)),)+b.src[1:]) + nc_src = tuple(dab if x is b else x for x in c.src) + if nc_src == c.src: return None + return dab.after(c.replace(src=nc_src)) + pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([ (UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)), @@ -382,7 +393,10 @@ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([ # remove MOP on AFTER (UPat(Ops.AFTER, src=(UPat.var("x"), UPat(GroupOp.Movement, name="y"))), lambda x,y: x.after(y.src[0])), # remove double AFTER - (UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:])) + (UPat(Ops.AFTER, src=(UPat.var("x"), UPat(Ops.AFTER, name="y"))), lambda x,y: x.after(*y.src[1:])), + + # resolve anonymous buffers + (UPat(Ops.AFTER, src=(UPat(Ops.BUFFER, src=(UPat(Ops.NOOP),), name="b", allow_any_len=True), UPat(Ops.CALL, name="c"))), resolve_anonymous_buffer), ]) pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([ diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index e95d3e1157..829a3631f3 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -926,6 +926,7 @@ def should_resolve_call(c:UOp) -> bool: # don't resolve real kernel calls, sink or program if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return False if c.src[0].op in {Ops.PROGRAM, Ops.LINEAR, Ops.COPY}: return False + if c.arg.precompile: return False return True # ******** ops in python ********