mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Stable Sigmoid op (#59)
* 🔨 Added stable sigmoid * ✅ added sigmoid test * 🔧 suppressed overflow warning * 🔧 clean up
This commit is contained in:
@@ -53,6 +53,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, gpu=self.gpu)
|
||||
def test_relu(self):
|
||||
helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, gpu=self.gpu)
|
||||
def test_sigmoid(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, gpu=self.gpu)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5, gpu=self.gpu)
|
||||
|
||||
|
||||
@@ -124,10 +124,11 @@ register('relu', ReLU)
|
||||
class Sigmoid(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
# TODO: stable sigmoid? does the overflow matter?
|
||||
with np.warnings.catch_warnings():
|
||||
np.warnings.filterwarnings('ignore')
|
||||
ret = 1/(1 + np.exp(-input))
|
||||
ret = np.where(
|
||||
input >= 0,1/(1 + np.exp(-input)),np.exp(input)/(1 + np.exp(input))
|
||||
)
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user