use UOp methods in allreduce.py [pr] (#15137)

except the one line with Ops.BUFFER and Ops.NOOP, not sure what that's for
This commit is contained in:
chenyu
2026-03-04 17:15:33 -05:00
committed by GitHub
parent 34594bcaaf
commit 106d18b792
2 changed files with 5 additions and 4 deletions

View File

@@ -19,7 +19,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
# naive: copy to all devices. if you shrink later, that'll be handled
if not use_ring and not use_all2all:
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
return functools.reduce(lambda x,y: x.alu(red.arg, y), [buf.mselect(i).copy_to_device(red.src[1]) for i in range(ndev)])
# chunk data into ndev pieces
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
@@ -44,12 +44,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
copied_chunks:list[UOp] = []
for i,rc in enumerate(reduced_chunks):
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
elif use_all2all: copied_chunks.append(UOp.mstack(*(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
else:
chain:list[UOp] = [rc]
for step in range(ndev-1):
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
copied_chunks.append(UOp.mstack(*(chain[(j-i+1)%ndev] for j in range(ndev))))
# reassemble
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
@@ -59,5 +59,5 @@ def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|Non
if output is None: output = UOp(Ops.BUFFER, red.dtype, (UOp(Ops.NOOP), red.src[1]), red.size).reshape(red.shape)
to = red.param_like(0)
src = buf.param_like(1)
red = UOp(Ops.ALLREDUCE, dtype=red.dtype, src=(src, red.src[1]), arg=red.arg)
red = src.allreduce(red.arg, red.src[1])
return output.after(to.assign(handle_allreduce(src, red)).sink().call(output, buf.contiguous(), name="allreduce", precompile=True))

View File

@@ -558,6 +558,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg)
def mstack(self, *srcs: UOp) -> UOp: return UOp(Ops.MSTACK, self.dtype, (self,)+srcs)
@property
def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None)
def encdec(self, *src, arg=None): return UOp(Ops.ENCDEC, self.dtype, src=(self,)+src, arg=arg)