mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
minimal mstack pr to fix allreduce (#10649)
* minimal mstack pr to fix allreduce * fix webgpu
This commit is contained in:
@@ -4,7 +4,6 @@ from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import Ops
|
||||
|
||||
class TestRingAllReduce(unittest.TestCase):
|
||||
@unittest.skip("still broken")
|
||||
def test_schedule_ring(self):
|
||||
with Context(RING=2):
|
||||
N = 4
|
||||
|
||||
@@ -58,7 +58,7 @@ def mselect_reorder_view(ms:UOp, view:UOp, base:UOp):
|
||||
st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)})
|
||||
return base.mselect(ms.arg).view(st)
|
||||
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT}
|
||||
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK}
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
@@ -123,6 +123,10 @@ replace_contiguous = PatternMatcher([
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||
for s in rb.src:
|
||||
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
|
||||
st = unwrap(view.st)
|
||||
# always realize unsafe pad ops before masked view
|
||||
@@ -139,8 +143,8 @@ do_realize = PatternMatcher([
|
||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
||||
# realize before COPY and MSELECT
|
||||
(UPat((Ops.COPY, Ops.MSELECT), src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), allow_any_len=True), realize),
|
||||
# realize parents of COPY, MSELECT, MSTACK
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
|
||||
])
|
||||
|
||||
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
|
||||
@@ -243,7 +247,7 @@ def create_kernel(x:UOp, b:UOp|None=None):
|
||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
||||
return buffer.assign(kernel).reshape(x.shape)
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MULTI}
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI}
|
||||
def append_to_kernel(x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = x.arg.metadata
|
||||
@@ -402,6 +406,8 @@ fix_kernel_ops = PatternMatcher([
|
||||
replace_globals = PatternMatcher([
|
||||
# replace ASSIGN with the target BUFFER
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
||||
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
|
||||
])
|
||||
|
||||
def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||
@@ -411,7 +417,11 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
||||
# push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
||||
# replace buffer with define_global + add load/store last
|
||||
bufs = tuple(s.buf_uop if s.op is not Ops.MSELECT else s.src[0].buf_uop for s in k.src)
|
||||
bufs = []
|
||||
for s in k.src:
|
||||
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
||||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
||||
bufs.append(s.buf_uop)
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||
@@ -513,7 +523,7 @@ def limit_bufs(root:UOp):
|
||||
# count number of unique buffers flowing into this op
|
||||
bufs: set[UOp] = set()
|
||||
def gate_input(u:UOp):
|
||||
if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN})): bufs.add(u)
|
||||
if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u)
|
||||
return not is_load
|
||||
root.toposort(gate=gate_input)
|
||||
# NOTE: this -1 is for the output buffer
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import cast
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp
|
||||
@@ -38,9 +39,19 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
.alu(red.arg, chunk.copy_to_device(buf.device[dest], dest))
|
||||
reduced_chunks.append(reduced_chunk)
|
||||
|
||||
# allgather + reassemble
|
||||
# allgather
|
||||
copied_chunks = []
|
||||
for i,c in enumerate(reduced_chunks):
|
||||
this_chunk = [None] * len(buf.device)
|
||||
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
|
||||
for step in range(n_lbs-1):
|
||||
dest = (i+step)%n_lbs
|
||||
this_chunk[dest] = c = c.copy_to_device(buf.device[dest])
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk))))
|
||||
|
||||
# reassemble
|
||||
pads = [((s,numel-e),) for s,e in chunks]
|
||||
return functools.reduce(operator.add, [c.copy_to_device(buf.device).pad(pad) for pad,c in zip(pads, reduced_chunks)]).reshape(shape)
|
||||
return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape)
|
||||
|
||||
replace_allreduce = PatternMatcher([(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),])
|
||||
|
||||
|
||||
@@ -48,10 +48,13 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
if s.op is Ops.ASSIGN:
|
||||
children[s.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.MSELECT:
|
||||
if s.src[0].op is not Ops.BUFFER:
|
||||
children[s.src[0].src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op in {Ops.MSELECT, Ops.MSTACK}:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
if ss.op is not Ops.BUFFER:
|
||||
assert ss.op is Ops.ASSIGN
|
||||
children[ss.src[1]].append(k)
|
||||
in_degree[k] += 1
|
||||
elif s.op is Ops.BUFFER:
|
||||
pass # a BUFFER is already realized, nothing to do here
|
||||
else:
|
||||
|
||||
@@ -531,12 +531,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MSELECT:
|
||||
assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device"
|
||||
return self.src[0].device[self.arg]
|
||||
if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src)
|
||||
if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device
|
||||
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op is Ops.BUFFER: return self
|
||||
if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg)
|
||||
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src))
|
||||
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
|
||||
return self.src[0].base
|
||||
@property
|
||||
@@ -549,6 +551,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
ret = self.src[0].buffer
|
||||
assert isinstance(ret, MultiBuffer)
|
||||
return ret.bufs[self.arg]
|
||||
if self.op is Ops.MSTACK:
|
||||
ret = MultiBuffer.__new__(MultiBuffer)
|
||||
ret.bufs = [cast(Buffer, x.buffer) for x in self.src]
|
||||
return ret
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base
|
||||
|
||||
@@ -50,13 +50,16 @@ def validate_kernel(k:UOp):
|
||||
|
||||
assign_spec = PatternMatcher([
|
||||
# KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT)), name="k"), validate_kernel),
|
||||
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK)), name="k"), validate_kernel),
|
||||
|
||||
# ASSIGN has a target and a value. It can also optionally depend on other assigns
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])),
|
||||
|
||||
# MSELECT chooses one of the multi buffers
|
||||
(UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)),
|
||||
|
||||
# MSTACK combines buffers into multi
|
||||
(UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)),
|
||||
])
|
||||
|
||||
# *** this is the spec of a Tensor in UOp ***
|
||||
|
||||
@@ -15,7 +15,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.IGNORE: "#00C000",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0"}
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
|
||||
Reference in New Issue
Block a user