mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
user _resolve_dim in argmax (#3846)
also added comment of the behavior if there are multple, and more tests
This commit is contained in:
@@ -540,14 +540,20 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2))
|
||||
|
||||
def test_argmax(self):
|
||||
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max
|
||||
# check if it returns the first index for multiple occurences
|
||||
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy())
|
||||
np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), np.array(0))
|
||||
np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), np.array(1))
|
||||
helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
|
||||
|
||||
def test_argmin(self):
|
||||
# check if it returns the first index for multiple occurences
|
||||
self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
|
||||
np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), np.array(0))
|
||||
np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), np.array(1))
|
||||
helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(0, False).type(torch.int32), lambda x: x.argmin(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True)
|
||||
|
||||
@@ -633,10 +633,11 @@ class Tensor:
|
||||
return m - ss.log()
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
# NOTE: return the first index if there are multiple occurrences of the maximum values
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
axis = self._resolve_dim(axis)
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
|
||||
|
||||
Reference in New Issue
Block a user