smaller inputs for test_sort and test_topk (#10829)

This commit is contained in:
chenyu
2025-06-15 21:21:15 -07:00
committed by GitHub
parent c0329148c7
commit e5d5ae55f9

View File

@@ -1093,8 +1093,8 @@ class TestOps(unittest.TestCase):
def test_sort(self):
for dim in [-1, 0, 1]:
for descending in [True, False]:
helper_test_op([(8,45,6)], lambda x: x.sort(dim, descending).values, lambda x: x.sort(dim, descending)[0], forward_only=True)
helper_test_op([(8,45,6)], lambda x: x.sort(dim, descending).indices.type(torch.int32), lambda x: x.sort(dim, descending)[1],
helper_test_op([(8,8,6)], lambda x: x.sort(dim, descending).values, lambda x: x.sort(dim, descending)[0], forward_only=True)
helper_test_op([(8,8,6)], lambda x: x.sort(dim, descending).indices.type(torch.int32), lambda x: x.sort(dim, descending)[1],
forward_only=True)
# repeated values
helper_test_op(None, lambda x: x.sort(stable=True).values, lambda x: x.sort()[0], forward_only=True, vals=[[0, 1] * 9])
@@ -1110,12 +1110,12 @@ class TestOps(unittest.TestCase):
for dim in [0, 1, -1]:
for largest in [True, False]:
for sorted_ in [True]: # TODO support False
helper_test_op([(10,12,6)],
lambda x: x.topk(5, dim, largest, sorted_).values,
lambda x: x.topk(5, dim, largest, sorted_)[0], forward_only=True)
helper_test_op([(10,12,6)],
lambda x: x.topk(5, dim, largest, sorted_).indices.type(torch.int32),
lambda x: x.topk(5, dim, largest, sorted_)[1], forward_only=True)
helper_test_op([(6,5,4)],
lambda x: x.topk(4, dim, largest, sorted_).values,
lambda x: x.topk(4, dim, largest, sorted_)[0], forward_only=True)
helper_test_op([(5,5,4)],
lambda x: x.topk(4, dim, largest, sorted_).indices.type(torch.int32),
lambda x: x.topk(4, dim, largest, sorted_)[1], forward_only=True)
# repeated values
value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3)
np.testing.assert_equal(value.numpy(), [1, 1, 1])