leakyrelu to leaky_relu (#9270)

This commit is contained in:
Francis Lata
2025-02-26 13:22:08 -05:00
committed by GitHub
parent cd822bbe11
commit 86b737a120
11 changed files with 30 additions and 30 deletions

View File

@@ -110,7 +110,7 @@ class TinyNet:
def __call__(self, x):
x = self.l1(x)
x = x.leakyrelu()
x = x.leaky_relu()
x = self.l2(x)
return x
@@ -118,7 +118,7 @@ net = TinyNet()
```
We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor `x`.
We can also see that functional operations like `leakyrelu` are not defined as classes and instead are just methods we can just call.
We can also see that functional operations like `leaky_relu` are not defined as classes and instead are just methods we can just call.
Finally, we just initialize an instance of our neural network, and we are ready to start training it.
## Training

View File

@@ -52,7 +52,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.erf
::: tinygrad.Tensor.gelu
::: tinygrad.Tensor.quick_gelu
::: tinygrad.Tensor.leakyrelu
::: tinygrad.Tensor.leaky_relu
::: tinygrad.Tensor.mish
::: tinygrad.Tensor.softplus
::: tinygrad.Tensor.softsign

View File

@@ -16,9 +16,9 @@ class LinearGen:
self.l4 = Tensor.scaled_uniform(1024, 784)
def forward(self, x):
x = x.dot(self.l1).leakyrelu(0.2)
x = x.dot(self.l2).leakyrelu(0.2)
x = x.dot(self.l3).leakyrelu(0.2)
x = x.dot(self.l1).leaky_relu(0.2)
x = x.dot(self.l2).leaky_relu(0.2)
x = x.dot(self.l3).leaky_relu(0.2)
x = x.dot(self.l4).tanh()
return x
@@ -31,9 +31,9 @@ class LinearDisc:
def forward(self, x):
# balance the discriminator inputs with const bias (.add(1))
x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l1).add(1).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l2).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l3).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l4).log_softmax()
return x

View File

@@ -437,14 +437,14 @@ class Generator:
x = self.conv_pre(x)
if g is not None: x = x + self.cond(g)
for i in range(self.num_upsamples):
x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None
x, xs = self.ups[i](x.leaky_relu(LRELU_SLOPE)), None
x_source = self.noise_convs[i](har_source)
x = x + x_source
for j in range(self.num_kernels):
if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x)
else: xs += self.resblocks[i * self.num_kernels + j].forward(x)
x = xs / self.num_kernels
return self.conv_post(x.leakyrelu()).tanh()
return self.conv_post(x.leaky_relu()).tanh()
# **** helpers ****

View File

@@ -105,12 +105,12 @@ class Vgg7:
Output format: (1, 3, Y - 14, X - 14)
(the - 14 represents the 7-pixel context border that is lost)
"""
x = self.conv1.forward(x).leakyrelu(0.1)
x = self.conv2.forward(x).leakyrelu(0.1)
x = self.conv3.forward(x).leakyrelu(0.1)
x = self.conv4.forward(x).leakyrelu(0.1)
x = self.conv5.forward(x).leakyrelu(0.1)
x = self.conv6.forward(x).leakyrelu(0.1)
x = self.conv1.forward(x).leaky_relu(0.1)
x = self.conv2.forward(x).leaky_relu(0.1)
x = self.conv3.forward(x).leaky_relu(0.1)
x = self.conv4.forward(x).leaky_relu(0.1)
x = self.conv5.forward(x).leaky_relu(0.1)
x = self.conv6.forward(x).leaky_relu(0.1)
x = self.conv7.forward(x)
return x

View File

@@ -193,10 +193,10 @@ class Generator:
x = self.conv_pre(x)
if g is not None: x = x + self.cond(g)
for i in range(self.num_upsamples):
x = self.ups[i](x.leakyrelu(LRELU_SLOPE))
x = self.ups[i](x.leaky_relu(LRELU_SLOPE))
xs = sum(self.resblocks[i * self.num_kernels + j].forward(x) for j in range(self.num_kernels))
x = (xs / self.num_kernels).realize()
res = self.conv_post(x.leakyrelu()).tanh().realize()
res = self.conv_post(x.leaky_relu()).tanh().realize()
return res
class LayerNorm(nn.LayerNorm):
@@ -238,8 +238,8 @@ class ResBlock1:
self.convs2 = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) for _ in range(3)]
def forward(self, x: Tensor, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = x.leakyrelu(LRELU_SLOPE)
xt = c1(xt if x_mask is None else xt * x_mask).leakyrelu(LRELU_SLOPE)
xt = x.leaky_relu(LRELU_SLOPE)
xt = c1(xt if x_mask is None else xt * x_mask).leaky_relu(LRELU_SLOPE)
x = c2(xt if x_mask is None else xt * x_mask) + x
return x if x_mask is None else x * x_mask

View File

@@ -228,7 +228,7 @@ class Darknet:
module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True))
# LeakyReLU activation
if activation == "leaky":
module.append(lambda x: x.leakyrelu(0.1))
module.append(lambda x: x.leaky_relu(0.1))
elif module_type == "maxpool":
size, stride = int(x["size"]), int(x["stride"])
module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride))

View File

@@ -295,7 +295,7 @@ def get_onnx_ops():
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
return (X > 0).where(X, X * slope)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leaky_relu(alpha)
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()

View File

@@ -215,7 +215,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse
"aten.mul.out": operator.mul,
"aten.bmm.out": operator.matmul,
"aten.leaky_relu.out": Tensor.leakyrelu, # TODO: this should be renamed in tinygrad
"aten.leaky_relu.out": Tensor.leaky_relu,
# NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods
"aten.remainder.Tensor_out": Tensor.mod,
"aten.pow.Tensor_Tensor_out": Tensor.pow,

View File

@@ -806,9 +806,9 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.relu(), vals=[[-1.,0,1]])
def test_relu_maximum_exact(self):
helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]])
def test_leakyrelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
def test_leaky_relu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leaky_relu)
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leaky_relu)
def test_celu(self):
for val in range(1, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))

View File

@@ -3078,17 +3078,17 @@ class Tensor(SimpleMathTrait):
"""
return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01):
def leaky_relu(self, neg_slope=0.01):
"""
Applies the Leaky ReLU function element-wise.
- Described: https://paperswithcode.com/method/leaky-relu
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
```
"""
return self.relu() - (-neg_slope*self).relu()