Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-02-26 13:27:35 -05:00
11 changed files with 44 additions and 56 deletions

View File

@@ -110,7 +110,7 @@ class TinyNet:
def __call__(self, x):
x = self.l1(x)
x = x.leakyrelu()
x = x.leaky_relu()
x = self.l2(x)
return x
@@ -118,7 +118,7 @@ net = TinyNet()
```
We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor `x`.
We can also see that functional operations like `leakyrelu` are not defined as classes and instead are just methods we can just call.
We can also see that functional operations like `leaky_relu` are not defined as classes and instead are just methods we can just call.
Finally, we just initialize an instance of our neural network, and we are ready to start training it.
## Training

View File

@@ -52,7 +52,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.erf
::: tinygrad.Tensor.gelu
::: tinygrad.Tensor.quick_gelu
::: tinygrad.Tensor.leakyrelu
::: tinygrad.Tensor.leaky_relu
::: tinygrad.Tensor.mish
::: tinygrad.Tensor.softplus
::: tinygrad.Tensor.softsign

View File

@@ -16,9 +16,9 @@ class LinearGen:
self.l4 = Tensor.scaled_uniform(1024, 784)
def forward(self, x):
x = x.dot(self.l1).leakyrelu(0.2)
x = x.dot(self.l2).leakyrelu(0.2)
x = x.dot(self.l3).leakyrelu(0.2)
x = x.dot(self.l1).leaky_relu(0.2)
x = x.dot(self.l2).leaky_relu(0.2)
x = x.dot(self.l3).leaky_relu(0.2)
x = x.dot(self.l4).tanh()
return x
@@ -31,9 +31,9 @@ class LinearDisc:
def forward(self, x):
# balance the discriminator inputs with const bias (.add(1))
x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3)
x = x.dot(self.l1).add(1).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l2).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l3).leaky_relu(0.2).dropout(0.3)
x = x.dot(self.l4).log_softmax()
return x

View File

@@ -437,14 +437,14 @@ class Generator:
x = self.conv_pre(x)
if g is not None: x = x + self.cond(g)
for i in range(self.num_upsamples):
x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None
x, xs = self.ups[i](x.leaky_relu(LRELU_SLOPE)), None
x_source = self.noise_convs[i](har_source)
x = x + x_source
for j in range(self.num_kernels):
if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x)
else: xs += self.resblocks[i * self.num_kernels + j].forward(x)
x = xs / self.num_kernels
return self.conv_post(x.leakyrelu()).tanh()
return self.conv_post(x.leaky_relu()).tanh()
# **** helpers ****

View File

@@ -105,12 +105,12 @@ class Vgg7:
Output format: (1, 3, Y - 14, X - 14)
(the - 14 represents the 7-pixel context border that is lost)
"""
x = self.conv1.forward(x).leakyrelu(0.1)
x = self.conv2.forward(x).leakyrelu(0.1)
x = self.conv3.forward(x).leakyrelu(0.1)
x = self.conv4.forward(x).leakyrelu(0.1)
x = self.conv5.forward(x).leakyrelu(0.1)
x = self.conv6.forward(x).leakyrelu(0.1)
x = self.conv1.forward(x).leaky_relu(0.1)
x = self.conv2.forward(x).leaky_relu(0.1)
x = self.conv3.forward(x).leaky_relu(0.1)
x = self.conv4.forward(x).leaky_relu(0.1)
x = self.conv5.forward(x).leaky_relu(0.1)
x = self.conv6.forward(x).leaky_relu(0.1)
x = self.conv7.forward(x)
return x

View File

@@ -193,10 +193,10 @@ class Generator:
x = self.conv_pre(x)
if g is not None: x = x + self.cond(g)
for i in range(self.num_upsamples):
x = self.ups[i](x.leakyrelu(LRELU_SLOPE))
x = self.ups[i](x.leaky_relu(LRELU_SLOPE))
xs = sum(self.resblocks[i * self.num_kernels + j].forward(x) for j in range(self.num_kernels))
x = (xs / self.num_kernels).realize()
res = self.conv_post(x.leakyrelu()).tanh().realize()
res = self.conv_post(x.leaky_relu()).tanh().realize()
return res
class LayerNorm(nn.LayerNorm):
@@ -238,8 +238,8 @@ class ResBlock1:
self.convs2 = [nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) for _ in range(3)]
def forward(self, x: Tensor, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = x.leakyrelu(LRELU_SLOPE)
xt = c1(xt if x_mask is None else xt * x_mask).leakyrelu(LRELU_SLOPE)
xt = x.leaky_relu(LRELU_SLOPE)
xt = c1(xt if x_mask is None else xt * x_mask).leaky_relu(LRELU_SLOPE)
x = c2(xt if x_mask is None else xt * x_mask) + x
return x if x_mask is None else x * x_mask

View File

@@ -228,7 +228,7 @@ class Darknet:
module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True))
# LeakyReLU activation
if activation == "leaky":
module.append(lambda x: x.leakyrelu(0.1))
module.append(lambda x: x.leaky_relu(0.1))
elif module_type == "maxpool":
size, stride = int(x["size"]), int(x["stride"])
module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride))

View File

@@ -295,7 +295,7 @@ def get_onnx_ops():
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
return (X > 0).where(X, X * slope)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leaky_relu(alpha)
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()

View File

@@ -215,7 +215,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
"aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse
"aten.mul.out": operator.mul,
"aten.bmm.out": operator.matmul,
"aten.leaky_relu.out": Tensor.leakyrelu, # TODO: this should be renamed in tinygrad
"aten.leaky_relu.out": Tensor.leaky_relu,
# NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods
"aten.remainder.Tensor_out": Tensor.mod,
"aten.pow.Tensor_Tensor_out": Tensor.pow,

View File

@@ -54,31 +54,19 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
compare("forward pass", ret.numpy(), out.detach().cpu().numpy(), atol=atol, rtol=rtol)
torch_fbp, tinygrad_fbp = np.nan, np.nan
if not forward_only and not FORWARD_ONLY:
if not forward_only and not FORWARD_ONLY and ts and tst:
st = time.monotonic()
(out+1).square().mean().backward()
torch_grads = torch.autograd.grad(torch_fxn(*ts).sum(), ts)
torch_fbp = time.monotonic() - st
st = time.monotonic()
# NOTE: we now have to recompute the forward pass since we realized it
ret = tinygrad_fxn(*tst)
loss:Tensor = (ret+1).square().mean()
# test_ops uses new style gradient
tst_grads = loss.gradient(*tst)
if len(tst_grads): Tensor.realize(*tst_grads)
tiny_grads = tinygrad_fxn(*tst).sum().gradient(*tst)
Tensor.realize(*tiny_grads)
tinygrad_fbp = time.monotonic() - st
for i, (t, tt_grad) in enumerate(zip(ts, tst_grads)):
compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol)
"""
(ret+1).square().mean().backward()
for tt in tst: tt.grad.realize()
tinygrad_fbp = time.monotonic() - st
for i, (t, tt) in enumerate(zip(ts, tst)):
compare(f"backward pass tensor {i}", tt.grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
"""
for i, (t, torch_grad) in enumerate(zip(tiny_grads, torch_grads)):
compare(f"backward pass tensor {i}", t.numpy(), torch_grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol)
if not CI:
print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \
@@ -818,9 +806,9 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.relu(), vals=[[-1.,0,1]])
def test_relu_maximum_exact(self):
helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]])
def test_leakyrelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
def test_leaky_relu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leaky_relu)
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leaky_relu)
def test_celu(self):
for val in range(1, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
@@ -1339,8 +1327,8 @@ class TestOps(unittest.TestCase):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="var\\(\\): degrees of freedom is <= 0")
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3)))
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=5))
# TODO: fix backward
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=5), forward_only=True)
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,4), correction=5), forward_only=True)
helper_test_op([(1,)], lambda x: x.var(axis=(0,), correction=0))
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=0))
@@ -1401,9 +1389,9 @@ class TestOps(unittest.TestCase):
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
def test_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=2e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=2e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=2e-7)
def test_softmax_argmax(self):
helper_test_op([(45,65)], lambda x: x.softmax(0).argmax().type(torch.int32),
lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
@@ -1459,12 +1447,12 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), grad_atol=1e-6)
def test_asinh(self):
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6)
# NOTE: this one has larger atol
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_atol=1e-6, low=-300, high=-297)
# TODO: this one has larger tol?
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_rtol=2e-2, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303)
def test_acosh(self):
helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6)
helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-3, grad_rtol=1e-2, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6, low=300, high=303)
def test_atanh(self):
helper_test_op([(45,65)], lambda x: x.atanh(), grad_atol=1e-6)
@@ -2033,7 +2021,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
# needed to relax tolerance on NVIDIA
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_atol=1e-4, grad_rtol=1e-4)
def test_simple_grouped_conv2d(self):
bs = 1

View File

@@ -3078,17 +3078,17 @@ class Tensor(SimpleMathTrait):
"""
return self * (self * 1.702).sigmoid()
def leakyrelu(self, neg_slope=0.01):
def leaky_relu(self, neg_slope=0.01):
"""
Applies the Leaky ReLU function element-wise.
- Described: https://paperswithcode.com/method/leaky-relu
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leaky_relu(neg_slope=0.42).numpy())
```
"""
return self.relu() - (-neg_slope*self).relu()