From 660df3cff15b0535b88278245dd33de66f94ec22 Mon Sep 17 00:00:00 2001 From: reddyn12 <72528507+reddyn12@users.noreply.github.com> Date: Sat, 2 Mar 2024 23:51:52 -0500 Subject: [PATCH] Add test for .softmax.argmax (#3559) * Add broken test for known issue * skip PYTHON * skip PYTHON * fix commit --------- Co-authored-by: schlimeszn Co-authored-by: reddyn --- test/test_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 53c6b4aaea..5abca633e7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -752,7 +752,10 @@ class TestOps(unittest.TestCase): helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7) - + @unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "PYTHON"], "Broken ISSUE #3552") + def test_softmax_argmax(self): + helper_test_op([(45,65)], lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7) + helper_test_op([(45,65)], lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7) def test_log_softmax(self): helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)