mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
@@ -37,6 +37,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
|
||||
::: tinygrad.Tensor.hardsigmoid
|
||||
::: tinygrad.Tensor.elu
|
||||
::: tinygrad.Tensor.celu
|
||||
::: tinygrad.Tensor.selu
|
||||
::: tinygrad.Tensor.swish
|
||||
::: tinygrad.Tensor.silu
|
||||
::: tinygrad.Tensor.relu6
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
|
||||
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan","Relu",
|
||||
"Sigmoid", "MatMul", "Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign",
|
||||
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Xor", "Round", "Erf"}
|
||||
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf"}
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
@@ -44,7 +44,6 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v
|
||||
|
||||
def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
|
||||
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
|
||||
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
||||
return (X > 0).where(X, X * slope)
|
||||
|
||||
@@ -697,6 +697,9 @@ class TestOps(unittest.TestCase):
|
||||
for val in range(1, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
def test_selu(self):
|
||||
helper_test_op([(45,65)], torch.nn.functional.selu, Tensor.selu)
|
||||
helper_test_op([()], torch.nn.functional.selu, Tensor.selu)
|
||||
|
||||
def test_abs(self):
|
||||
helper_test_op([(45,65)], torch.abs, Tensor.abs)
|
||||
|
||||
@@ -2688,6 +2688,19 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
||||
|
||||
def selu(self, alpha=1.67326, gamma=1.0507):
|
||||
"""
|
||||
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
|
||||
|
||||
- Described: https://paperswithcode.com/method/selu
|
||||
- Paper: https://arxiv.org/abs/1706.02515v5
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
|
||||
```
|
||||
"""
|
||||
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
|
||||
|
||||
def swish(self):
|
||||
"""
|
||||
See `.silu()`
|
||||
|
||||
Reference in New Issue
Block a user