mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
LayerNorm2d for 2 lines
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm
|
||||
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d
|
||||
import torch
|
||||
|
||||
class TestNN(unittest.TestCase):
|
||||
@@ -76,7 +76,7 @@ class TestNN(unittest.TestCase):
|
||||
def test_conv2d(self):
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
|
||||
# create in tinygrad
|
||||
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
@@ -131,5 +131,24 @@ class TestNN(unittest.TestCase):
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
def test_layernorm_2d(self):
|
||||
N, C, H, W = 20, 5, 10, 10
|
||||
|
||||
# create in tinygrad
|
||||
layer = LayerNorm2d(C)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.LayerNorm([C]).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user