From f271cd682b60aab4c28ed8583c85778971e84a15 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 20 Mar 2024 20:17:30 -0400 Subject: [PATCH] user _resolve_dim in argmax (#3846) also added comment of the behavior if there are multple, and more tests --- test/test_ops.py | 8 +++++++- tinygrad/tensor.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 21f59b0af0..035647fb69 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -540,14 +540,20 @@ class TestOps(unittest.TestCase): helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2)) def test_argmax(self): - self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max + # check if it returns the first index for multiple occurences + self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) + np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), np.array(0)) + np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), np.array(1)) helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True) def test_argmin(self): + # check if it returns the first index for multiple occurences self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy()) + np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), np.array(0)) + np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), np.array(1)) helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmin(0, False).type(torch.int32), lambda x: x.argmin(0, False), forward_only=True) helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a0ebae39a2..365dfa248a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -633,10 +633,11 @@ class Tensor: return m - ss.log() def argmax(self, axis=None, keepdim=False): + # NOTE: return the first index if there are multiple occurrences of the maximum values if axis is None: idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape) return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32) - axis = axis + len(self.shape) if axis < 0 else axis + axis = self._resolve_dim(axis) m = self == self.max(axis=axis, keepdim=True) idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)