mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
Broken Sigmoid backward: Add test and mlop for Sigmoid (#1113)
* Add failing sigmoid test * update more tests * add mlop for sigmoid * add back test * math.log(math.e) = 1 * remove divides --------- Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
This commit is contained in:
@@ -267,15 +267,17 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
|
||||
def test_sigmoid(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=-100)
|
||||
helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True)
|
||||
def test_softplus(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
@unittest.skip("not supported in older pytorch")
|
||||
def test_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=-100)
|
||||
def test_quick_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
|
||||
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=-100)
|
||||
helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
|
||||
def test_elu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x), Tensor.elu)
|
||||
@@ -396,6 +398,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
def test_tanh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, a=-100)
|
||||
helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
def test_hardtanh(self):
|
||||
for val in range(10, 30, 5):
|
||||
|
||||
@@ -41,7 +41,7 @@ class Log(Function):
|
||||
__slots__ = "x"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)/math.log(math.e)))
|
||||
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.binary_op(BinaryOps.DIV, self.x)
|
||||
@@ -49,12 +49,21 @@ class Log(Function):
|
||||
class Exp(Function):
|
||||
__slots__ = "ret"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(math.log(math.e)/math.log(2))).unary_op(UnaryOps.EXP2)
|
||||
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(1/math.log(2))).unary_op(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Sigmoid(Function):
|
||||
__slots__ = "ret"
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.const_like(1).binary_op(BinaryOps.DIV, x.const_like(1).binary_op(BinaryOps.ADD, x.binary_op(BinaryOps.MUL, x.const_like(-1/math.log(2))).unary_op(UnaryOps.EXP2)))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret)).binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@@ -138,7 +147,7 @@ class Pow(Function):
|
||||
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
return grad_output.binary_op(BinaryOps.MUL, self.y.binary_op(BinaryOps.MUL, self.ret.binary_op(BinaryOps.DIV, self.x))) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2)/math.log(math.e))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
|
||||
grad_output.binary_op(BinaryOps.MUL, self.x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, self.x.const_like(math.log(2))).binary_op(BinaryOps.MUL, self.ret)) if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
__slots__ = 'x', 'y'
|
||||
|
||||
@@ -485,6 +485,7 @@ class Tensor:
|
||||
def log2(self): return mlops.Log.apply(self)/log(2)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
def relu(self): return mlops.Relu.apply(self)
|
||||
def sigmoid(self): return mlops.Sigmoid.apply(self)
|
||||
def sin(self): return mlops.Sin.apply(self)
|
||||
def cos(self): return ((pi/2)-self).sin()
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
@@ -512,8 +513,6 @@ class Tensor:
|
||||
def reciprocal(self): return 1.0/self
|
||||
|
||||
# ***** activation functions (unary) *****
|
||||
|
||||
def sigmoid(self): return (1.0 + (-self).exp()).reciprocal()
|
||||
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
|
||||
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
||||
def swish(self): return self * self.sigmoid()
|
||||
|
||||
Reference in New Issue
Block a user