leakyrelu to leaky_relu (#9270)

This commit is contained in:
Francis Lata
2025-02-26 13:22:08 -05:00
committed by GitHub
parent cd822bbe11
commit 86b737a120
11 changed files with 30 additions and 30 deletions

View File

@@ -110,7 +110,7 @@ class TinyNet:
def __call__(self, x): def __call__(self, x):
x = self.l1(x) x = self.l1(x)
x = x.leakyrelu() x = x.leaky_relu()
x = self.l2(x) x = self.l2(x)
return x return x
@@ -118,7 +118,7 @@ net = TinyNet()
``` ```
We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor `x`. We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor `x`.
We can also see that functional operations like `leakyrelu` are not defined as classes and instead are just methods we can just call. We can also see that functional operations like `leaky_relu` are not defined as classes and instead are just methods we can just call.
Finally, we just initialize an instance of our neural network, and we are ready to start training it. Finally, we just initialize an instance of our neural network, and we are ready to start training it.
## Training ## Training

View File

@@ -52,7 +52,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.erf ::: tinygrad.Tensor.erf
::: tinygrad.Tensor.gelu ::: tinygrad.Tensor.gelu
::: tinygrad.Tensor.quick_gelu ::: tinygrad.Tensor.quick_gelu
::: tinygrad.Tensor.leakyrelu ::: tinygrad.Tensor.leaky_relu
::: tinygrad.Tensor.mish ::: tinygrad.Tensor.mish
::: tinygrad.Tensor.softplus ::: tinygrad.Tensor.softplus
::: tinygrad.Tensor.softsign ::: tinygrad.Tensor.softsign

View File

@@ -16,9 +16,9 @@ class LinearGen:
self.l4 = Tensor.scaled_uniform(1024, 784) self.l4 = Tensor.scaled_uniform(1024, 784)
def forward(self, x): def forward(self, x):
x = x.dot(self.l1).leakyrelu(0.2) x = x.dot(self.l1).leaky_relu(0.2)
x = x.dot(self.l2).leakyrelu(0.2) x = x.dot(self.l2).leaky_relu(0.2)
x = x.dot(self.l3).leakyrelu(0.2) x = x.dot(self.l3).leaky_relu(0.2)
x = x.dot(self.l4).tanh() x = x.dot(self.l4).tanh()
return x return x
@@ -31,9 +31,9 @@ class LinearDisc:
def forward(self, x): def forward(self, x):
# balance the discriminator inputs with const bias (.add(1)) # balance the discriminator inputs with const bias (.add(1))
x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3) x = x.dot(self.l1).add(1).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3) x = x.dot(self.l2).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3) x = x.dot(self.l3).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l4).log_softmax() x = x.dot(self.l4).log_softmax()
return x return x

View File

@@ -437,14 +437,14 @@ class Generator:
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: x = x + self.cond(g) if g is not None: x = x + self.cond(g)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None x, xs = self.ups[i](x.leaky_relu(LRELU_SLOPE)), None
x_source = self.noise_convs[i](har_source) x_source = self.noise_convs[i](har_source)
x = x + x_source x = x + x_source
for j in range(self.num_kernels): for j in range(self.num_kernels):
if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x) if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x)
else: xs += self.resblocks[i * self.num_kernels + j].forward(x) else: xs += self.resblocks[i * self.num_kernels + j].forward(x)
x = xs / self.num_kernels x = xs / self.num_kernels
return self.conv_post(x.leakyrelu()).tanh() return self.conv_post(x.leaky_relu()).tanh()
# **** helpers **** # **** helpers ****

View File

@@ -105,12 +105,12 @@ class Vgg7:
Output format: (1, 3, Y - 14, X - 14) Output format: (1, 3, Y - 14, X - 14)
(the - 14 represents the 7-pixel context border that is lost) (the - 14 represents the 7-pixel context border that is lost)
""" """
x = self.conv1.forward(x).leakyrelu(0.1) x = self.conv1.forward(x).leaky_relu(0.1)
x = self.conv2.forward(x).leakyrelu(0.1) x = self.conv2.forward(x).leaky_relu(0.1)
x = self.conv3.forward(x).leakyrelu(0.1) x = self.conv3.forward(x).leaky_relu(0.1)
x = self.conv4.forward(x).leakyrelu(0.1) x = self.conv4.forward(x).leaky_relu(0.1)
x = self.conv5.forward(x).leakyrelu(0.1) x = self.conv5.forward(x).leaky_relu(0.1)
x = self.conv6.forward(x).leakyrelu(0.1) x = self.conv6.forward(x).leaky_relu(0.1)
x = self.conv7.forward(x) x = self.conv7.forward(x)
return x return x

View File

@@ -193,10 +193,10 @@ class Generator:
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: x = x + self.cond(g) if g is not None: x = x + self.cond(g)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
x = self.ups[i](x.leakyrelu(LRELU_SLOPE)) x = self.ups[i](x.leaky_relu(LRELU_SLOPE))
xs = sum(self.resblocks[i * self.num_kernels + j].forward(x) for j in range(self.num_kernels)) xs = sum(self.resblocks[i * self.num_kernels + j].forward(x) for j in range(self.num_kernels))
x = (xs / self.num_kernels).realize() x = (xs / self.num_kernels).realize()
res = self.conv_post(x.leakyrelu()).tanh().realize() res = self.conv_post(x.leaky_relu()).tanh().realize()
return res return res
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
@@ -238,8 +238,8 @@ class ResBlock1:
self.convs2 = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) for _ in range(3)] self.convs2 = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) for _ in range(3)]
def forward(self, x: Tensor, x_mask=None): def forward(self, x: Tensor, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2): for c1, c2 in zip(self.convs1, self.convs2):
xt = x.leakyrelu(LRELU_SLOPE) xt = x.leaky_relu(LRELU_SLOPE)
xt = c1(xt if x_mask is None else xt * x_mask).leakyrelu(LRELU_SLOPE) xt = c1(xt if x_mask is None else xt * x_mask).leaky_relu(LRELU_SLOPE)
x = c2(xt if x_mask is None else xt * x_mask) + x x = c2(xt if x_mask is None else xt * x_mask) + x
return x if x_mask is None else x * x_mask return x if x_mask is None else x * x_mask

View File

@@ -228,7 +228,7 @@ class Darknet:
module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True)) module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True))
# LeakyReLU activation # LeakyReLU activation
if activation == "leaky": if activation == "leaky":
module.append(lambda x: x.leakyrelu(0.1)) module.append(lambda x: x.leaky_relu(0.1))
elif module_type == "maxpool": elif module_type == "maxpool":
size, stride = int(x["size"]), int(x["stride"]) size, stride = int(x["size"]), int(x["stride"])
module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride)) module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride))

View File

@@ -295,7 +295,7 @@ def get_onnx_ops():
def PRelu(X:Tensor, slope:Tensor): def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
return (X > 0).where(X, X * slope) return (X > 0).where(X, X * slope)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leaky_relu(alpha)
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()

View File

@@ -215,7 +215,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse "aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse
"aten.mul.out": operator.mul, "aten.mul.out": operator.mul,
"aten.bmm.out": operator.matmul, "aten.bmm.out": operator.matmul,
"aten.leaky_relu.out": Tensor.leakyrelu, # TODO: this should be renamed in tinygrad "aten.leaky_relu.out": Tensor.leaky_relu,
# NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods # NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods
"aten.remainder.Tensor_out": Tensor.mod, "aten.remainder.Tensor_out": Tensor.mod,
"aten.pow.Tensor_Tensor_out": Tensor.pow, "aten.pow.Tensor_Tensor_out": Tensor.pow,

View File

@@ -806,9 +806,9 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.relu(), vals=[[-1.,0,1]]) helper_test_op(None, lambda x: x.relu(), vals=[[-1.,0,1]])
def test_relu_maximum_exact(self): def test_relu_maximum_exact(self):
helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]]) helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]])
def test_leakyrelu(self): def test_leaky_relu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu) helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leaky_relu)
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu) helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leaky_relu)
def test_celu(self): def test_celu(self):
for val in range(1, 5): for val in range(1, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))

View File

@@ -3078,17 +3078,17 @@ class Tensor(SimpleMathTrait):
""" """
return self * (self * 1.702).sigmoid() return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01): def leaky_relu(self, neg_slope=0.01):
""" """
Applies the Leaky ReLU function element-wise. Applies the Leaky ReLU function element-wise.
- Described: https://paperswithcode.com/method/leaky-relu - Described: https://paperswithcode.com/method/leaky-relu
```python exec="true" source="above" session="tensor" result="python" ```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy()) print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy())
``` ```
```python exec="true" source="above" session="tensor" result="python" ```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy()) print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
``` ```
""" """
return self.relu() - (-neg_slope*self).relu() return self.relu() - (-neg_slope*self).relu()