fix sort for small dim (#11601)

* fix sort for small dim

* fixed test_sort_empty
This commit is contained in:
chenyu
2025-08-09 22:17:41 -07:00
committed by GitHub
parent ef17af85c6
commit dfb702ef33
3 changed files with 4 additions and 2 deletions

View File

@@ -1093,6 +1093,9 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]])
def test_sort(self):
for shape in [(0,), (0,5), (1,), (1,5)]:
helper_test_op([shape], lambda x: x.sort(0).values, lambda x: x.sort(0)[0], forward_only=True)
helper_test_op([shape], lambda x: x.sort(0).indices.type(torch.int32), lambda x: x.sort(0)[1], forward_only=True)
for dim in [-1, 0, 1]:
for descending in [True, False]:
helper_test_op([(8,8,6)], lambda x: x.sort(dim, descending).values, lambda x: x.sort(dim, descending)[0], forward_only=True)