mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
type annotation to round_up (#6898)
* type annotation to round_up also cleaned up places where round_up was potentially called on symbolic * fix
This commit is contained in:
@@ -41,7 +41,7 @@ def fully_flatten(l):
|
||||
return [l]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def round_up(num, amt:int): return (num+amt-1)//amt * amt
|
||||
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
||||
def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
|
||||
def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
|
||||
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
||||
|
||||
@@ -357,17 +357,17 @@ class Tensor:
|
||||
|
||||
"""
|
||||
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
|
||||
canonical_devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
|
||||
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
|
||||
if axis is not None:
|
||||
if axis < 0: axis += len(self.shape)
|
||||
if splits is None:
|
||||
sz = round_up(self.shape[axis], len(devices)) // len(devices)
|
||||
splits = tuple([max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))])
|
||||
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
|
||||
sz = round_up(total, len(devices)) // len(devices)
|
||||
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
|
||||
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
||||
boundaries = tuple(itertools.accumulate(splits))
|
||||
bounds = tuple(zip((0,) + boundaries, boundaries))
|
||||
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis, bounds),
|
||||
device=canonical_devices, requires_grad=self.requires_grad)
|
||||
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
|
||||
"""
|
||||
@@ -1075,6 +1075,7 @@ class Tensor:
|
||||
if any(abs(st) != 1 for st in strides):
|
||||
strides = tuple(abs(s) for s in strides)
|
||||
# pad shape to multiple of stride
|
||||
if not all_int(ret.shape): raise RuntimeError("symbolic shape not supprted")
|
||||
ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides)))
|
||||
ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides))))
|
||||
ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2])
|
||||
@@ -2100,12 +2101,11 @@ class Tensor:
|
||||
# TODO: someday the optimizer will find this on it's own
|
||||
# for now this is a two stage cumsum
|
||||
SPLIT = 256
|
||||
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
|
||||
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
|
||||
ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
||||
if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis)
|
||||
ret = self.transpose(axis,-1).pad2d((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1)
|
||||
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)
|
||||
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
||||
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1)
|
||||
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return fix(ret) + fix(base_add)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user