fix batchnorm not realizing

This commit is contained in:
George Hotz
2023-02-27 10:19:54 -08:00
parent c9252d38b2
commit 2f17d151b3
4 changed files with 24 additions and 10 deletions

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import optim
from tinygrad.nn import optim, BatchNorm2d
from extra.training import train, evaluate
from datasets import fetch_mnist
@@ -23,7 +23,7 @@ class TinyBobNet:
# create a model with a conv layer
class TinyConvNet:
def __init__(self):
def __init__(self, has_batchnorm=False):
# https://keras.io/examples/vision/mnist_convnet/
conv = 3
#inter_chan, out_chan = 32, 64
@@ -31,14 +31,19 @@ class TinyConvNet:
self.c1 = Tensor.scaled_uniform(inter_chan,1,conv,conv)
self.c2 = Tensor.scaled_uniform(out_chan,inter_chan,conv,conv)
self.l1 = Tensor.scaled_uniform(out_chan*5*5, 10)
if has_batchnorm:
self.bn1 = BatchNorm2d(inter_chan)
self.bn2 = BatchNorm2d(out_chan)
else:
self.bn1, self.bn2 = lambda x: x, lambda x: x
def parameters(self):
return optim.get_parameters(self)
def forward(self, x):
x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
x = x.conv2d(self.c1).relu().max_pool2d()
x = x.conv2d(self.c2).relu().max_pool2d()
x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d()
x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d()
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1).log_softmax()
@@ -89,6 +94,13 @@ class TestMNIST(unittest.TestCase):
train(model, X_train, Y_train, optimizer, steps=100)
assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes
def test_conv_with_bn(self):
np.random.seed(1337)
model = TinyConvNet(has_batchnorm=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=100)
# TODO: batchnorm doesn't work!!!
def test_sgd(self):
np.random.seed(1337)
model = TinyBobNet()