diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 4a19c53327..8b22093956 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -41,7 +41,7 @@ def fully_flatten(l): return [l] def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst -def round_up(num, amt:int): return (num+amt-1)//amt * amt +def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF) def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32) def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 13cdac3b5b..bc1b11bec9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -357,17 +357,17 @@ class Tensor: """ assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer" - canonical_devices, bounds = tuple(Device.canonicalize(x) for x in devices), None + devices, bounds = tuple(Device.canonicalize(x) for x in devices), None if axis is not None: if axis < 0: axis += len(self.shape) if splits is None: - sz = round_up(self.shape[axis], len(devices)) // len(devices) - splits = tuple([max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]) + if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}") + sz = round_up(total, len(devices)) // len(devices) + splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))]) assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape" boundaries = tuple(itertools.accumulate(splits)) bounds = tuple(zip((0,) + boundaries, boundaries)) - return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis, bounds), - device=canonical_devices, requires_grad=self.requires_grad) + return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad) def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None): """ @@ -1075,6 +1075,7 @@ class Tensor: if any(abs(st) != 1 for st in strides): strides = tuple(abs(s) for s in strides) # pad shape to multiple of stride + if not all_int(ret.shape): raise RuntimeError("symbolic shape not supprted") ret = ret.pad(tuple((0, round_up(s, st) - s) for s, st in zip(ret.shape, strides))) ret = ret.reshape(tuple(flatten((s // st, st) for s, st in zip(ret.shape, strides)))) ret = ret.shrink(tuple(flatten(((0, s), (0, 1)) for s in ret.shape[::2]))).reshape(ret.shape[::2]) @@ -2100,12 +2101,11 @@ class Tensor: # TODO: someday the optimizer will find this on it's own # for now this is a two stage cumsum SPLIT = 256 - if self.shape[axis] <= SPLIT*2: return self._cumsum(axis) - ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0)) - ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1) + if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis) + ret = self.transpose(axis,-1).pad2d((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1) base_add = ret[..., -1]._cumsum(-1, _first_zero=True) base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1]) - def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1) + def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) return fix(ret) + fix(base_add) @staticmethod