Tensor.logsumexp (#4442)

the subtract max part should share with safe softmax

cleaner
This commit is contained in:
chenyu
2024-05-09 20:49:06 -04:00
committed by GitHub
parent 78b298aa2a
commit d3dc332c2e
2 changed files with 12 additions and 0 deletions

View File

@@ -810,6 +810,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
def test_logsumexp(self):
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7)
def test_sinh(self):
helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6)
# TODO: backward nan instead of inf