#include <tgmath.h> in ops_clang (#3927)

* different clang sqrt/log2/exp2/sin function based on dtype

fixed softmax_argmax issue in #3552 for clang.

* tgmath.h

* revert those
This commit is contained in:
chenyu
2024-03-25 17:48:57 -04:00
committed by GitHub
parent 514c43201d
commit 4ecd5789ab
3 changed files with 4 additions and 4 deletions

View File

@@ -793,7 +793,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7)
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "PYTHON"], "Broken ISSUE #3552")
@unittest.skipIf(Device.DEFAULT in ["PYTHON"], "Broken ISSUE #3552")
def test_softmax_argmax(self):
helper_test_op([(45,65)], lambda x: x.softmax(0).argmax().type(torch.int32),
lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)