mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Added tensor.squeeze and support for testing exceptions (#1241)
* WIP: `tensor.squeeze` function * Added `test_except` param to `helper_test_op` to avoid false positives * Extracted new method `helper_test_exception` for testing exceptions * Made `squeeze` not throw IndexError when ndim == 0 and dim <= 0 to match PyTorch
This commit is contained in:
@@ -350,6 +350,13 @@ class Tensor:
|
||||
slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num))
|
||||
return [self.slice(p) for p in slice_params]
|
||||
|
||||
def squeeze(self, dim=None):
|
||||
if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
|
||||
if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior
|
||||
if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})")
|
||||
if dim < 0: dim += self.ndim
|
||||
return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim])
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
if dim < 0: dim = len(self.shape) + dim + 1
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
Reference in New Issue
Block a user