mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
simpler leaky_relu (#9271)
rendered as `*(data0+alu0) = ((val0<0.0f)?(0.01f*val0):val0);` instead of two wheres. possible to update rewrite rules too
This commit is contained in:
@@ -3091,7 +3091,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
|
||||
```
|
||||
"""
|
||||
return self.relu() - (-neg_slope*self).relu()
|
||||
return (self<0).where(neg_slope*self, self)
|
||||
|
||||
def mish(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user