diff --git a/test/unit/test_allreduce.py b/test/unit/test_allreduce.py index 7eb54bd9d3..a30309660a 100644 --- a/test/unit/test_allreduce.py +++ b/test/unit/test_allreduce.py @@ -4,7 +4,6 @@ from tinygrad.helpers import Context from tinygrad.uop.ops import Ops class TestRingAllReduce(unittest.TestCase): - @unittest.skip("still broken") def test_schedule_ring(self): with Context(RING=2): N = 4 diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 2c9884820d..340989660b 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -58,7 +58,7 @@ def mselect_reorder_view(ms:UOp, view:UOp, base:UOp): st = st.substitute({dnums[0]:dnums[0].const_like(ms.arg)}) return base.mselect(ms.arg).view(st) -ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT} +ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK} sym = symbolic_simple+PatternMatcher([ # UOp with size 0 is zero @@ -123,6 +123,10 @@ replace_contiguous = PatternMatcher([ def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None +def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None: + for s in rb.src: + if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None + def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None: st = unwrap(view.st) # always realize unsafe pad ops before masked view @@ -139,8 +143,8 @@ do_realize = PatternMatcher([ (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}, name="tr"), realize), # realize before expand or unsafe pad ops (UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view), - # realize before COPY and MSELECT - (UPat((Ops.COPY, Ops.MSELECT), src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), allow_any_len=True), realize), + # realize parents of COPY, MSELECT, MSTACK + (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents), ]) def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None], @@ -243,7 +247,7 @@ def create_kernel(x:UOp, b:UOp|None=None): buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset)) return buffer.assign(kernel).reshape(x.shape) -DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MULTI} +DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI} def append_to_kernel(x:UOp): new_srcs: list[UOp] = [] metadata = x.arg.metadata @@ -402,6 +406,8 @@ fix_kernel_ops = PatternMatcher([ replace_globals = PatternMatcher([ # replace ASSIGN with the target BUFFER (UPat(Ops.ASSIGN, src=(UPat(Ops.BUFFER), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]), + # HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?) + (UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]), ]) def fix_kernel_ast(k:UOp) -> UOp|None: @@ -411,7 +417,11 @@ def fix_kernel_ast(k:UOp) -> UOp|None: # push views to edges ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right") # replace buffer with define_global + add load/store last - bufs = tuple(s.buf_uop if s.op is not Ops.MSELECT else s.src[0].buf_uop for s in k.src) + bufs = [] + for s in k.src: + # traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only + while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0] + bufs.append(s.buf_uop) ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer") if ast.op is Ops.SINK and not all_same([x.device for x in k.src]): raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}") @@ -513,7 +523,7 @@ def limit_bufs(root:UOp): # count number of unique buffers flowing into this op bufs: set[UOp] = set() def gate_input(u:UOp): - if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN})): bufs.add(u) + if (is_load:=(u.op in {Ops.BUFFER, Ops.GBARRIER, Ops.ASSIGN, Ops.MSTACK})): bufs.add(u) return not is_load root.toposort(gate=gate_input) # NOTE: this -1 is for the output buffer diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index d76b538dc7..82a9dadc5f 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -1,3 +1,4 @@ +from typing import cast import functools, itertools, operator from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp @@ -38,9 +39,19 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: .alu(red.arg, chunk.copy_to_device(buf.device[dest], dest)) reduced_chunks.append(reduced_chunk) - # allgather + reassemble + # allgather + copied_chunks = [] + for i,c in enumerate(reduced_chunks): + this_chunk = [None] * len(buf.device) + this_chunk[(i+len(buf.device)-1)%n_lbs] = c + for step in range(n_lbs-1): + dest = (i+step)%n_lbs + this_chunk[dest] = c = c.copy_to_device(buf.device[dest]) + copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(cast(list[UOp], this_chunk)))) + + # reassemble pads = [((s,numel-e),) for s,e in chunks] - return functools.reduce(operator.add, [c.copy_to_device(buf.device).pad(pad) for pad,c in zip(pads, reduced_chunks)]).reshape(shape) + return functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads, copied_chunks)]).reshape(shape) replace_allreduce = PatternMatcher([(UPat(Ops.ALLREDUCE, src=(UPat.var("buf"), UPat()), name="red"), handle_allreduce),]) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ee7df3113f..3aefd61fb2 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -48,10 +48,13 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[ if s.op is Ops.ASSIGN: children[s.src[1]].append(k) in_degree[k] += 1 - elif s.op is Ops.MSELECT: - if s.src[0].op is not Ops.BUFFER: - children[s.src[0].src[1]].append(k) - in_degree[k] += 1 + elif s.op in {Ops.MSELECT, Ops.MSTACK}: + for ss in s.src: + if ss.op is Ops.MSELECT: ss = ss.src[0] + if ss.op is not Ops.BUFFER: + assert ss.op is Ops.ASSIGN + children[ss.src[1]].append(k) + in_degree[k] += 1 elif s.op is Ops.BUFFER: pass # a BUFFER is already realized, nothing to do here else: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 3de6b5390f..ff8a8c497c 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -531,12 +531,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.MSELECT: assert isinstance(self.src[0].device, tuple), "mselect must be on tuple device" return self.src[0].device[self.arg] + if self.op is Ops.MSTACK: return tuple(cast(str, x.device) for x in self.src) if self.op in {Ops.COPY, Ops.BUFFER, Ops.ALLREDUCE}: return self.src[1].device return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: if self.op is Ops.BUFFER: return self if self.op is Ops.MSELECT: return self.src[0].buf_uop.mselect(self.arg) + if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.buf_uop for x in self.src)) assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}" return self.src[0].base @property @@ -549,6 +551,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): ret = self.src[0].buffer assert isinstance(ret, MultiBuffer) return ret.bufs[self.arg] + if self.op is Ops.MSTACK: + ret = MultiBuffer.__new__(MultiBuffer) + ret.bufs = [cast(Buffer, x.buffer) for x in self.src] + return ret assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" if (cret:=buffers.get(self)) is not None: return cret rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index e7c02cc844..d4928c4bc5 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -50,13 +50,16 @@ def validate_kernel(k:UOp): assign_spec = PatternMatcher([ # KERNEL can attach to an ASSIGN to describe the compute required to realize a BUFFER - (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT)), name="k"), validate_kernel), + (UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.MSELECT, Ops.MSTACK)), name="k"), validate_kernel), # ASSIGN has a target and a value. It can also optionally depend on other assigns (UPat(Ops.ASSIGN, name="x"), lambda x: len(x.src) >= 2 and all(s.op is Ops.ASSIGN for s in x.src[2:])), # MSELECT chooses one of the multi buffers (UPat(Ops.MSELECT, name="x"), lambda x: isinstance(x.src[0].device, tuple) and x.arg < len(x.src[0].device)), + + # MSTACK combines buffers into multi + (UPat(Ops.MSTACK, name="x"), lambda x: all(isinstance(x.device, str) for x in x.src)), ]) # *** this is the spec of a Tensor in UOp *** diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index f15aa45252..186c0d38bc 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -15,7 +15,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.IGNORE: "#00C000", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500", - Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0"} + Ops.ALLREDUCE: "#ff40a0", Ops.GBARRIER: "#FFC14D", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0"} # VIZ API