From 2f17d151b33dee7672e535fcbb930bfdc369bb96 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 27 Feb 2023 10:19:54 -0800 Subject: [PATCH] fix batchnorm not realizing --- test/test_mnist.py | 20 ++++++++++++++++---- tinygrad/nn/__init__.py | 6 +++--- tinygrad/nn/optim.py | 4 +++- tinygrad/tensor.py | 4 ++-- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/test/test_mnist.py b/test/test_mnist.py index 544bf1eaa8..66086a02b7 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad.tensor import Tensor, Device -from tinygrad.nn import optim +from tinygrad.nn import optim, BatchNorm2d from extra.training import train, evaluate from datasets import fetch_mnist @@ -23,7 +23,7 @@ class TinyBobNet: # create a model with a conv layer class TinyConvNet: - def __init__(self): + def __init__(self, has_batchnorm=False): # https://keras.io/examples/vision/mnist_convnet/ conv = 3 #inter_chan, out_chan = 32, 64 @@ -31,14 +31,19 @@ class TinyConvNet: self.c1 = Tensor.scaled_uniform(inter_chan,1,conv,conv) self.c2 = Tensor.scaled_uniform(out_chan,inter_chan,conv,conv) self.l1 = Tensor.scaled_uniform(out_chan*5*5, 10) + if has_batchnorm: + self.bn1 = BatchNorm2d(inter_chan) + self.bn2 = BatchNorm2d(out_chan) + else: + self.bn1, self.bn2 = lambda x: x, lambda x: x def parameters(self): return optim.get_parameters(self) def forward(self, x): x = x.reshape(shape=(-1, 1, 28, 28)) # hacks - x = x.conv2d(self.c1).relu().max_pool2d() - x = x.conv2d(self.c2).relu().max_pool2d() + x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d() + x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d() x = x.reshape(shape=[x.shape[0], -1]) return x.dot(self.l1).log_softmax() @@ -89,6 +94,13 @@ class TestMNIST(unittest.TestCase): train(model, X_train, Y_train, optimizer, steps=100) assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes + def test_conv_with_bn(self): + np.random.seed(1337) + model = TinyConvNet(has_batchnorm=True) + optimizer = optim.Adam(model.parameters(), lr=0.001) + train(model, X_train, Y_train, optimizer, steps=100) + # TODO: batchnorm doesn't work!!! + def test_sgd(self): np.random.seed(1337) model = TinyBobNet() diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index af247bb6b2..6d203da21b 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -11,7 +11,7 @@ class BatchNorm2d: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False) self.num_batches_tracked = Tensor.zeros(1, requires_grad=False) - def __call__(self, x): + def __call__(self, x:Tensor): if Tensor.training: # This requires two full memory accesses to x # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh @@ -25,8 +25,8 @@ class BatchNorm2d: # NOTE: wow, this is done all throughout training in most PyTorch models if self.track_running_stats: - self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean - self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var + self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean) + self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var) self.num_batches_tracked += 1 else: batch_mean, batch_var = self.running_mean, self.running_var diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index ce7fb19bb5..d2c68b9433 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -10,6 +10,8 @@ class Optimizer: x.requires_grad = True self.params : List[Tensor] = [x for x in params if x.requires_grad] + self.buffers : List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized + self.realize() # TODO: this probably shouldn't change the gradients, just the ones used by the optimizer def clipnorm(self, amount=1): @@ -24,7 +26,7 @@ class Optimizer: def realize(self, extra=None): # TODO: corealize - for p in extra + self.params if extra is not None else self.params: + for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers: p.realize() class SGD(Optimizer): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cc5d339365..d9ddc1ffd7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -114,10 +114,10 @@ class Tensor: # TODO: remove use of numpy here and make lazy @staticmethod - def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape) + def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous() @staticmethod - def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape) + def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous() @staticmethod def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs)