mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Union, Any, Tuple, List
|
||||
from typing import Optional, Union, Any, Tuple, List, Dict
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, dedup, round_up, prod, DEBUG, RING
|
||||
from tinygrad.dtype import DType, ConstType
|
||||
@@ -80,7 +80,8 @@ class MultiLazyBuffer:
|
||||
# if we already have a copy on the device, return that
|
||||
for lb in self.real_lbs:
|
||||
if lb.device == device: return lb
|
||||
return self.lbs[self.real.index(True)].copy_to_device(device)
|
||||
return self.real_lbs[0].copy_to_device(device)
|
||||
# copy lbs to device, pad to final shape, and sum
|
||||
llbs:List[LazyBuffer] = []
|
||||
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
||||
if not real: continue
|
||||
@@ -112,14 +113,11 @@ class MultiLazyBuffer:
|
||||
if mlb.axis == axis or not_all_real: srcs.append(mlb.lbs)
|
||||
elif mlb.axis is None and axis is not None: srcs.append(to_sharded(mlb.lbs, axis))
|
||||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis))
|
||||
new_real_lbs = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs),new_real)) if r}
|
||||
new_real_lbs:Dict[int,LazyBuffer] = {i:lsrcs[0].e(op, *lsrcs[1:], arg=arg) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
||||
# NOTE: const dtype should match real
|
||||
real_dtype = next(iter(new_real_lbs.values())).dtype
|
||||
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const(0).cast(real_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
|
||||
|
||||
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
||||
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
||||
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
||||
if self.axis is not None and self.axis in axis:
|
||||
# all-reduce on sharded axes
|
||||
@@ -131,6 +129,9 @@ class MultiLazyBuffer:
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
||||
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
||||
|
||||
def reshape(self, arg:Tuple[sint, ...]):
|
||||
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
|
||||
arg_acc:List[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
||||
@@ -155,13 +156,16 @@ class MultiLazyBuffer:
|
||||
sum(lb.shape[self.axis] for i,lb in enumerate(self.lbs) if i > self.real.index(True))), "can only pad to whole axis"
|
||||
return MultiLazyBuffer([x if r else x.const(0) for x,r in zip(self.lbs, self.real)], self.axis)
|
||||
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
|
||||
|
||||
def expand(self, arg:Tuple[sint, ...]):
|
||||
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
||||
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
|
||||
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
|
||||
|
||||
def permute(self, arg:Tuple[int, ...]):
|
||||
# all permutes supported!
|
||||
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
|
||||
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]):
|
||||
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
|
||||
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
|
||||
@@ -171,6 +175,7 @@ class MultiLazyBuffer:
|
||||
return MultiLazyBuffer([lb if i==idx else lb.const(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
|
||||
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
|
||||
self.axis, self.real)
|
||||
|
||||
def stride(self, arg:Tuple[int, ...]):
|
||||
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
||||
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
|
||||
|
||||
Reference in New Issue
Block a user