From 1d64c12f2b2454ccb6a0c6d66e0b738e03f34bec Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Mon, 10 Mar 2025 08:01:42 +0800 Subject: [PATCH] add Topk to tensor (#9343) * terrible but somewhat working impl * linux behaves differently than macos? * slightly better impl * small clean up; haven't figured this out yet * better * torch has different behavior on linux and macos for duplicated values * add sum docs * fix test * add torch return_type test * add an exception test * wrap_fxn instead, and move op lower in order * better repeated values test * rerun ci --- docs/tensor/ops.md | 1 + extra/onnx.py | 6 ++++- extra/torch_backend/backend.py | 7 +----- extra/torch_backend/test.py | 9 ++++++- test/external/external_test_onnx_backend.py | 1 - test/test_ops.py | 28 +++++++++++++++++++++ tinygrad/tensor.py | 26 +++++++++++++++++++ 7 files changed, 69 insertions(+), 9 deletions(-) diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index 488b632bbf..4a4a12cb13 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -34,6 +34,7 @@ ::: tinygrad.Tensor.interpolate ::: tinygrad.Tensor.scatter ::: tinygrad.Tensor.scatter_reduce +::: tinygrad.Tensor.topk ## Neural Network (functional) diff --git a/extra/onnx.py b/extra/onnx.py index 8f838b00a2..430b26be92 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -82,7 +82,7 @@ class OnnxNode: # ***** python const ***** required_input_python_consts: dict[str, tuple[int, ...]] = { "Tile": (1,), "Range": (0,1,2), "Expand": (1,), "Reshape": (1,), "Squeeze": (1,), "Unsqueeze": (1,), "Trilu": (1,), "ConstantOfShape": (0,), - "CumSum": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,), + "CumSum": (1,), "TopK": (1,), "Pad": (1,2,3), "MaxUnpool": (2,), "Dropout": (1,2), "CenterCropPad": (1,), "OneHot": (1,), "Compress": (1,), "ImageDecoder": (0,), "AffineGrid": (1,), "Resize": (1,2,3), "Upsample": (1,), "Split": (1,), "Slice": (1,2,3,4), **{"Reduce"+r: (1,) for r in ("Max", "Min", "Sum", "Mean", "SumSquare", "Prod", "L1", "L2", "LogSum", "LogSumExp")}, **{optim: (1,) for optim in ("Adam", "Adagrad", "Momentum")} @@ -522,6 +522,10 @@ def get_onnx_ops(): return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated + def TopK(X:Tensor, K:int|list[int], axis:int=-1, largest:int=1, sorted:int=1): + val, idx = X.topk(K if isinstance(K, int) else K[0], axis, largest, sorted) + return val, idx.cast(dtypes.int64) + # ***** Neural Network Ops ***** def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, training_mode:int=0, spatial=1, is_test=0): diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index b27bb773f0..1066a9f284 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -35,12 +35,6 @@ def masked_select(self, mask): # err, bad return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()])) -@torch.library.impl("aten::topk", "privateuseone") -def topk(self, k, dim=-1, largest=True, sorted=True): - # TODO: move to tinygrad - t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted) - return torch.return_types.topk((t1.tiny(), t2.tiny())) - @torch.library.impl("aten::_index_put_impl_", "privateuseone") def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): # TODO: move to tinygrad @@ -369,6 +363,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other, "aten.linspace": lambda start, stop, steps, dtype=None, **kwargs: Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})), + "aten.topk": Tensor.topk, "aten::view.dtype": lambda self, dtype: self.bitcast(_from_torch_dtype(dtype)), "aten.constant_pad_nd": lambda self, padding, value=0.0: self.pad(padding, mode="constant", value=value), "aten.logsumexp": lambda self, axis, keepdim=False: self.logsumexp(axis[0], keepdim=keepdim), diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index ec5e5ac2e9..39634ca577 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -62,7 +62,7 @@ class TestTorchBackend(unittest.TestCase): a += b np.testing.assert_equal(a.cpu().numpy(), [3,3,3,3]) - def test_exp2(qself): + def test_exp2(self): a = torch.ones(4, device=device) b = a.exp2() np.testing.assert_equal(b.cpu().numpy(), [2,2,2,2]) @@ -91,6 +91,13 @@ class TestTorchBackend(unittest.TestCase): res2 = x ^ y print(res2.cpu()) + def test_topk(self): + # test topk return_types + a = torch.tensor([1, 3, 2, 4], device=device) + out = torch.topk(a, k=2) + np.testing.assert_equal(out.values.cpu().numpy(), [4, 3]) + np.testing.assert_equal(out.indices.cpu().numpy(), [3, 1]) + @unittest.skip("meh") def test_str(self): a = torch.ones(4, device=device) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 0818f8bffb..d12fcd9f88 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -146,7 +146,6 @@ backend_test.exclude('test_sequence_*') backend_test.exclude('test_nonmaxsuppression_*') backend_test.exclude('test_reversesequence_*') backend_test.exclude('test_roialign_*') -backend_test.exclude('test_top_k_*') backend_test.exclude('test_tfidfvectorizer_*') backend_test.exclude('test_stft_*') backend_test.exclude('test_melweightmatrix_*') diff --git a/test/test_ops.py b/test/test_ops.py index 77c9be2580..497a6e9eee 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1047,6 +1047,27 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]]) helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]]) + def test_topk(self): + helper_test_op([(10)], lambda x: x.topk(3).values, lambda x: x.topk(3)[0], forward_only=True) + helper_test_op([(10)], lambda x: x.topk(3).indices.type(torch.int32), lambda x: x.topk(3)[1], forward_only=True) + for dim in [0, 1, -1]: + for largest in [True, False]: + for sorted_ in [True]: # TODO support False + helper_test_op([(10,20,30)], + lambda x: x.topk(5, dim, largest, sorted_).values, + lambda x: x.topk(5, dim, largest, sorted_)[0], forward_only=True) + helper_test_op([(10,20,30)], + lambda x: x.topk(5, dim, largest, sorted_).indices.type(torch.int32), + lambda x: x.topk(5, dim, largest, sorted_)[1], forward_only=True) + # repeated values + value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3) + np.testing.assert_equal(value.numpy(), [1, 1, 1]) + np.testing.assert_equal(indices.numpy(), [0, 1, 3]) + value, indices = Tensor([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0]).topk(3, largest=False) + np.testing.assert_equal(value.numpy(), [0, 0, 0]) + np.testing.assert_equal(indices.numpy(), [2, 4, 6]) + self.helper_test_exception([(4)], lambda x: x.topk(5), lambda x: x.topk(5), expected=(RuntimeError, ValueError)) + def test_einsum(self): # matrix transpose helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a)) @@ -2634,6 +2655,13 @@ class TestOps(unittest.TestCase): lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])), vals=[[1., 2., 3.]]) + @unittest.expectedFailure + def test_gather_failure(self): + # gather with inf values do not work, other values results in nan + helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)), + lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])), + vals=[[-float("inf"), 2., 3.]]) + def test_scatter(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().cpu().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index df2418706f..8bb77a5d83 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2563,6 +2563,32 @@ class Tensor(SimpleMathTrait): return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count) raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'") + def topk(self, k, dim=-1, largest=True, sorted_=True): + """ + Computes the top-k elements of the tensor along the specified `dim`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[0.1, 0.5, 1.2, 3.4, 2.1], [2.2, 1.9, 0.3, 4.5, 0.8]]) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + topk_values, topk_indices = t.topk(2, dim=1) + print(topk_values.numpy()) + print(topk_indices.numpy()) + ``` + """ + if not sorted_: raise NotImplementedError("topk with sorted_=False is not supported") + if k > self.shape[dim]: raise ValueError(f"selected index {k=} is out of range") + x, dim = self, self._resolve_dim(dim) + select_fxn, mask_value = (Tensor.argmax, dtypes.min(self.dtype)) if largest else (Tensor.argmin, dtypes.max(self.dtype)) + indices: list[Tensor] = [] + for _ in range(k): + idx = select_fxn(x, dim, keepdim=True) + indices.append(idx) + x = x.scatter(dim, idx, mask_value) + combined_indices = indices[0].cat(*indices[1:], dim=dim) + return self.gather(dim, combined_indices), combined_indices + # ***** unary ops ***** def logical_not(self) -> Tensor: