user _resolve_dim in argmax (#3846)

also added comment of the behavior if there are multple, and more tests
This commit is contained in:
chenyu
2024-03-20 20:17:30 -04:00
committed by GitHub
parent 5c4cf62d2c
commit f271cd682b
2 changed files with 9 additions and 2 deletions

View File

@@ -540,14 +540,20 @@ class TestOps(unittest.TestCase):
helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2))
def test_argmax(self):
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max
# check if it returns the first index for multiple occurences
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy())
np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), np.array(0))
np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), np.array(1))
helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
def test_argmin(self):
# check if it returns the first index for multiple occurences
self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), np.array(0))
np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), np.array(1))
helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(0, False).type(torch.int32), lambda x: x.argmin(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True)

View File

@@ -633,10 +633,11 @@ class Tensor:
return m - ss.log()
def argmax(self, axis=None, keepdim=False):
# NOTE: return the first index if there are multiple occurrences of the maximum values
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
axis = axis + len(self.shape) if axis < 0 else axis
axis = self._resolve_dim(axis)
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)