mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
test more dims in test_logsumexp and test_logcumsumexp (#10907)
refactoring squeeze and unsqueeze is easy to get wrong
This commit is contained in:
@@ -1534,6 +1534,9 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=2), lambda x: x.logsumexp(2), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6,6)], lambda x: torch.logsumexp(x, dim=3), lambda x: x.logsumexp(3), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7)
|
||||
@@ -1541,6 +1544,9 @@ class TestOps(unittest.TestCase):
|
||||
def test_logcumsumexp(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6)], lambda x: torch.logcumsumexp(x, dim=2), lambda x: x.logcumsumexp(2), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6,6)], lambda x: torch.logcumsumexp(x, dim=2), lambda x: x.logcumsumexp(2), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(6,6,6,6)], lambda x: torch.logcumsumexp(x, dim=3), lambda x: x.logcumsumexp(3), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
@@ -2098,7 +2098,6 @@ class Tensor(MathTrait):
|
||||
```
|
||||
"""
|
||||
if self.ndim == 0: return self
|
||||
axis = self._resolve_dim(axis)
|
||||
x = self.transpose(axis, -1)
|
||||
last_dim_size = x.shape[-1]
|
||||
x_reshaped = x.reshape(-1, last_dim_size)
|
||||
|
||||
Reference in New Issue
Block a user