minimal mstack pr to fix allreduce (#10649)

* minimal mstack pr to fix allreduce

* fix webgpu
This commit is contained in:
George Hotz
2025-06-05 15:14:53 -07:00
committed by GitHub
parent 4c315f8e17
commit baba274a76
7 changed files with 47 additions and 15 deletions

View File

@@ -4,7 +4,6 @@ from tinygrad.helpers import Context
from tinygrad.uop.ops import Ops
class TestRingAllReduce(unittest.TestCase):
@unittest.skip("still broken")
def test_schedule_ring(self):
with Context(RING=2):
N = 4

View File

@@ -58,7 +58,7 @@ def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
return base.mselect(ms.arg).view(st)
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT}
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
@@ -123,6 +123,10 @@ replace_contiguous = PatternMatcher([
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
st = unwrap(view.st)
# always realize unsafe pad ops before masked view
@@ -139,8 +143,8 @@ do_realize = PatternMatcher([
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
# realize before COPY and MSELECT
(UPat((Ops.COPY, Ops.MSELECT), src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), allow_any_len=True), realize),
# realize parents of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
])
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
@@ -243,7 +247,7 @@ def create_kernel(x:UOp, b:UOp|None=None):
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return buffer.assign(kernel).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MULTI}
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI}
def append_to_kernel(x:UOp):
new_srcs: list[UOp] = []
metadata = x.arg.metadata
@@ -402,6 +406,8 @@ fix_kernel_ops = PatternMatcher([
replace_globals = PatternMatcher([
# replace ASSIGN with the target BUFFER
(UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
(UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
])
def fix_kernel_ast(k:UOp) -> UOp|None:
@@ -411,7 +417,11 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
# push views to edges
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
# replace buffer with define_global + add load/store last
bufs = tuple(s.buf_uop if s.op is not Ops.MSELECT else s.src[0].buf_uop for s in k.src)
bufs = []
for s in k.src:
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s.buf_uop)
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
@@ -513,7 +523,7 @@ def limit_bufs(root:UOp):
# count number of unique buffers flowing into this op
bufs: set[UOp] = set()
def gate_input(u:UOp):
if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN})): bufs.add(u)
if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
return not is_load
root.toposort(gate=gate_input)
# NOTE: this -1 is for the output buffer

View File

@@ -1,3 +1,4 @@
from typing import cast
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp
@@ -38,9 +39,19 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
reduced_chunks.append(reduced_chunk)
# allgather + reassemble
# allgather
copied_chunks = []
for i,c in enumerate(reduced_chunks):
this_chunk = [None] * len(buf.device)
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
for step in range(n_lbs-1):
dest = (i+step)%n_lbs
this_chunk[dest] = c = c.copy_to_device(buf.device[dest])
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
# reassemble
pads = [((s,numel-e),) for s,e in chunks]
return functools.reduce(operator.add, [c.copy_to_device(buf.device).pad(pad) for pad,c in zip(pads, reduced_chunks)]).reshape(shape)
return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape)
replace_allreduce = PatternMatcher([(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),])

View File

@@ -48,10 +48,13 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
if s.op is Ops.ASSIGN:
children[s.src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.MSELECT:
if s.src[0].op is not Ops.BUFFER:
children[s.src[0].src[1]].append(k)
in_degree[k] += 1
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
if ss.op is not Ops.BUFFER:
assert ss.op is Ops.ASSIGN
children[ss.src[1]].append(k)
in_degree[k] += 1
elif s.op is Ops.BUFFER:
pass # a BUFFER is already realized, nothing to do here
else:

View File

@@ -531,12 +531,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.MSELECT:
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
return self.src[0].device[self.arg]
if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src)
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src))
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
return self.src[0].base
@property
@@ -549,6 +551,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ret = self.src[0].buffer
assert isinstance(ret, MultiBuffer)
return ret.bufs[self.arg]
if self.op is Ops.MSTACK:
ret = MultiBuffer.__new__(MultiBuffer)
ret.bufs = [cast(Buffer, x.buffer) for x in self.src]
return ret
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
if (cret:=buffers.get(self)) is not None: return cret
rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base

View File

@@ -50,13 +50,16 @@ def validate_kernel(k:UOp):
assign_spec = PatternMatcher([
# KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT)), name="k"), validate_kernel),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK)), name="k"), validate_kernel),
# ASSIGN has a target and a value. It can also optionally depend on other assigns
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
# MSELECT chooses one of the multi buffers
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
# MSTACK combines buffers into multi
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
])
# *** this is the spec of a Tensor in UOp ***

View File

@@ -15,7 +15,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.IGNORE: "#00C000",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0"}
Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"}
# VIZ API