mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Tensor.logsumexp (#4442)
the subtract max part should share with safe softmax cleaner
This commit is contained in:
@@ -810,6 +810,14 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
def test_logsumexp(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=0, keepdim=True), lambda x: x.logsumexp(0, True), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logsumexp(x, dim=1), lambda x: x.logsumexp(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
def test_sinh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6)
|
||||
# TODO: backward nan instead of inf
|
||||
|
||||
Reference in New Issue
Block a user