Argmax/Argmin Feature (#1576)

* implemented argmax and argmin

* lint

* lint

* match torch behaviour

* format

* removed flip
This commit is contained in:
Umut Zengin
2023-08-21 04:46:46 +03:00
committed by GitHub
parent 1900acda09
commit 35bf21276f
2 changed files with 22 additions and 0 deletions

View File

@@ -392,6 +392,20 @@ class TestOps(unittest.TestCase):
helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1), atol=1e-6)
helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2), atol=1e-6)
helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=-1), lambda x: Tensor.cumsum(x, axis=-1), atol=1e-6)
def test_argmax(self):
self.assertEqual(torch.Tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max
helper_test_op([(10,20)], lambda x: x.argmax(), lambda x: x.argmax(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(0, False), lambda x: x.argmax(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, False), lambda x: x.argmax(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, True), lambda x: x.argmax(1, True), forward_only=True)
def test_argmin(self):
self.assertEqual(torch.Tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
helper_test_op([(10,20)], lambda x: x.argmin(), lambda x: x.argmin(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(0, False), lambda x: x.argmin(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, False), lambda x: x.argmin(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, True), lambda x: x.argmin(1, True), forward_only=True)
def test_matmul_simple(self):
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_matmul(self):