From f2f7384b670268506816eadf0116f4f255f5f2b0 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 16 Nov 2024 11:05:39 -0500 Subject: [PATCH] _resolve_dim cleanup (#7736) no duplicated self.ndim+outer --- tinygrad/tensor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 20c1cceb61..94b5aefa58 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1303,10 +1303,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method final_shape = [r*s for r,s in zip(repeats, base_shape)] return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape) - def _resolve_dim(self, dim:int, *, outer:bool=False) -> int: - if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer): - raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}") - return dim + self.ndim+outer if dim < 0 else dim + def _resolve_dim(self, dim:int, *, extra:bool=False) -> int: + total = self.ndim + int(extra) + if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") + return dim + total if dim < 0 else dim def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]: """ @@ -1389,7 +1389,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.unsqueeze(1).numpy()) ``` """ - dim = self._resolve_dim(dim, outer=True) + dim = self._resolve_dim(dim, extra=True) return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) @property