activation function docs (#4705)

This commit is contained in:
wozeparrot
2024-05-24 00:12:16 +00:00
committed by GitHub
parent 27abbd5b2b
commit 2c56aa7fe0

View File

@@ -1485,8 +1485,38 @@ class Tensor:
def log2(self): return self.log()/math.log(2)
def exp(self): return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
def exp2(self): return F.Exp.apply(self*math.log(2))
def relu(self): return F.Relu.apply(self)
def sigmoid(self): return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
def relu(self):
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
- Described: https://paperswithcode.com/method/relu
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.relu().numpy())
```
"""
return F.Relu.apply(self)
def sigmoid(self):
"""
Applies the Sigmoid function element-wise.
- Described: https://en.wikipedia.org/wiki/Sigmoid_function
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.sigmoid().numpy())
```
"""
return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
def sin(self): return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
def sqrt(self): return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
def rsqrt(self): return self.reciprocal().sqrt()
@@ -1509,25 +1539,320 @@ class Tensor:
# ***** activation functions (unary) *****
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def swish(self): return self * self.sigmoid()
def silu(self): return self.swish() # The SiLU function is also known as the swish function.
def relu6(self): return self.relu() - (self-6).relu()
def hardswish(self): return self * (self+3).relu6() * (1/6)
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self): return (self.exp() - self.neg().exp()) / 2
def cosh(self): return (self.exp() + self.neg().exp()) / 2
def atanh(self): return ((1 + self)/(1 - self)).log() / 2
def asinh(self): return (self + (self.square() + 1).sqrt()).log()
def acosh(self): return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1): return self.clip(min_val, max_val)
def gelu(self): return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
def quick_gelu(self): return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01): return self.relu() - (-neg_slope*self).relu()
def mish(self): return self * self.softplus().tanh()
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self): return self / (1 + self.abs())
def elu(self, alpha=1.0):
"""
Applies the Exponential Linear Unit (ELU) function element-wise.
- Described: https://paperswithcode.com/method/elu
- Paper: https://arxiv.org/abs/1511.07289v5
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.elu().numpy())
```
"""
return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0):
"""
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
- Described: https://paperswithcode.com/method/celu
- Paper: https://arxiv.org/abs/1704.07483
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.celu().numpy())
```
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def swish(self):
"""
See `.silu()`
- Paper: https://arxiv.org/abs/1710.05941v1
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.swish().numpy())
```
"""
return self * self.sigmoid()
def silu(self):
"""
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
- Described: https://paperswithcode.com/method/silu
- Paper: https://arxiv.org/abs/1606.08415
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.silu().numpy())
```
"""
return self.swish() # The SiLU function is also known as the swish function.
def relu6(self):
"""
Applies the ReLU6 function element-wise.
- Described: https://paperswithcode.com/method/relu6
- Paper: https://arxiv.org/abs/1704.04861v1
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.relu6().numpy())
```
"""
return self.relu() - (self-6).relu()
def hardswish(self):
"""
Applies the Hardswish function element-wise.
- Described: https://paperswithcode.com/method/hard-swish
- Paper: https://arxiv.org/abs/1905.02244v5
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.hardswish().numpy())
```
"""
return self * (self+3).relu6() * (1/6)
def tanh(self):
"""
Applies the Hyperbolic Tangent (tanh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.tanh().numpy())
```
"""
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self):
"""
Applies the Hyperbolic Sine (sinh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.sinh().numpy())
```
"""
return (self.exp() - self.neg().exp()) / 2
def cosh(self):
"""
Applies the Hyperbolic Cosine (cosh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.cosh().numpy())
```
"""
return (self.exp() + self.neg().exp()) / 2
def atanh(self):
"""
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = (Tensor.rand(2, 3) + 1) / 2
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.atanh().numpy())
```
"""
return ((1 + self)/(1 - self)).log() / 2
def asinh(self):
"""
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = (Tensor.rand(2, 3) + 1) / 2
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.asinh().numpy())
```
"""
return (self + (self.square() + 1).sqrt()).log()
def acosh(self):
"""
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.rand(2, 3) + 1
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.acosh().numpy())
```
"""
return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1):
"""
Applies the Hardtanh function element-wise.
- Described: https://paperswithcode.com/method/hardtanh-activation
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.hardtanh().numpy())
```
"""
return self.clip(min_val, max_val)
def gelu(self):
"""
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
- Described: https://paperswithcode.com/method/gelu
- Paper: https://arxiv.org/abs/1606.08415v5
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.gelu().numpy())
```
"""
return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
def quick_gelu(self):
"""
Applies the Sigmoid GELU approximation element-wise.
- Described: https://paperswithcode.com/method/gelu
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.quick_gelu().numpy())
```
"""
return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01):
"""
Applies the Leaky ReLU function element-wise.
- Described: https://paperswithcode.com/method/leaky-relu
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.leakyrelu().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.leakyrelu(neg_slope=1).numpy())
```
"""
return self.relu() - (-neg_slope*self).relu()
def mish(self):
"""
Applies the Mish function element-wise.
- Described: https://paperswithcode.com/method/mish
- Paper: https://arxiv.org/abs/1908.08681v3
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.mish().numpy())
```
"""
return self * self.softplus().tanh()
def softplus(self, beta=1):
"""
Applies the Softplus function element-wise.
- Described: https://paperswithcode.com/method/softplus
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.softplus().numpy())
```
"""
return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self):
"""
Applies the Softsign function element-wise.
- Described: https://paperswithcode.com/method/softsign
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.softsign().numpy())
```
"""
return self / (1 + self.abs())
# ***** broadcasted elementwise mlops *****
def _broadcast_to(self, shape:Tuple[sint, ...]):