diff --git a/test/null/test_sort.py b/test/null/test_sort.py new file mode 100644 index 0000000000..79641cdb25 --- /dev/null +++ b/test/null/test_sort.py @@ -0,0 +1,44 @@ +import unittest +from tinygrad import Tensor +from tinygrad.helpers import GlobalCounters, Context + +class TestSortComplexity(unittest.TestCase): + def _sort_values_ops(self, n:int) -> int: + t = Tensor.randn(n, device="NULL").realize() + GlobalCounters.reset() + t.sort()[0].realize() + return GlobalCounters.global_ops + + def _sort_indices_ops(self, n:int) -> int: + t = Tensor.randn(n, device="NULL").realize() + GlobalCounters.reset() + t.sort()[1].realize() + return GlobalCounters.global_ops + + def _sort_both_ops(self, n:int) -> int: + t = Tensor.randn(n, device="NULL").realize() + values, indices = t.sort() + GlobalCounters.reset() + Tensor.realize(values, indices) + return GlobalCounters.global_ops + + def test_sort_values_complexity_small_noopt(self): + with Context(NOOPT=1, SPLIT_REDUCEOP=0): + ops_64 = self._sort_values_ops(64) + ops_256 = self._sort_values_ops(256) + self.assertLess(ops_256, int(ops_64*7.2), f"value sort growth too high with NOOPT=1 SPLIT_REDUCEOP=0: {ops_64=} {ops_256=}") + + def test_sort_indices_complexity_small_noopt(self): + with Context(NOOPT=1, SPLIT_REDUCEOP=0): + ops_64 = self._sort_indices_ops(64) + ops_256 = self._sort_indices_ops(256) + self.assertLess(ops_256, int(ops_64*8.0), f"index sort growth too high with NOOPT=1 SPLIT_REDUCEOP=0: {ops_64=} {ops_256=}") + + def test_sort_corealize_values_indices_noopt(self): + with Context(NOOPT=1, SPLIT_REDUCEOP=0): + indices_ops = self._sort_indices_ops(256) + both_ops = self._sort_both_ops(256) + self.assertLess(both_ops, int(indices_ops*1.2), f"co-realize should share sort work with NOOPT=1 SPLIT_REDUCEOP=0: {indices_ops=} {both_ops=}") + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 605142ed9d..6b9a857935 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2756,36 +2756,39 @@ class Tensor(OpMixin): """ x, dim = self, self._resolve_dim(dim) if (orig_len := int(x.shape[dim])) <= 1: return x, x.zeros_like(dtype=dtypes.default_int) + idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))).expand(x.shape) # pad to power of 2 n_stages = (orig_len-1).bit_length() pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim)) x = x.pad(pads, value=dtypes.min(x.dtype) if descending else dtypes.max(x.dtype)).unflatten(dim, (2,)*n_stages) + idx = idx.pad(pads, value=orig_len).unflatten(dim, (2,)*n_stages) # https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg for stage in range(1, n_stages+1): if stage != n_stages: # flip so arrows of green boxes point the same way as blue boxes crossover_dim = dim + n_stages - stage - 1 blue_box, green_box = x.split(1, crossover_dim) + blue_idx, green_idx = idx.split(1, crossover_dim) flip_dims = tuple(-i for i in range(1, stage+1+(self.ndim-dim))) x = (blue_box.cat(green_box.flip(flip_dims), dim=crossover_dim)).contiguous() + idx = blue_idx.cat(green_idx.flip(flip_dims), dim=crossover_dim) for substage in range(stage-1, -1, -1): partner_dim = dim + n_stages - substage - 1 x_top, x_bottom = x.split(1, partner_dim) + idx_top, idx_bottom = idx.split(1, partner_dim) x_larger, x_smaller = x_top.maximum(x_bottom), x_top.minimum(x_bottom) - x = (x_larger.cat(x_smaller, dim=partner_dim) if descending else x_smaller.cat(x_larger, dim=partner_dim)).contiguous() + # stable tie-break: for equal values, lower original index comes first + top_goes_first = ((x_top > x_bottom) if descending else (x_top < x_bottom)) | ((x_top == x_bottom) & (idx_top < idx_bottom)) + idx_first, idx_second = top_goes_first.where(idx_top, idx_bottom), top_goes_first.where(idx_bottom, idx_top) + idx = idx_first.cat(idx_second, dim=partner_dim).contiguous() + x = Tensor.cat(*([x_larger, x_smaller] if descending else [x_smaller, x_larger]), dim=partner_dim).contiguous() if stage != n_stages: # flip wires back to undo the crossover blue_box, flipped_green_box = x.split(1, crossover_dim) + blue_idx, flipped_green_idx = idx.split(1, crossover_dim) x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim) - x = x.flatten(dim, dim+n_stages-1).shrink_to(self.shape) - # compute indices for sorted values - mask = Tensor.ones(orig_len, orig_len, dtype=dtypes.bool, device=self.device).tril().reshape((None, None) + (1,)*(self.ndim-dim-1)) - def compute_counts(t:Tensor): return (mask & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1) - count_orig, count_sorted = compute_counts(self), compute_counts(x) - cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim)) - idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))) - idx = (cond * idx.unsqueeze(dim+1)).sum(dim) - return x, idx + idx = blue_idx.cat(flipped_green_idx.flip(flip_dims), dim=crossover_dim) + return x.flatten(dim, dim+n_stages-1).shrink_to(self.shape), idx.flatten(dim, dim+n_stages-1).shrink_to(self.shape) def argsort(self, dim:int=-1, descending:bool=False) -> Tensor: """