mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
LayerNorm2d for 2 lines
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, LayerNorm, Linear
|
||||
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
|
||||
|
||||
class Block:
|
||||
def __init__(self, dim):
|
||||
@@ -18,8 +18,8 @@ class Block:
|
||||
class ConvNeXt:
|
||||
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
|
||||
self.downsample_layers = [
|
||||
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm((dims[0], 1, 1), eps=1e-6)],
|
||||
*[[LayerNorm((dims[i], 1, 1), eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
|
||||
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
|
||||
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
|
||||
]
|
||||
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
|
||||
self.norm = LayerNorm(dims[-1])
|
||||
|
||||
Reference in New Issue
Block a user