From 106d18b7929e250c81edecfc18c69dbbda97298a Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 4 Mar 2026 17:15:33 -0500 Subject: [PATCH] use UOp methods in allreduce.py [pr] (#15137) except the one line with Ops.BUFFER and Ops.NOOP, not sure what that's for --- tinygrad/schedule/allreduce.py | 8 ++++---- tinygrad/uop/ops.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tinygrad/schedule/allreduce.py b/tinygrad/schedule/allreduce.py index 70c09a5ea8..6c75a4f48d 100644 --- a/tinygrad/schedule/allreduce.py +++ b/tinygrad/schedule/allreduce.py @@ -19,7 +19,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: # naive: copy to all devices. if you shrink later, that'll be handled if not use_ring and not use_all2all: - return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)]) + return functools.reduce(lambda x,y: x.alu(red.arg, y), [buf.mselect(i).copy_to_device(red.src[1]) for i in range(ndev)]) # chunk data into ndev pieces factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1) @@ -44,12 +44,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None: copied_chunks:list[UOp] = [] for i,rc in enumerate(reduced_chunks): if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg)) - elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev)))) + elif use_all2all: copied_chunks.append(UOp.mstack(*(rc.copy_to_device(buf.device[j]) for j in range(ndev)))) else: chain:list[UOp] = [rc] for step in range(ndev-1): chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev])) - copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev)))) + copied_chunks.append(UOp.mstack(*(chain[(j-i+1)%ndev] for j in range(ndev)))) # reassemble return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape) @@ -59,5 +59,5 @@ def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|Non if output is None: output = UOp(Ops.BUFFER, red.dtype, (UOp(Ops.NOOP), red.src[1]), red.size).reshape(red.shape) to = red.param_like(0) src = buf.param_like(1) - red = UOp(Ops.ALLREDUCE, dtype=red.dtype, src=(src, red.src[1]), arg=red.arg) + red = src.allreduce(red.arg, red.src[1]) return output.after(to.assign(handle_allreduce(src, red)).sink().call(output, buf.contiguous(), name="allreduce", precompile=True)) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index a40e4aa2d9..31b19f4262 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -558,6 +558,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg) return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device)) def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg) + def mstack(self, *srcs: UOp) -> UOp: return UOp(Ops.MSTACK, self.dtype, (self,)+srcs) @property def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None) def encdec(self, *src, arg=None): return UOp(Ops.ENCDEC, self.dtype, src=(self,)+src, arg=arg)