fix argmax/min on int32 min (#8118)

This commit is contained in:
chenyu
2024-12-09 02:29:23 -05:00
committed by GitHub
parent c814de2dd4
commit ccf54c2375
2 changed files with 25 additions and 11 deletions

View File

@@ -844,7 +844,8 @@ class TestOps(unittest.TestCase):
def test_argmax(self):
# check if it returns the first index for multiple occurences
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy())
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[2, 2]])
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[1, 2, 2]])
np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), np.array(0))
np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), np.array(1))
helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True)
@@ -854,9 +855,16 @@ class TestOps(unittest.TestCase):
# regression test for bitwise_not then argmax
helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]])
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[-2**31, 0]])
# NOTE: torch does not support this on bool
helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[True, False]])
def test_argmin(self):
# check if it returns the first index for multiple occurences
self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[2, 2]])
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[3, 2, 2]])
np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), np.array(0))
np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), np.array(1))
helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True)
@@ -864,6 +872,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, True).type(torch.int32), lambda x: x.argmin(1, True), forward_only=True)
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[-2**31, 0]])
# NOTE: torch does not support this on bool
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]])
def test_einsum(self):
# matrix transpose
helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
@@ -1099,10 +1113,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
helper_test_op([()], lambda x: x.min())
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[True, False]])
def test_max(self):
helper_test_op([(45,3)], lambda x: x.max())
@@ -1111,10 +1125,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
helper_test_op([()], lambda x: x.max())
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[True, False]])
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
def test_any(self):

View File

@@ -1885,7 +1885,7 @@ class Tensor(SimpleMathTrait):
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
```
"""
return (-self).argmax(axis=axis, keepdim=keepdim)
return (-self if self.is_floating_point() else ~self).argmax(axis=axis, keepdim=keepdim)
def rearrange(self, formula: str, **sizes) -> Tensor:
"""