check Tensor.permute input arg is a valid permutation (#5069)

also added support of negative axes
This commit is contained in:
chenyu
2024-06-20 10:01:28 -04:00
committed by GitHub
parent 24c89a2a33
commit f4355d0f1b
2 changed files with 9 additions and 1 deletions

View File

@@ -1112,9 +1112,15 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3)], lambda x: x.T)
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2))
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2))
def test_permute(self):
helper_test_op([(1,2,3,4)], lambda x: x.permute((3,0,2,1)))
helper_test_op([(3,4,5,6)], lambda x: x.permute((3,2,1,0)))
helper_test_op([(3,4,5,6)], lambda x: x.permute((-2,-1,1,0)))
helper_test_op([()], lambda x: x.permute(()))
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,2)), lambda x: x.permute((0,2)), expected=RuntimeError)
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,1,2,3,3,3)), lambda x: x.permute((0,1,2,3,3,3)), expected=RuntimeError)
self.helper_test_exception([(3,4,5,6)], lambda x: x.permute((0,0,1,2,3)), lambda x: x.permute((0,0,1,2,3)), expected=RuntimeError)
def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,3,6,6)))