test more dims in test_logsumexp and test_logcumsumexp (#10907)

refactoring squeeze and unsqueeze is easy to get wrong
This commit is contained in:
chenyu
2025-06-20 21:42:18 -04:00
committed by GitHub
parent 3771cc0f77
commit 2d9c61e39e
2 changed files with 6 additions and 1 deletions

View File

@@ -1534,6 +1534,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7)
helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7)
helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=3), lambda x: x.logsumexp(3), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7)
@@ -1541,6 +1544,9 @@ class TestOps(unittest.TestCase):
def test_logcumsumexp(self):
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(6,6,6)], lambda x: torch.logcumsumexp(x, dim=2), lambda x: x.logcumsumexp(2), atol=1e-7, grad_atol=1e-7)
helper_test_op([(6,6,6,6)], lambda x: torch.logcumsumexp(x, dim=2), lambda x: x.logcumsumexp(2), atol=1e-7, grad_atol=1e-7)
helper_test_op([(6,6,6,6)], lambda x: torch.logcumsumexp(x, dim=3), lambda x: x.logcumsumexp(3), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7)

View File

@@ -2098,7 +2098,6 @@ class Tensor(MathTrait):
```
"""
if self.ndim == 0: return self
axis = self._resolve_dim(axis)
x = self.transpose(axis, -1)
last_dim_size = x.shape[-1]
x_reshaped = x.reshape(-1, last_dim_size)