Add unsqueeze (#574)

* Add unsqueeze

* remove UNSQUEEZE from llops part of readme

* make it an hlop
This commit is contained in:
Connor Henderson
2023-02-20 23:14:59 -05:00
committed by GitHub
parent cfad2902d5
commit 9670bf1fd1
2 changed files with 9 additions and 0 deletions

View File

@@ -214,6 +214,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1)))
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3)))
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
def test_unsqueeze(self):
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4))
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1))
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3))
def test_flatten(self):
for axis in range(3):

View File

@@ -378,6 +378,9 @@ class Tensor:
def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args))
def slice(self, arg): return mlops.Slice.apply(self, arg=arg)
def unsqueeze(self, dim):
if dim < 0: dim = len(self.shape) + dim + 1
return mlops.Reshape.apply(self, shape=self.shape[:dim] + (1,) + self.shape[dim:])
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) # type: ignore