From d6f421995296fe4e86ec724ef417ce5d305f59ea Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 20 Mar 2023 16:58:43 -0700 Subject: [PATCH] LayerNorm2d for 2 lines --- models/convnext.py | 6 +++--- test/test_nn.py | 23 +++++++++++++++++++++-- tinygrad/nn/__init__.py | 10 +++++++--- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/models/convnext.py b/models/convnext.py index f79071133d..d9d58ad8a6 100644 --- a/models/convnext.py +++ b/models/convnext.py @@ -1,5 +1,5 @@ from tinygrad.tensor import Tensor -from tinygrad.nn import Conv2d, LayerNorm, Linear +from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear class Block: def __init__(self, dim): @@ -18,8 +18,8 @@ class Block: class ConvNeXt: def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]): self.downsample_layers = [ - [Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm((dims[0], 1, 1), eps=1e-6)], - *[[LayerNorm((dims[i], 1, 1), eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)] + [Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)], + *[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)] ] self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))] self.norm = LayerNorm(dims[-1]) diff --git a/test/test_nn.py b/test/test_nn.py index 736d9e3e1d..136138672e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad.tensor import Tensor, Device -from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm +from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d import torch class TestNN(unittest.TestCase): @@ -76,7 +76,7 @@ class TestNN(unittest.TestCase): def test_conv2d(self): BS, C1, H, W = 4, 16, 224, 224 C2, K, S, P = 64, 7, 2, 1 - + # create in tinygrad layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P) @@ -131,5 +131,24 @@ class TestNN(unittest.TestCase): torch_z = torch_layer(torch_x) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + def test_layernorm_2d(self): + N, C, H, W = 20, 5, 10, 10 + + # create in tinygrad + layer = LayerNorm2d(C) + + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.LayerNorm([C]).eval() + torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + + # test + x = Tensor.randn(N, C, H, W) + z = layer(x) + torch_x = torch.tensor(x.cpu().numpy()) + torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 8f3ffba6e2..88e82f550f 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -73,11 +73,15 @@ class GroupNorm: class LayerNorm: def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True): - normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape) - self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine - self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None) + self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape) + self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine + self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None) def __call__(self, x:Tensor): + assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}" x = x.layernorm(eps=self.eps, axis=self.axis) if not self.elementwise_affine: return x return x * self.weight + self.bias + +class LayerNorm2d(LayerNorm): + def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)