fix batchnorm at training (#753)

* e2e testing

* min failure

* no affine on bn, still fails

* why did i think i could detach that?

* allow more kernels for bn

* some test issue i don't understand
This commit is contained in:
George Hotz
2023-04-19 08:01:04 -07:00
committed by GitHub
parent 1aa0648d6a
commit 03b38864db
6 changed files with 186 additions and 14 deletions

View File

@@ -72,7 +72,7 @@ jobs:
- name: Install Dependencies
run: pip install -e .
- name: Compile EfficientNet to C
run: CLANG=1 python3 examples/compile_efficientnet.py > recognize.c
run: PYTHONPATH="." CLANG=1 python3 examples/compile_efficientnet.py > recognize.c
- name: Compile C to native
run: clang -O2 recognize.c -lm -o recognize
- name: Test EfficientNet

13
test/external/external_hlb_cifar.py vendored Normal file
View File

@@ -0,0 +1,13 @@
#!/usr/bin/env python3
from examples.hlb_cifar10 import SpeedyResNet, fetch_batch
from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch
from datasets import fetch_cifar
from test.models.test_end2end import compare_tiny_torch
if __name__ == "__main__":
X_test, Y_test = fetch_cifar(train=False)
X, Y = fetch_batch(X_test, Y_test, 32)
print(X.shape, Y.shape)
model = SpeedyResNet()
model_torch = SpeedyResNetTorch()
compare_tiny_torch(model, model_torch, X, Y)

View File

@@ -228,13 +228,11 @@ class TestOpt(unittest.TestCase):
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
opt = optim.SGD(optim.get_parameters([c1, bn]))
with CLCache():
with CLCache(allowed=18): # this is too high
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
opt.step()
# TODO: broken with optim fixes
assert len(GlobalCounters.cache) in [9,10,13,14], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}"
Tensor.training = False
def test_fold_conv_batchnorm_notrain(self):

160
test/models/test_end2end.py Normal file
View File

@@ -0,0 +1,160 @@
import torch
from torch import nn
import unittest
import numpy as np
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor
from datasets import fetch_mnist
def compare_tiny_torch(model, model_torch, X, Y):
Tensor.training = True
model_torch.train()
model_state_dict = optim.get_state_dict(model)
for k,v in model_torch.named_parameters():
print(f"initting {k} from torch")
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
optimizer = optim.SGD(optim.get_parameters(model), lr=0.01)
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
Xt = torch.Tensor(X.numpy())
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
out = model(X)
loss = (out * Y).mean()
print(loss.realize().numpy()[0])
out_torch = model_torch(torch.Tensor(X.numpy()))
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
print(loss_torch.detach().numpy())
# assert losses match
np.testing.assert_allclose(loss.realize().numpy()[0], loss_torch.detach().numpy(), atol=1e-4)
# zero and backward
optimizer.zero_grad()
loss.backward()
optimizer_torch.zero_grad()
loss_torch.backward()
for k,v in list(model_torch.named_parameters())[::-1]:
g = model_state_dict[k].grad.numpy()
gt = v.grad.detach().numpy()
print("testing grads", k)
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
# take the steps
optimizer.step()
optimizer_torch.step()
# assert weights match (they don't!)
for k,v in model_torch.named_parameters():
print("testing weight", k)
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
def get_mnist_data():
X_train, Y_train, X_test, Y_test = fetch_mnist()
BS = 32
num_classes = 10
X = Tensor(X_test[0:BS].astype(np.float32))
Y = np.zeros((BS, num_classes), np.float32)
Y[range(BS),Y_test[0:BS]] = -1.0*num_classes
return X, Tensor(Y)
class TestEnd2End(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.X, cls.Y = get_mnist_data()
def test_linear_mnist(self):
class LinTiny:
def __init__(self, has_batchnorm=False):
self.l1 = Linear(784, 128)
self.l2 = Linear(128, 10)
self.bn1 = BatchNorm2d(128) if has_batchnorm else lambda x: x
def __call__(self, x):
return self.l2(self.l1(x)).relu().log_softmax(-1)
class LinTorch(nn.Module):
def __init__(self, has_batchnorm=False):
super().__init__()
self.l1 = nn.Linear(784, 128)
self.l2 = nn.Linear(128, 10)
def forward(self, x):
return self.l2(self.l1(x)).relu().log_softmax(-1)
compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y)
def test_bn_mnist(self):
class LinTiny:
def __init__(self):
self.l1 = Linear(784, 128)
self.l2 = Linear(128, 10)
self.bn1 = BatchNorm2d(128)
def __call__(self, x):
return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1)
class LinTorch(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(784, 128)
self.l2 = nn.Linear(128, 10)
self.bn1 = nn.BatchNorm2d(128)
def forward(self, x):
return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1)
compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y)
def test_bn_alone(self):
np.random.seed(1337)
X = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32))
Y = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32))
compare_tiny_torch(BatchNorm2d(10), nn.BatchNorm2d(10), X, Y)
def test_bn_linear(self):
BS, K = 2, 1
eps = 0
X = Tensor([1,0]).reshape(BS, K, 1, 1)
Y = Tensor([-1,0]).reshape(BS, K, 1, 1)
class LinTiny:
def __init__(self):
self.l1 = Conv2d(K, K, 1, bias=False)
self.bn1 = BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps)
def __call__(self, x): return self.bn1(self.l1(x))
class LinTorch(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Conv2d(K, K, 1, bias=False)
self.bn1 = nn.BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps)
def forward(self, x): return self.bn1(self.l1(x))
model_torch = LinTorch()
with torch.no_grad():
model_torch.l1.weight[:] = 1.
compare_tiny_torch(LinTiny(), model_torch, X, Y)
def test_conv_mnist(self):
class LinTiny:
def __init__(self, has_batchnorm=False):
self.c1 = Conv2d(1, 8, 3, stride=2)
self.c2 = Conv2d(8, 16, 3, stride=2)
self.l1 = Linear(16*6*6, 10)
if has_batchnorm:
self.bn1, self.bn2 = BatchNorm2d(8), BatchNorm2d(16)
else:
self.bn1, self.bn2 = lambda x: x, lambda x: x
def __call__(self, x):
return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1)
class LinTorch(nn.Module):
def __init__(self, has_batchnorm=False):
super().__init__()
self.c1 = nn.Conv2d(1, 8, 3, stride=2)
self.c2 = nn.Conv2d(8, 16, 3, stride=2)
self.l1 = nn.Linear(16*6*6, 10)
if has_batchnorm:
self.bn1, self.bn2 = nn.BatchNorm2d(8), nn.BatchNorm2d(16)
else:
self.bn1, self.bn2 = lambda x: x, lambda x: x
def forward(self, x):
return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1)
for has_batchnorm in [False, True]:
with self.subTest(has_batchnorm=has_batchnorm):
compare_tiny_torch(LinTiny(has_batchnorm), LinTorch(has_batchnorm), self.X.reshape((-1, 1, 28, 28)), self.Y)
if __name__ == "__main__":
unittest.main()

View File

@@ -3,10 +3,10 @@ from tinygrad.tensor import Tensor
class BatchNorm2d:
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
assert affine, "BatchNorm2d is only supported with affine"
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
else: self.weight, self.bias = None, None
self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
@@ -16,16 +16,15 @@ class BatchNorm2d:
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
x_detached = x.detach()
batch_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_mean = x.mean(axis=(0,2,3))
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(0,2,3))
batch_invstd = batch_var.add(self.eps).pow(-0.5)
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean)
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var)
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var.detach())
self.num_batches_tracked += 1
else:
batch_mean = self.running_mean

View File

@@ -463,9 +463,11 @@ class Tensor:
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Tensor, bias:Tensor, mean:Tensor, invstd:Tensor) -> Tensor:
x = (self - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1])
return x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd) + bias.reshape(shape=[1, -1, 1, 1])
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
x = (self - mean.reshape(shape=[1, -1, 1, 1]))
if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training: return self