diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 00ca34a710..e4172569d5 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -32,7 +32,7 @@ def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 e def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0] def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s) def ansilen(s:str): return len(ansistrip(s)) -def make_tuple(x:Union[int, Tuple[int, ...]], cnt) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x +def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x) def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist] def fully_flatten(l): if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e09252e57b..02e93e1475 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1447,10 +1447,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method if self.ndim == 0: if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]") axis = () - axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis)) - axis_ = tuple(self._resolve_dim(x) for x in axis_) - ret = fxn.apply(self, axis=axis_) - return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis_)) + axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1))) + ret = fxn.apply(self, axis=axis) + return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): """