Added test_chunk and fixed (#1283)

This commit is contained in:
Umut Zengin
2023-07-20 05:21:26 +03:00
committed by GitHub
parent 3f2497160c
commit 74e63fe4ee
2 changed files with 29 additions and 4 deletions

View File

@@ -98,6 +98,31 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True)
helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True)
def test_chunk(self):
tor = torch.arange(13).repeat(8, 1).chunk(6, 1)
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1)
assert len(tor) == len(ten)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
tor = torch.arange(13).repeat(8, 1).chunk(6, 0)
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 0)
assert len(tor) == len(ten)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
tor = torch.arange(13).repeat(8, 1).chunk(3, -1)
ten = Tensor.arange(13).repeat((8, 1)).chunk(3, -1)
assert len(tor) == len(ten)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
tor = torch.arange(13).repeat(8, 3, 3).chunk(3, -2)
ten = Tensor.arange(13).repeat((8, 3, 3)).chunk(3, -2)
assert len(tor) == len(ten)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
def test_arange(self):
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(10, 5, 3), forward_only=True)