mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix bug in logsumexp keepdim=True (#14268)
This commit is contained in:
@@ -1661,6 +1661,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1, keepdim=True), lambda x: x.logsumexp(1, True), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=3), lambda x: x.logsumexp(3), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
@@ -2012,7 +2012,7 @@ class Tensor(OpMixin):
|
||||
```
|
||||
"""
|
||||
m = self.max(axis=axis, keepdim=True)
|
||||
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
|
||||
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis))
|
||||
|
||||
def logcumsumexp(self, axis=0) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user