simpler leaky_relu (#9271)

rendered as `*(data0+alu0) = ((val0<0.0f)?(0.01f*val0):val0);` instead of two wheres.

possible to update rewrite rules too
This commit is contained in:
chenyu
2025-02-26 13:43:48 -05:00
committed by GitHub
parent 86b737a120
commit 6350725e2d

View File

@@ -3091,7 +3091,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
```
"""
return self.relu() - (-neg_slope*self).relu()
return (self<0).where(neg_slope*self, self)
def mish(self):
"""