From 6350725e2dde26a301adee5d00eab80b6d42204e Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 26 Feb 2025 13:43:48 -0500 Subject: [PATCH] simpler leaky_relu (#9271) rendered as `*(data0+alu0) = ((val0<0.0f)?(0.01f*val0):val0);` instead of two wheres. possible to update rewrite rules too --- tinygrad/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8466cb3d06..fad1135cf3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3091,7 +3091,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy()) ``` """ - return self.relu() - (-neg_slope*self).relu() + return (self<0).where(neg_slope*self, self) def mish(self): """