diff --git a/test/test_ops.py b/test/test_ops.py index bac4f82216..fda6402089 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1661,6 +1661,7 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1, keepdim=True), lambda x: x.logsumexp(1, True), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=3), lambda x: x.logsumexp(3), atol=1e-7, grad_atol=1e-7) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a175eb7088..89250bb8b3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2012,7 +2012,7 @@ class Tensor(OpMixin): ``` """ m = self.max(axis=axis, keepdim=True) - return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis) + return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis)) def logcumsumexp(self, axis=0) -> Tensor: """