LayerNorm2d for 2 lines

This commit is contained in:
George Hotz
2023-03-20 16:58:43 -07:00
parent 128ca160ac
commit d6f4219952
3 changed files with 31 additions and 8 deletions

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d
import torch
class TestNN(unittest.TestCase):
@@ -76,7 +76,7 @@ class TestNN(unittest.TestCase):
def test_conv2d(self):
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
@@ -131,5 +131,24 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
def test_layernorm_2d(self):
N, C, H, W = 20, 5, 10, 10
# create in tinygrad
layer = LayerNorm2d(C)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.LayerNorm([C]).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.cpu().numpy())
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
if __name__ == '__main__':
unittest.main()