mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use UOp methods in allreduce.py [pr] (#15137)
except the one line with Ops.BUFFER and Ops.NOOP, not sure what that's for
This commit is contained in:
@@ -19,7 +19,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
|
||||
# naive: copy to all devices. if you shrink later, that'll be handled
|
||||
if not use_ring and not use_all2all:
|
||||
return functools.reduce(lambda x,y: x.alu(red.arg, y), [UOp(Ops.COPY, buf.dtype, (buf.mselect(i), red.src[1])) for i in range(ndev)])
|
||||
return functools.reduce(lambda x,y: x.alu(red.arg, y), [buf.mselect(i).copy_to_device(red.src[1]) for i in range(ndev)])
|
||||
|
||||
# chunk data into ndev pieces
|
||||
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
||||
@@ -44,12 +44,12 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
copied_chunks:list[UOp] = []
|
||||
for i,rc in enumerate(reduced_chunks):
|
||||
if isinstance(red.src[1].arg, str): copied_chunks.append(rc.copy_to_device(red.src[1].arg))
|
||||
elif use_all2all: copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
|
||||
elif use_all2all: copied_chunks.append(UOp.mstack(*(rc.copy_to_device(buf.device[j]) for j in range(ndev))))
|
||||
else:
|
||||
chain:list[UOp] = [rc]
|
||||
for step in range(ndev-1):
|
||||
chain.append(rc := rc.copy_to_device(buf.device[(i+step)%ndev]))
|
||||
copied_chunks.append(UOp(Ops.MSTACK, buf.dtype, tuple(chain[(j-i+1)%ndev] for j in range(ndev))))
|
||||
copied_chunks.append(UOp.mstack(*(chain[(j-i+1)%ndev] for j in range(ndev))))
|
||||
|
||||
# reassemble
|
||||
return UOp.sum(*[c.pad(((s,numel-e),)) for (s,e),c in zip(chunks, copied_chunks)]).reshape(shape)
|
||||
@@ -59,5 +59,5 @@ def create_allreduce_function(buf:UOp, red:UOp, output:UOp|None=None) -> UOp|Non
|
||||
if output is None: output = UOp(Ops.BUFFER, red.dtype, (UOp(Ops.NOOP), red.src[1]), red.size).reshape(red.shape)
|
||||
to = red.param_like(0)
|
||||
src = buf.param_like(1)
|
||||
red = UOp(Ops.ALLREDUCE, dtype=red.dtype, src=(src, red.src[1]), arg=red.arg)
|
||||
red = src.allreduce(red.arg, red.src[1])
|
||||
return output.after(to.assign(handle_allreduce(src, red)).sink().call(output, buf.contiguous(), name="allreduce", precompile=True))
|
||||
|
||||
@@ -558,6 +558,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
inp = self if arg is None else UOp(Ops.MSELECT, self.dtype, src=(self,), arg=arg)
|
||||
return UOp(Ops.COPY, self.dtype, (inp, UOp(Ops.DEVICE, arg=device) if not isinstance(device, UOp) else device))
|
||||
def mselect(self, arg:int) -> UOp: return UOp(Ops.MSELECT, self.dtype, (self,), arg)
|
||||
def mstack(self, *srcs: UOp) -> UOp: return UOp(Ops.MSTACK, self.dtype, (self,)+srcs)
|
||||
@property
|
||||
def metadata(self) -> tuple[Metadata, ...]|None: return all_metadata.get(self, None)
|
||||
def encdec(self, *src, arg=None): return UOp(Ops.ENCDEC, self.dtype, src=(self,)+src, arg=arg)
|
||||
|
||||
Reference in New Issue
Block a user