mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Add unsqueeze (#574)
* Add unsqueeze * remove UNSQUEEZE from llops part of readme * make it an hlop
This commit is contained in:
@@ -214,6 +214,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
|
||||
|
||||
def test_unsqueeze(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3))
|
||||
|
||||
def test_flatten(self):
|
||||
for axis in range(3):
|
||||
|
||||
@@ -378,6 +378,9 @@ class Tensor:
|
||||
def permute(self, order, *args): return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args): return mlops.Flip.apply(self, axis=argfix(axis, *args))
|
||||
def slice(self, arg): return mlops.Slice.apply(self, arg=arg)
|
||||
def unsqueeze(self, dim):
|
||||
if dim < 0: dim = len(self.shape) + dim + 1
|
||||
return mlops.Reshape.apply(self, shape=self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
|
||||
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user