diff --git a/test/test_ops.py b/test/test_ops.py index 90aca02224..d39e877622 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1534,6 +1534,9 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7) + helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7) + helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=3), lambda x: x.logsumexp(3), atol=1e-7, grad_atol=1e-7) helper_test_op([(45)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7) @@ -1541,6 +1544,9 @@ class TestOps(unittest.TestCase): def test_logcumsumexp(self): helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(6,6,6)], lambda x: torch.logcumsumexp(x, dim=2), lambda x: x.logcumsumexp(2), atol=1e-7, grad_atol=1e-7) + helper_test_op([(6,6,6,6)], lambda x: torch.logcumsumexp(x, dim=2), lambda x: x.logcumsumexp(2), atol=1e-7, grad_atol=1e-7) + helper_test_op([(6,6,6,6)], lambda x: torch.logcumsumexp(x, dim=3), lambda x: x.logcumsumexp(3), atol=1e-7, grad_atol=1e-7) helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2a0609fb4a..65d449be44 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2098,7 +2098,6 @@ class Tensor(MathTrait): ``` """ if self.ndim == 0: return self - axis = self._resolve_dim(axis) x = self.transpose(axis, -1) last_dim_size = x.shape[-1] x_reshaped = x.reshape(-1, last_dim_size)