diff --git a/test/test_ops.py b/test/test_ops.py index 6a975e8a79..952bf0930c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -267,6 +267,7 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign) def test_sigmoid(self): helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid) + helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=100) helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=-100) helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True) def test_softplus(self): @@ -274,9 +275,11 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6) def test_gelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu) + helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100) helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=-100) def test_quick_gelu(self): helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) + helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=100) helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=-100) helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) def test_elu(self): diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 87e4c4ab01..1bf04ff4f9 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -55,6 +55,9 @@ class Exp(Function): def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.binary_op(BinaryOps.MUL, grad_output) +# NOTE: the implicit derivative of sigmoid is not stable +# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e +# TODO: have the backend automatically find this class Sigmoid(Function): __slots__ = "ret" def forward(self, x:LazyBuffer) -> LazyBuffer: