Revert "improve Tensor.sort indices (#15070)" (#15072)

This reverts commit e3003631f2.
This commit is contained in:
chenyu
2026-02-28 14:40:30 -05:00
committed by GitHub
parent e3003631f2
commit fe0fa8333b
2 changed files with 10 additions and 57 deletions

View File

@@ -1,44 +0,0 @@
import unittest
from tinygrad import Tensor
from tinygrad.helpers import GlobalCounters, Context
class TestSortComplexity(unittest.TestCase):
def _sort_values_ops(self, n:int) -> int:
t = Tensor.randn(n, device="NULL").realize()
GlobalCounters.reset()
t.sort()[0].realize()
return GlobalCounters.global_ops
def _sort_indices_ops(self, n:int) -> int:
t = Tensor.randn(n, device="NULL").realize()
GlobalCounters.reset()
t.sort()[1].realize()
return GlobalCounters.global_ops
def _sort_both_ops(self, n:int) -> int:
t = Tensor.randn(n, device="NULL").realize()
values, indices = t.sort()
GlobalCounters.reset()
Tensor.realize(values, indices)
return GlobalCounters.global_ops
def test_sort_values_complexity_small_noopt(self):
with Context(NOOPT=1, SPLIT_REDUCEOP=0):
ops_64 = self._sort_values_ops(64)
ops_256 = self._sort_values_ops(256)
self.assertLess(ops_256, int(ops_64*7.2), f"value sort growth too high with NOOPT=1 SPLIT_REDUCEOP=0: {ops_64=} {ops_256=}")
def test_sort_indices_complexity_small_noopt(self):
with Context(NOOPT=1, SPLIT_REDUCEOP=0):
ops_64 = self._sort_indices_ops(64)
ops_256 = self._sort_indices_ops(256)
self.assertLess(ops_256, int(ops_64*8.0), f"index sort growth too high with NOOPT=1 SPLIT_REDUCEOP=0: {ops_64=} {ops_256=}")
def test_sort_corealize_values_indices_noopt(self):
with Context(NOOPT=1, SPLIT_REDUCEOP=0):
indices_ops = self._sort_indices_ops(256)
both_ops = self._sort_both_ops(256)
self.assertLess(both_ops, int(indices_ops*1.2), f"co-realize should share sort work with NOOPT=1 SPLIT_REDUCEOP=0: {indices_ops=} {both_ops=}")
if __name__ == '__main__':
unittest.main()

View File

@@ -2756,39 +2756,36 @@ class Tensor(OpMixin):
"""
x, dim = self, self._resolve_dim(dim)
if (orig_len := int(x.shape[dim])) <= 1: return x, x.zeros_like(dtype=dtypes.default_int)
idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))).expand(x.shape)
# pad to power of 2
n_stages = (orig_len-1).bit_length()
pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim))
x = x.pad(pads, value=dtypes.min(x.dtype) if descending else dtypes.max(x.dtype)).unflatten(dim, (2,)*n_stages)
idx = idx.pad(pads, value=orig_len).unflatten(dim, (2,)*n_stages)
# https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg
for stage in range(1, n_stages+1):
if stage != n_stages:
# flip so arrows of green boxes point the same way as blue boxes
crossover_dim = dim + n_stages - stage - 1
blue_box, green_box = x.split(1, crossover_dim)
blue_idx, green_idx = idx.split(1, crossover_dim)
flip_dims = tuple(-i for i in range(1, stage+1+(self.ndim-dim)))
x = (blue_box.cat(green_box.flip(flip_dims), dim=crossover_dim)).contiguous()
idx = blue_idx.cat(green_idx.flip(flip_dims), dim=crossover_dim)
for substage in range(stage-1, -1, -1):
partner_dim = dim + n_stages - substage - 1
x_top, x_bottom = x.split(1, partner_dim)
idx_top, idx_bottom = idx.split(1, partner_dim)
x_larger, x_smaller = x_top.maximum(x_bottom), x_top.minimum(x_bottom)
# stable tie-break: for equal values, lower original index comes first
top_goes_first = ((x_top > x_bottom) if descending else (x_top < x_bottom)) | ((x_top == x_bottom) & (idx_top < idx_bottom))
idx_first, idx_second = top_goes_first.where(idx_top, idx_bottom), top_goes_first.where(idx_bottom, idx_top)
idx = idx_first.cat(idx_second, dim=partner_dim).contiguous()
x = Tensor.cat(*([x_larger, x_smaller] if descending else [x_smaller, x_larger]), dim=partner_dim).contiguous()
x = (x_larger.cat(x_smaller, dim=partner_dim) if descending else x_smaller.cat(x_larger, dim=partner_dim)).contiguous()
if stage != n_stages:
# flip wires back to undo the crossover
blue_box, flipped_green_box = x.split(1, crossover_dim)
blue_idx, flipped_green_idx = idx.split(1, crossover_dim)
x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim)
idx = blue_idx.cat(flipped_green_idx.flip(flip_dims), dim=crossover_dim)
return x.flatten(dim, dim+n_stages-1).shrink_to(self.shape), idx.flatten(dim, dim+n_stages-1).shrink_to(self.shape)
x = x.flatten(dim, dim+n_stages-1).shrink_to(self.shape)
# compute indices for sorted values
mask = Tensor.ones(orig_len, orig_len, dtype=dtypes.bool, device=self.device).tril().reshape((None, None) + (1,)*(self.ndim-dim-1))
def compute_counts(t:Tensor): return (mask & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1)
count_orig, count_sorted = compute_counts(self), compute_counts(x)
cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim))
idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim)))
idx = (cond * idx.unsqueeze(dim+1)).sum(dim)
return x, idx
def argsort(self, dim:int=-1, descending:bool=False) -> Tensor:
"""