Added tensor.squeeze and support for testing exceptions (#1241)

* WIP: `tensor.squeeze` function

* Added `test_except` param to `helper_test_op` to avoid false positives

* Extracted new method `helper_test_exception` for testing exceptions

* Made `squeeze` not throw IndexError when ndim == 0 and dim <= 0 to match PyTorch
This commit is contained in:
Stan
2023-07-15 09:33:24 +02:00
committed by GitHub
parent a8f3b3f4ed
commit 264d467f2b
2 changed files with 41 additions and 8 deletions

View File

@@ -350,6 +350,13 @@ class Tensor:
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
return [self.slice(p) for p in slice_params]
def squeeze(self, dim=None):
if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior
if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})")
if dim < 0: dim += self.ndim
return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim])
def unsqueeze(self, dim):
if dim < 0: dim = len(self.shape) + dim + 1
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])