From fe0fa8333b243b2c3c09f9ca014eff8df3c185a1 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 28 Feb 2026 14:40:30 -0500 Subject: [PATCH] Revert "improve Tensor.sort indices (#15070)" (#15072) This reverts commit e3003631f2dc87ef60f0b52bd5c750b3cc91618d. --- test/null/test_sort.py | 44 ------------------------------------------ tinygrad/tensor.py | 23 ++++++++++------------ 2 files changed, 10 insertions(+), 57 deletions(-) delete mode 100644 test/null/test_sort.py diff --git a/test/null/test_sort.py b/test/null/test_sort.py deleted file mode 100644 index 79641cdb25..0000000000 --- a/test/null/test_sort.py +++ /dev/null @@ -1,44 +0,0 @@ -import unittest -from tinygrad import Tensor -from tinygrad.helpers import GlobalCounters, Context - -class TestSortComplexity(unittest.TestCase): - def _sort_values_ops(self, n:int) -> int: - t = Tensor.randn(n, device="NULL").realize() - GlobalCounters.reset() - t.sort()[0].realize() - return GlobalCounters.global_ops - - def _sort_indices_ops(self, n:int) -> int: - t = Tensor.randn(n, device="NULL").realize() - GlobalCounters.reset() - t.sort()[1].realize() - return GlobalCounters.global_ops - - def _sort_both_ops(self, n:int) -> int: - t = Tensor.randn(n, device="NULL").realize() - values, indices = t.sort() - GlobalCounters.reset() - Tensor.realize(values, indices) - return GlobalCounters.global_ops - - def test_sort_values_complexity_small_noopt(self): - with Context(NOOPT=1, SPLIT_REDUCEOP=0): - ops_64 = self._sort_values_ops(64) - ops_256 = self._sort_values_ops(256) - self.assertLess(ops_256, int(ops_64*7.2), f"value sort growth too high with NOOPT=1 SPLIT_REDUCEOP=0: {ops_64=} {ops_256=}") - - def test_sort_indices_complexity_small_noopt(self): - with Context(NOOPT=1, SPLIT_REDUCEOP=0): - ops_64 = self._sort_indices_ops(64) - ops_256 = self._sort_indices_ops(256) - self.assertLess(ops_256, int(ops_64*8.0), f"index sort growth too high with NOOPT=1 SPLIT_REDUCEOP=0: {ops_64=} {ops_256=}") - - def test_sort_corealize_values_indices_noopt(self): - with Context(NOOPT=1, SPLIT_REDUCEOP=0): - indices_ops = self._sort_indices_ops(256) - both_ops = self._sort_both_ops(256) - self.assertLess(both_ops, int(indices_ops*1.2), f"co-realize should share sort work with NOOPT=1 SPLIT_REDUCEOP=0: {indices_ops=} {both_ops=}") - -if __name__ == '__main__': - unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6b9a857935..605142ed9d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2756,39 +2756,36 @@ class Tensor(OpMixin): """ x, dim = self, self._resolve_dim(dim) if (orig_len := int(x.shape[dim])) <= 1: return x, x.zeros_like(dtype=dtypes.default_int) - idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))).expand(x.shape) # pad to power of 2 n_stages = (orig_len-1).bit_length() pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim)) x = x.pad(pads, value=dtypes.min(x.dtype) if descending else dtypes.max(x.dtype)).unflatten(dim, (2,)*n_stages) - idx = idx.pad(pads, value=orig_len).unflatten(dim, (2,)*n_stages) # https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg for stage in range(1, n_stages+1): if stage != n_stages: # flip so arrows of green boxes point the same way as blue boxes crossover_dim = dim + n_stages - stage - 1 blue_box, green_box = x.split(1, crossover_dim) - blue_idx, green_idx = idx.split(1, crossover_dim) flip_dims = tuple(-i for i in range(1, stage+1+(self.ndim-dim))) x = (blue_box.cat(green_box.flip(flip_dims), dim=crossover_dim)).contiguous() - idx = blue_idx.cat(green_idx.flip(flip_dims), dim=crossover_dim) for substage in range(stage-1, -1, -1): partner_dim = dim + n_stages - substage - 1 x_top, x_bottom = x.split(1, partner_dim) - idx_top, idx_bottom = idx.split(1, partner_dim) x_larger, x_smaller = x_top.maximum(x_bottom), x_top.minimum(x_bottom) - # stable tie-break: for equal values, lower original index comes first - top_goes_first = ((x_top > x_bottom) if descending else (x_top < x_bottom)) | ((x_top == x_bottom) & (idx_top < idx_bottom)) - idx_first, idx_second = top_goes_first.where(idx_top, idx_bottom), top_goes_first.where(idx_bottom, idx_top) - idx = idx_first.cat(idx_second, dim=partner_dim).contiguous() - x = Tensor.cat(*([x_larger, x_smaller] if descending else [x_smaller, x_larger]), dim=partner_dim).contiguous() + x = (x_larger.cat(x_smaller, dim=partner_dim) if descending else x_smaller.cat(x_larger, dim=partner_dim)).contiguous() if stage != n_stages: # flip wires back to undo the crossover blue_box, flipped_green_box = x.split(1, crossover_dim) - blue_idx, flipped_green_idx = idx.split(1, crossover_dim) x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim) - idx = blue_idx.cat(flipped_green_idx.flip(flip_dims), dim=crossover_dim) - return x.flatten(dim, dim+n_stages-1).shrink_to(self.shape), idx.flatten(dim, dim+n_stages-1).shrink_to(self.shape) + x = x.flatten(dim, dim+n_stages-1).shrink_to(self.shape) + # compute indices for sorted values + mask = Tensor.ones(orig_len, orig_len, dtype=dtypes.bool, device=self.device).tril().reshape((None, None) + (1,)*(self.ndim-dim-1)) + def compute_counts(t:Tensor): return (mask & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1) + count_orig, count_sorted = compute_counts(self), compute_counts(x) + cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim)) + idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))) + idx = (cond * idx.unsqueeze(dim+1)).sum(dim) + return x, idx def argsort(self, dim:int=-1, descending:bool=False) -> Tensor: """