diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index 4a4a12cb13..2a2a872ae4 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -34,6 +34,7 @@ ::: tinygrad.Tensor.interpolate ::: tinygrad.Tensor.scatter ::: tinygrad.Tensor.scatter_reduce +::: tinygrad.Tensor.sort ::: tinygrad.Tensor.topk ## Neural Network (functional) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index ac8f7c7c76..508bcb39da 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -209,6 +209,22 @@ def _copy_from(src: torch.Tensor, dest, non_blocking=False): def cat_out(tensors, dim=0, out=None): unwrap(out).assign(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim)) +@torch.library.impl("aten::topk.values", "privateuseone") +@inplace_fn(["values", "indices"]) +def topk_values(input, k, dim=None, largest=True, sorted=True, values=None, indices=None): + out_values, out_indices = unwrap(input).topk(k, dim if dim is not None else -1, largest, sorted) + unwrap(values).assign(out_values) + unwrap(indices).assign(out_indices.cast(dtypes.int64)) + return wrap(out_values), wrap(out_indices) + +@torch.library.impl("aten::sort.values_stable", "privateuseone") +@inplace_fn(["values", "indices"]) +def sort_values(input, dim=-1, descending=False, stable=True, values=None, indices=None): + out_values, out_indices = unwrap(input).sort(dim, descending) + unwrap(values).assign(out_values) + unwrap(indices).assign(out_indices.cast(dtypes.int64)) + return wrap(out_values), wrap(out_indices) + # register some decompositions from torch._decomp import get_decompositions aten = torch.ops.aten @@ -419,7 +435,6 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other, "aten.linspace": lambda start, stop, steps, dtype=None, **kwargs: Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})), - "aten.topk": Tensor.topk, "aten::view.dtype": lambda self, dtype: self.bitcast(_from_torch_dtype(dtype)), "aten.constant_pad_nd": lambda self, padding, value=0.0: self.pad(padding, mode="constant", value=value), "aten.logsumexp": lambda self, axis, keepdim=False: self.logsumexp(axis[0], keepdim=keepdim), diff --git a/test/test_ops.py b/test/test_ops.py index be1699391e..5fc60c0686 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1047,6 +1047,20 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]]) + def test_sort(self): + for dim in [-1, 0, 1]: + for descending in [True, False]: + helper_test_op([(8,45,65)], lambda x: x.sort(dim, descending).values, lambda x: x.sort(dim, descending)[0], forward_only=True) + helper_test_op([(8,45,65)], lambda x: x.sort(dim, descending).indices.type(torch.int32), lambda x: x.sort(dim, descending)[1], + forward_only=True) + # repeated values + helper_test_op(None, lambda x: x.sort(stable=True).values, lambda x: x.sort()[0], forward_only=True, vals=[[0, 1] * 9]) + helper_test_op(None, lambda x: x.sort(stable=True).indices.type(torch.int32), lambda x: x.sort()[1], forward_only=True, vals=[[0, 1] * 9]) + helper_test_op(None, lambda x: x.sort(stable=True, descending=True).values, + lambda x: x.sort(descending=True)[0], forward_only=True, vals=[[0, 1] * 9]) + helper_test_op(None, lambda x: x.sort(stable=True, descending=True).indices.type(torch.int32), + lambda x: x.sort(descending=True)[1], forward_only=True, vals=[[0, 1] * 9]) + def test_topk(self): helper_test_op([(10)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True) helper_test_op([(10)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2f1d5633b5..5e979a6098 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2563,10 +2563,63 @@ class Tensor(SimpleMathTrait): return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count) raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'") - def topk(self, k, dim=-1, largest=True, sorted_=True): + def sort(self, dim:int=-1, descending:bool=False): + """ + Performs a bitonic sort on the tensor along the specified dimension. + + Order of indices for equivalent elements is always preserved. + + See: https://en.wikipedia.org/wiki/Bitonic_sorter + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]]) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + sorted_values, indices = t.sort(dim=1, descending=True) + print(sorted_values.numpy()) + print(indices.numpy()) + ``` + """ + x, dim = self, self._resolve_dim(dim) + # pad to power of 2 + orig_len = x.shape[dim] + n_stages = math.ceil(math.log2(orig_len)) + fill_value = dtypes.min(x.dtype) if descending else dtypes.max(x.dtype) + pads = tuple((0, 2**n_stages - orig_len) if i == dim else None for i in range(x.ndim)) + x = x.pad(pads, value=fill_value).unflatten(dim, (2,)*n_stages) + # https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort1.svg + for stage in range(1, n_stages+1): + if stage != n_stages: + # flip so arrows of green boxes point the same way as blue boxes + crossover_dim = dim + n_stages - stage - 1 + blue_box, green_box = x.split(1, crossover_dim) + flip_dims = tuple(-i for i in range(1, stage+1+(self.ndim-dim))) + x = (blue_box.cat(green_box.flip(flip_dims), dim=crossover_dim)).contiguous() + for substage in range(stage-1, -1, -1): + partner_dim = dim + n_stages - substage - 1 + x_top, x_bottom = x.split(1, partner_dim) + x_larger, x_smaller = x_top.maximum(x_bottom), x_top.minimum(x_bottom) + x = (x_larger.cat(x_smaller, dim=partner_dim) if descending else x_smaller.cat(x_larger, dim=partner_dim)).contiguous() + if stage != n_stages: + # flip wires back to undo the crossover + blue_box, flipped_green_box = x.split(1, crossover_dim) + x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim) + x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, orig_len) if i == dim else None for i in range(x.ndim))) + # compute indices for sorted values + idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))).expand(x.shape) + def compute_counts(t:Tensor): return ((idx.unsqueeze(dim) <= idx.unsqueeze(dim+1)) & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1) + count_orig, count_sorted = compute_counts(self), compute_counts(x) + cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim)) + idx = (cond * idx.unsqueeze(dim+1)).sum(dim) + return x, idx + + def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True): """ Computes the top-k elements of the tensor along the specified `dim`. + Order of indices for equivalent elements is always preserved. + ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]]) print(t.numpy()) @@ -2578,16 +2631,10 @@ class Tensor(SimpleMathTrait): ``` """ if not sorted_: raise NotImplementedError("topk with sorted_=False is not supported") - if k > self.shape[dim]: raise ValueError(f"selected index {k=} is out of range") - x, dim = self, self._resolve_dim(dim) - select_fxn, mask_value = (Tensor.argmax, dtypes.min(self.dtype)) if largest else (Tensor.argmin, dtypes.max(self.dtype)) - indices: list[Tensor] = [] - for _ in range(k): - idx = select_fxn(x, dim, keepdim=True) - indices.append(idx) - x = x.scatter(dim, idx, mask_value) - combined_indices = indices[0].cat(*indices[1:], dim=dim) - return self.gather(dim, combined_indices), combined_indices + if k > self.shape[dim:=self._resolve_dim(dim)]: raise ValueError(f"selected index {k=} is out of range") + x, idx = self.sort(dim, descending=largest) + shrink_to_k = tuple((0, k) if i == dim else None for i in range(self.ndim)) + return x.shrink(shrink_to_k), idx.shrink(shrink_to_k) # ***** unary ops *****