mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Add bitonic cat sort (#9422)
* poc * repeated values fail, sigh * is this being timed out? * fix up down names * bitonic v2, does this run? * bitonic v3, faster * bitonic v3.1, faster * bitonic v3.1.1, same speed unlucky * support dim and indices * bitonic v3.2, simpler code, TODO repeated indices * bruv gimme green for once cmon * cat (stack) implementation, slow but maybe one day when cat is fast meow * revert to v3.2 * bitonic v4, who let the cats out edition * clean up variable names * figured out repeated indices :D * ruff check --fix * use sort for topk * add Tensor.sort everywhere * fix docs and add some types * slightly better variable names * am I doing torch inplace correctly? * delegate sort to values_stable * add a contig, faster first sort * maybe don't test_inplace --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1047,6 +1047,20 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
def test_sort(self):
|
||||
for dim in [-1, 0, 1]:
|
||||
for descending in [True, False]:
|
||||
helper_test_op([(8,45,65)], lambda x: x.sort(dim, descending).values, lambda x: x.sort(dim, descending)[0], forward_only=True)
|
||||
helper_test_op([(8,45,65)], lambda x: x.sort(dim, descending).indices.type(torch.int32), lambda x: x.sort(dim, descending)[1],
|
||||
forward_only=True)
|
||||
# repeated values
|
||||
helper_test_op(None, lambda x: x.sort(stable=True).values, lambda x: x.sort()[0], forward_only=True, vals=[[0, 1] * 9])
|
||||
helper_test_op(None, lambda x: x.sort(stable=True).indices.type(torch.int32), lambda x: x.sort()[1], forward_only=True, vals=[[0, 1] * 9])
|
||||
helper_test_op(None, lambda x: x.sort(stable=True, descending=True).values,
|
||||
lambda x: x.sort(descending=True)[0], forward_only=True, vals=[[0, 1] * 9])
|
||||
helper_test_op(None, lambda x: x.sort(stable=True, descending=True).indices.type(torch.int32),
|
||||
lambda x: x.sort(descending=True)[1], forward_only=True, vals=[[0, 1] * 9])
|
||||
|
||||
def test_topk(self):
|
||||
helper_test_op([(10)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True)
|
||||
helper_test_op([(10)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True)
|
||||
|
||||
Reference in New Issue
Block a user