Tensor.erf (#7419)

the same one used in onnx and the one in bert.
This commit is contained in:
chenyu
2024-10-30 18:12:28 -04:00
committed by GitHub
parent e955aa1bee
commit fb694a63eb
5 changed files with 25 additions and 20 deletions

View File

@@ -647,6 +647,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
def test_erf(self):
helper_test_op([(45,65)], torch.erf, Tensor.erf)
helper_test_op([(45,65)], torch.erf, Tensor.erf, low=300, high=400)
helper_test_op([(45,65)], torch.erf, Tensor.erf, low=-400, high=-300)
helper_test_op([()], torch.erf, Tensor.erf)
def test_gelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=400)