mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix batchnorm not realizing
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.nn import optim, BatchNorm2d
|
||||
from extra.training import train, evaluate
|
||||
from datasets import fetch_mnist
|
||||
|
||||
@@ -23,7 +23,7 @@ class TinyBobNet:
|
||||
|
||||
# create a model with a conv layer
|
||||
class TinyConvNet:
|
||||
def __init__(self):
|
||||
def __init__(self, has_batchnorm=False):
|
||||
# https://keras.io/examples/vision/mnist_convnet/
|
||||
conv = 3
|
||||
#inter_chan, out_chan = 32, 64
|
||||
@@ -31,14 +31,19 @@ class TinyConvNet:
|
||||
self.c1 = Tensor.scaled_uniform(inter_chan,1,conv,conv)
|
||||
self.c2 = Tensor.scaled_uniform(out_chan,inter_chan,conv,conv)
|
||||
self.l1 = Tensor.scaled_uniform(out_chan*5*5, 10)
|
||||
if has_batchnorm:
|
||||
self.bn1 = BatchNorm2d(inter_chan)
|
||||
self.bn2 = BatchNorm2d(out_chan)
|
||||
else:
|
||||
self.bn1, self.bn2 = lambda x: x, lambda x: x
|
||||
|
||||
def parameters(self):
|
||||
return optim.get_parameters(self)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d()
|
||||
x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d()
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1).log_softmax()
|
||||
|
||||
@@ -89,6 +94,13 @@ class TestMNIST(unittest.TestCase):
|
||||
train(model, X_train, Y_train, optimizer, steps=100)
|
||||
assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes
|
||||
|
||||
def test_conv_with_bn(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet(has_batchnorm=True)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, steps=100)
|
||||
# TODO: batchnorm doesn't work!!!
|
||||
|
||||
def test_sgd(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
|
||||
Reference in New Issue
Block a user