From cededd8eb46b8342fe40444b2980035eb9c4565e Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 6 Jul 2024 21:55:59 -0400 Subject: [PATCH] minor multi cleanup (#5311) add type, move around and some newlines --- tinygrad/multi.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index bd558e86ad..d3df31843b 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, Union, Any, Tuple, List +from typing import Optional, Union, Any, Tuple, List, Dict import functools, itertools, operator from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING from tinygrad.dtype import DType, ConstType @@ -80,7 +80,8 @@ class MultiLazyBuffer: # if we already have a copy on the device, return that for lb in self.real_lbs: if lb.device == device: return lb - return self.lbs[self.real.index(True)].copy_to_device(device) + return self.real_lbs[0].copy_to_device(device) + # copy lbs to device, pad to final shape, and sum llbs:List[LazyBuffer] = [] for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds): if not real: continue @@ -112,14 +113,11 @@ class MultiLazyBuffer: if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs) elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis)) else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis)) - new_real_lbs = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs),new_real)) if r} + new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} # NOTE: const dtype should match real real_dtype = next(iter(new_real_lbs.values())).dtype return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real) - def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: - return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) - def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer: if self.axis is not None and self.axis in axis: # all-reduce on sharded axes @@ -131,6 +129,9 @@ class MultiLazyBuffer: # *** movement ops *** + def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: + return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) + def reshape(self, arg:Tuple[sint, ...]): if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real) arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) @@ -155,13 +156,16 @@ class MultiLazyBuffer: sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis" return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis) return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real) + def expand(self, arg:Tuple[sint, ...]): # NOTE: this assert isn't needed, sharded axis can have dim 1 assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}" return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real) + def permute(self, arg:Tuple[int, ...]): # all permutes supported! return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real) + def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}" if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]): @@ -171,6 +175,7 @@ class MultiLazyBuffer: return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))]) return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs], self.axis, self.real) + def stride(self, arg:Tuple[int, ...]): assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis" return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)