tensor roll (#6375)

* tensor roll function and tests

* fix type annotations

* reduce line count

* more readable
This commit is contained in:
Irakli Salia
2024-09-07 01:14:28 +04:00
committed by GitHub
parent dfb818788e
commit 2e01efc35f
2 changed files with 36 additions and 0 deletions

View File

@@ -1375,6 +1375,21 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2)))
helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1)))
def test_roll(self):
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, 0), lambda x: x.roll(-1, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, shifts=(2, 1), dims=(0, 1)), lambda x: x.roll(shifts=(2, 1), dims=(0, 1)))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, 1, 0), lambda x: x.roll(1, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, 1, -1), lambda x: x.roll(1, -1))
helper_test_op([(2, 4)], lambda x: torch.roll(x, -1, -1), lambda x: x.roll(-1, -1))
helper_test_op([(2, 4)], lambda x: torch.roll(x, 5, 0), lambda x: x.roll(5, 0))
helper_test_op([(2, 4)], lambda x: torch.roll(x, -5, 0), lambda x: x.roll(-5, 0))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(2, -3), dims=(0, 2)), lambda x: x.roll(shifts=(2, -3), dims=(0, 2)))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(1, 2, -1), dims=(0, 1, 2)), lambda x: x.roll(shifts=(1, 2, -1), dims=(0, 1, 2)))
helper_test_op([(2, 4)], lambda x: torch.roll(x, 0, 0), lambda x: x.roll(0, 0))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(0, 0), dims=(0, 1)), lambda x: x.roll(shifts=(0, 0), dims=(0, 1)))
helper_test_op([(2, 4, 6)], lambda x: torch.roll(x, shifts=(0, 2), dims=(0, 1)), lambda x: x.roll(shifts=(0, 2), dims=(0, 1)))
def test_detach(self):
helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True)
helper_test_op([()], lambda x: x.detach(), forward_only=True)