mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix argmax/min on int32 min (#8118)
This commit is contained in:
@@ -844,7 +844,8 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_argmax(self):
|
||||
# check if it returns the first index for multiple occurences
|
||||
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy())
|
||||
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[2, 2]])
|
||||
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[1, 2, 2]])
|
||||
np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), np.array(0))
|
||||
np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), np.array(1))
|
||||
helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True)
|
||||
@@ -854,9 +855,16 @@ class TestOps(unittest.TestCase):
|
||||
# regression test for bitwise_not then argmax
|
||||
helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]])
|
||||
|
||||
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[-2**31, 0]])
|
||||
# NOTE: torch does not support this on bool
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
def test_argmin(self):
|
||||
# check if it returns the first index for multiple occurences
|
||||
self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
|
||||
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[2, 2]])
|
||||
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[3, 2, 2]])
|
||||
np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), np.array(0))
|
||||
np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), np.array(1))
|
||||
helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True)
|
||||
@@ -864,6 +872,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, True).type(torch.int32), lambda x: x.argmin(1, True), forward_only=True)
|
||||
|
||||
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[-2**31, 0]])
|
||||
# NOTE: torch does not support this on bool
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
def test_einsum(self):
|
||||
# matrix transpose
|
||||
helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
|
||||
@@ -1099,10 +1113,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
|
||||
helper_test_op([()], lambda x: x.min())
|
||||
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])
|
||||
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max())
|
||||
@@ -1111,10 +1125,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
|
||||
helper_test_op([()], lambda x: x.max())
|
||||
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])
|
||||
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
|
||||
def test_any(self):
|
||||
|
||||
@@ -1885,7 +1885,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
|
||||
```
|
||||
"""
|
||||
return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
return (-self if self.is_floating_point() else ~self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
def rearrange(self, formula: str, **sizes) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user