From 9ad3c865ace3dbd8d86902a5b7b0abc3fbf06b56 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 21 Jan 2026 09:49:55 -0500 Subject: [PATCH] fix bug in logsumexp keepdim=True (#14268) --- test/test_ops.py | 1 + tinygrad/tensor.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index bac4f82216..fda6402089 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1661,6 +1661,7 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1, keepdim=True), lambda x: x.logsumexp(1, True), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7) helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=3), lambda x: x.logsumexp(3), atol=1e-7, grad_atol=1e-7) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a175eb7088..89250bb8b3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2012,7 +2012,7 @@ class Tensor(OpMixin): ``` """ m = self.max(axis=axis, keepdim=True) - return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis) + return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis)) def logcumsumexp(self, axis=0) -> Tensor: """