mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
enable batchnorm in serious mnist
This commit is contained in:
@@ -19,20 +19,22 @@ class SeriousModel:
|
||||
def __init__(self):
|
||||
self.blocks = 3
|
||||
self.block_convs = 3
|
||||
self.chans = 128
|
||||
|
||||
# TODO: raise back to 128 when it's fast
|
||||
self.chans = 32
|
||||
|
||||
self.convs = [Tensor.uniform(self.chans, self.chans if i > 0 else 1, 3, 3) for i in range(self.blocks * self.block_convs)]
|
||||
# TODO: Make batchnorm work at train time
|
||||
#self.bn = [BatchNorm2D(self.chans) for i in range(3)]
|
||||
self.bn = [BatchNorm2D(self.chans, training=True) for i in range(3)]
|
||||
self.fc = Tensor.uniform(self.chans, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
|
||||
for i in range(self.blocks):
|
||||
for j in range(self.block_convs):
|
||||
#print(i, j, x.shape, x.sum().cpu())
|
||||
# TODO: should padding be used?
|
||||
x = x.conv2d(self.convs[i*3+j]).relu()
|
||||
#x = self.bn[i](x)
|
||||
x = self.bn[i](x)
|
||||
if i > 0:
|
||||
x = x.avg_pool2d(kernel_size=(2,2))
|
||||
# TODO: Add concat support to concat with max_pool2d
|
||||
|
||||
Reference in New Issue
Block a user