mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 09:28:04 -05:00
fix batchnorm not realizing
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.nn import optim, BatchNorm2d
|
||||
from extra.training import train, evaluate
|
||||
from datasets import fetch_mnist
|
||||
|
||||
@@ -23,7 +23,7 @@ class TinyBobNet:
|
||||
|
||||
# create a model with a conv layer
|
||||
class TinyConvNet:
|
||||
def __init__(self):
|
||||
def __init__(self, has_batchnorm=False):
|
||||
# https://keras.io/examples/vision/mnist_convnet/
|
||||
conv = 3
|
||||
#inter_chan, out_chan = 32, 64
|
||||
@@ -31,14 +31,19 @@ class TinyConvNet:
|
||||
self.c1 = Tensor.scaled_uniform(inter_chan,1,conv,conv)
|
||||
self.c2 = Tensor.scaled_uniform(out_chan,inter_chan,conv,conv)
|
||||
self.l1 = Tensor.scaled_uniform(out_chan*5*5, 10)
|
||||
if has_batchnorm:
|
||||
self.bn1 = BatchNorm2d(inter_chan)
|
||||
self.bn2 = BatchNorm2d(out_chan)
|
||||
else:
|
||||
self.bn1, self.bn2 = lambda x: x, lambda x: x
|
||||
|
||||
def parameters(self):
|
||||
return optim.get_parameters(self)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(shape=(-1, 1, 28, 28)) # hacks
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d()
|
||||
x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d()
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1).log_softmax()
|
||||
|
||||
@@ -89,6 +94,13 @@ class TestMNIST(unittest.TestCase):
|
||||
train(model, X_train, Y_train, optimizer, steps=100)
|
||||
assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes
|
||||
|
||||
def test_conv_with_bn(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet(has_batchnorm=True)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, steps=100)
|
||||
# TODO: batchnorm doesn't work!!!
|
||||
|
||||
def test_sgd(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
|
||||
@@ -11,7 +11,7 @@ class BatchNorm2d:
|
||||
self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x:Tensor):
|
||||
if Tensor.training:
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
@@ -25,8 +25,8 @@ class BatchNorm2d:
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
|
||||
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean)
|
||||
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var)
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
batch_mean, batch_var = self.running_mean, self.running_var
|
||||
|
||||
@@ -10,6 +10,8 @@ class Optimizer:
|
||||
x.requires_grad = True
|
||||
|
||||
self.params : List[Tensor] = [x for x in params if x.requires_grad]
|
||||
self.buffers : List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized
|
||||
self.realize()
|
||||
|
||||
# TODO: this probably shouldn't change the gradients, just the ones used by the optimizer
|
||||
def clipnorm(self, amount=1):
|
||||
@@ -24,7 +26,7 @@ class Optimizer:
|
||||
|
||||
def realize(self, extra=None):
|
||||
# TODO: corealize
|
||||
for p in extra + self.params if extra is not None else self.params:
|
||||
for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers:
|
||||
p.realize()
|
||||
|
||||
class SGD(Optimizer):
|
||||
|
||||
@@ -114,10 +114,10 @@ class Tensor:
|
||||
# TODO: remove use of numpy here and make lazy
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape)
|
||||
def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous()
|
||||
|
||||
@staticmethod
|
||||
def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape)
|
||||
def ones(*shape, **kwargs): return Tensor([1], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous()
|
||||
|
||||
@staticmethod
|
||||
def zeros_like(tensor, **kwargs): return Tensor.zeros(*tensor.shape, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user