diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 20c1cceb61..94b5aefa58 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1303,10 +1303,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method final_shape = [r*s for r,s in zip(repeats, base_shape)] return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape) - def _resolve_dim(self, dim:int, *, outer:bool=False) -> int: - if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer): - raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}") - return dim + self.ndim+outer if dim < 0 else dim + def _resolve_dim(self, dim:int, *, extra:bool=False) -> int: + total = self.ndim + int(extra) + if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") + return dim + total if dim < 0 else dim def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]: """ @@ -1389,7 +1389,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.unsqueeze(1).numpy()) ``` """ - dim = self._resolve_dim(dim, outer=True) + dim = self._resolve_dim(dim, extra=True) return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) @property