enable batchnorm in serious mnist

This commit is contained in:
George Hotz
2020-12-09 03:29:40 -08:00
parent ffb96b2d0b
commit 99fa65f057

View File

@@ -19,20 +19,22 @@ class SeriousModel:
def __init__(self):
self.blocks = 3
self.block_convs = 3
self.chans = 128
# TODO: raise back to 128 when it's fast
self.chans = 32
self.convs = [Tensor.uniform(self.chans, self.chans if i > 0 else 1, 3, 3) for i in range(self.blocks * self.block_convs)]
# TODO: Make batchnorm work at train time
#self.bn = [BatchNorm2D(self.chans) for i in range(3)]
self.bn = [BatchNorm2D(self.chans, training=True) for i in range(3)]
self.fc = Tensor.uniform(self.chans, 10)
def forward(self, x):
x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
for i in range(self.blocks):
for j in range(self.block_convs):
#print(i, j, x.shape, x.sum().cpu())
# TODO: should padding be used?
x = x.conv2d(self.convs[i*3+j]).relu()
#x = self.bn[i](x)
x = self.bn[i](x)
if i > 0:
x = x.avg_pool2d(kernel_size=(2,2))
# TODO: Add concat support to concat with max_pool2d