fix batchnorm not realizing

This commit is contained in:
George Hotz
2023-02-27 10:19:54 -08:00
parent c9252d38b2
commit 2f17d151b3
4 changed files with 24 additions and 10 deletions

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import optim
from tinygrad.nn import optim, BatchNorm2d
from extra.training import train, evaluate
from datasets import fetch_mnist
@@ -23,7 +23,7 @@ class TinyBobNet:
# create a model with a conv layer
class TinyConvNet:
def __init__(self):
def __init__(self, has_batchnorm=False):
# https://keras.io/examples/vision/mnist_convnet/
conv = 3
#inter_chan, out_chan = 32, 64
@@ -31,14 +31,19 @@ class TinyConvNet:
self.c1 = Tensor.scaled_uniform(inter_chan,1,conv,conv)
self.c2 = Tensor.scaled_uniform(out_chan,inter_chan,conv,conv)
self.l1 = Tensor.scaled_uniform(out_chan*5*5, 10)
if has_batchnorm:
self.bn1 = BatchNorm2d(inter_chan)
self.bn2 = BatchNorm2d(out_chan)
else:
self.bn1, self.bn2 = lambda x: x, lambda x: x
def parameters(self):
return optim.get_parameters(self)
def forward(self, x):
x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
x = x.conv2d(self.c1).relu().max_pool2d()
x = x.conv2d(self.c2).relu().max_pool2d()
x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d()
x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d()
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1).log_softmax()
@@ -89,6 +94,13 @@ class TestMNIST(unittest.TestCase):
train(model, X_train, Y_train, optimizer, steps=100)
assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes
def test_conv_with_bn(self):
np.random.seed(1337)
model = TinyConvNet(has_batchnorm=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=100)
# TODO: batchnorm doesn't work!!!
def test_sgd(self):
np.random.seed(1337)
model = TinyBobNet()

View File

@@ -11,7 +11,7 @@ class BatchNorm2d:
self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
def __call__(self, x):
def __call__(self, x:Tensor):
if Tensor.training:
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
@@ -25,8 +25,8 @@ class BatchNorm2d:
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean)
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var)
self.num_batches_tracked += 1
else:
batch_mean, batch_var = self.running_mean, self.running_var

View File

@@ -10,6 +10,8 @@ class Optimizer:
x.requires_grad = True
self.params : List[Tensor] = [x for x in params if x.requires_grad]
self.buffers : List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized
self.realize()
# TODO: this probably shouldn't change the gradients, just the ones used by the optimizer
def clipnorm(self, amount=1):
@@ -24,7 +26,7 @@ class Optimizer:
def realize(self, extra=None):
# TODO: corealize
for p in extra + self.params if extra is not None else self.params:
for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers:
p.realize()
class SGD(Optimizer):

View File

@@ -114,10 +114,10 @@ class Tensor:
# TODO: remove use of numpy here and make lazy
@staticmethod
def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape)
def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous()
@staticmethod
def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape)
def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous()
@staticmethod
def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs)