minor multi cleanup (#5311)

add type, move around and some newlines
This commit is contained in:
chenyu
2024-07-06 21:55:59 -04:00
committed by GitHub
parent 8a99514462
commit cededd8eb4

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Union, Any, Tuple, List
from typing import Optional, Union, Any, Tuple, List, Dict
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
from tinygrad.dtype import DType, ConstType
@@ -80,7 +80,8 @@ class MultiLazyBuffer:
# if we already have a copy on the device, return that
for lb in self.real_lbs:
if lb.device == device: return lb
return self.lbs[self.real.index(True)].copy_to_device(device)
return self.real_lbs[0].copy_to_device(device)
# copy lbs to device, pad to final shape, and sum
llbs:List[LazyBuffer] = []
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
if not real: continue
@@ -112,14 +113,11 @@ class MultiLazyBuffer:
if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
new_real_lbs = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs),new_real)) if r}
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
# NOTE: const dtype should match real
real_dtype = next(iter(new_real_lbs.values())).dtype
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
if self.axis is not None and self.axis in axis:
# all-reduce on sharded axes
@@ -131,6 +129,9 @@ class MultiLazyBuffer:
# *** movement ops ***
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
def reshape(self, arg:Tuple[sint, ...]):
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
@@ -155,13 +156,16 @@ class MultiLazyBuffer:
sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
def expand(self, arg:Tuple[sint, ...]):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
def permute(self, arg:Tuple[int, ...]):
# all permutes supported!
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
@@ -171,6 +175,7 @@ class MultiLazyBuffer:
return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
self.axis, self.real)
def stride(self, arg:Tuple[int, ...]):
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)