From 9670bf1fd1f8f158d80ab32118aeabaa83f455dc Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Mon, 20 Feb 2023 23:14:59 -0500 Subject: [PATCH] Add unsqueeze (#574) * Add unsqueeze * remove UNSQUEEZE from llops part of readme * make it an hlop --- test/test_ops.py | 6 ++++++ tinygrad/tensor.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 1775bce711..0467ec2736 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -214,6 +214,12 @@ class TestOps(unittest.TestCase): helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1))) helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3))) helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,))) + + def test_unsqueeze(self): + helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) + helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4)) + helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1)) + helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3)) def test_flatten(self): for axis in range(3): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 21013330a4..6ef058103f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -378,6 +378,9 @@ class Tensor: def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args)) def slice(self, arg): return mlops.Slice.apply(self, arg=arg) + def unsqueeze(self, dim): + if dim < 0: dim = len(self.shape) + dim + 1 + return mlops.Reshape.apply(self, shape=self.shape[:dim] + (1,) + self.shape[dim:]) def linear(self, weight:Tensor, bias:Optional[Tensor]=None): x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) # type: ignore