_resolve_dim cleanup (#7736)

no duplicated self.ndim+outer
This commit is contained in:
chenyu
2024-11-16 11:05:39 -05:00
committed by GitHub
parent e777211a00
commit f2f7384b67

View File

@@ -1303,10 +1303,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
return dim + self.ndim+outer if dim < 0 else dim
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
total = self.ndim + int(extra)
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
return dim + total if dim < 0 else dim
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
"""
@@ -1389,7 +1389,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(t.unsqueeze(1).numpy())
```
"""
dim = self._resolve_dim(dim, outer=True)
dim = self._resolve_dim(dim, extra=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
@property