Add bitonic cat sort (#9422)

* poc

* repeated values fail, sigh

* is this being timed out?

* fix up down names

* bitonic v2, does this run?

* bitonic v3, faster

* bitonic v3.1, faster

* bitonic v3.1.1, same speed unlucky

* support dim and indices

* bitonic v3.2, simpler code, TODO repeated indices

* bruv gimme green for once cmon

* cat (stack) implementation, slow but maybe one day when cat is fast meow

* revert to v3.2

* bitonic v4, who let the cats out edition

* clean up variable names

* figured out repeated indices :D

* ruff check --fix

* use sort for topk

* add Tensor.sort everywhere

* fix docs and add some types

* slightly better variable names

* am I doing torch inplace correctly?

* delegate sort to values_stable

* add a contig, faster first sort

* maybe don't test_inplace

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2025-03-18 00:01:23 +08:00
committed by GitHub
parent f53be010d7
commit 53d6f1e1bb
4 changed files with 89 additions and 12 deletions

View File

@@ -1047,6 +1047,20 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]])
def test_sort(self):
for dim in [-1, 0, 1]:
for descending in [True, False]:
helper_test_op([(8,45,65)], lambda x: x.sort(dim, descending).values, lambda x: x.sort(dim, descending)[0], forward_only=True)
helper_test_op([(8,45,65)], lambda x: x.sort(dim, descending).indices.type(torch.int32), lambda x: x.sort(dim, descending)[1],
forward_only=True)
# repeated values
helper_test_op(None, lambda x: x.sort(stable=True).values, lambda x: x.sort()[0], forward_only=True, vals=[[0, 1] * 9])
helper_test_op(None, lambda x: x.sort(stable=True).indices.type(torch.int32), lambda x: x.sort()[1], forward_only=True, vals=[[0, 1] * 9])
helper_test_op(None, lambda x: x.sort(stable=True, descending=True).values,
lambda x: x.sort(descending=True)[0], forward_only=True, vals=[[0, 1] * 9])
helper_test_op(None, lambda x: x.sort(stable=True, descending=True).indices.type(torch.int32),
lambda x: x.sort(descending=True)[1], forward_only=True, vals=[[0, 1] * 9])
def test_topk(self):
helper_test_op([(10)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True)
helper_test_op([(10)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True)