LayerNorm2d for 2 lines

This commit is contained in:
George Hotz
2023-03-20 16:58:43 -07:00
parent 128ca160ac
commit d6f4219952
3 changed files with 31 additions and 8 deletions

View File

@@ -1,5 +1,5 @@
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, LayerNorm, Linear
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
class Block:
def __init__(self, dim):
@@ -18,8 +18,8 @@ class Block:
class ConvNeXt:
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
self.downsample_layers = [
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm((dims[0], 1, 1), eps=1e-6)],
*[[LayerNorm((dims[i], 1, 1), eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
]
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
self.norm = LayerNorm(dims[-1])