mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 07:18:40 -05:00
@@ -1303,10 +1303,10 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)
|
||||
|
||||
def _resolve_dim(self, dim:int, *, outer:bool=False) -> int:
|
||||
if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer):
|
||||
raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}")
|
||||
return dim + self.ndim+outer if dim < 0 else dim
|
||||
def _resolve_dim(self, dim:int, *, extra:bool=False) -> int:
|
||||
total = self.ndim + int(extra)
|
||||
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
||||
return dim + total if dim < 0 else dim
|
||||
|
||||
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
|
||||
"""
|
||||
@@ -1389,7 +1389,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
print(t.unsqueeze(1).numpy())
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim, outer=True)
|
||||
dim = self._resolve_dim(dim, extra=True)
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user