mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
support threshold in Tensor.softplus (#11564)
fix gradient for large input
This commit is contained in:
@@ -957,8 +957,9 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda t: torch.nn.functional.softplus(t, beta=3), lambda t: Tensor.softplus(t, beta=3), grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda t: torch.nn.functional.softplus(t, beta=1/3), lambda t: Tensor.softplus(t, beta=1/3), grad_atol=1e-6)
|
||||
# # TODO: support threshold and enable this
|
||||
# helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6, low=300, high=400)
|
||||
helper_test_op([(45,65)], lambda t: torch.nn.functional.softplus(t, beta=3, threshold=0.5),
|
||||
lambda t: Tensor.softplus(t, beta=3, threshold=0.5), grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6, low=300, high=400)
|
||||
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6, low=-400, high=-300)
|
||||
helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user