Add test for .softmax.argmax (#3559)

* Add broken test for known issue

* skip PYTHON

* skip PYTHON

* fix commit

---------

Co-authored-by: schlimeszn <schlimeszn@gmail.com>
Co-authored-by: reddyn <nikidsniper@gmail.com>
This commit is contained in:
reddyn12
2024-03-02 23:51:52 -05:00
committed by GitHub
parent ee41fafdab
commit 660df3cff1

View File

@@ -752,7 +752,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7)
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "PYTHON"], "Broken ISSUE #3552")
def test_softmax_argmax(self):
helper_test_op([(45,65)], lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
def test_log_softmax(self):
helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)