mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix batchnorm at training (#753)
* e2e testing * min failure * no affine on bn, still fails * why did i think i could detach that? * allow more kernels for bn * some test issue i don't understand
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: pip install -e .
|
||||
- name: Compile EfficientNet to C
|
||||
run: CLANG=1 python3 examples/compile_efficientnet.py > recognize.c
|
||||
run: PYTHONPATH="." CLANG=1 python3 examples/compile_efficientnet.py > recognize.c
|
||||
- name: Compile C to native
|
||||
run: clang -O2 recognize.c -lm -o recognize
|
||||
- name: Test EfficientNet
|
||||
|
||||
13
test/external/external_hlb_cifar.py
vendored
Normal file
13
test/external/external_hlb_cifar.py
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
from examples.hlb_cifar10 import SpeedyResNet, fetch_batch
|
||||
from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch
|
||||
from datasets import fetch_cifar
|
||||
from test.models.test_end2end import compare_tiny_torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_test, Y_test = fetch_cifar(train=False)
|
||||
X, Y = fetch_batch(X_test, Y_test, 32)
|
||||
print(X.shape, Y.shape)
|
||||
model = SpeedyResNet()
|
||||
model_torch = SpeedyResNetTorch()
|
||||
compare_tiny_torch(model, model_torch, X, Y)
|
||||
4
test/external/external_test_opt.py
vendored
4
test/external/external_test_opt.py
vendored
@@ -228,13 +228,11 @@ class TestOpt(unittest.TestCase):
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
opt = optim.SGD(optim.get_parameters([c1, bn]))
|
||||
with CLCache():
|
||||
with CLCache(allowed=18): # this is too high
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
opt.step()
|
||||
# TODO: broken with optim fixes
|
||||
assert len(GlobalCounters.cache) in [9,10,13,14], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}"
|
||||
Tensor.training = False
|
||||
|
||||
def test_fold_conv_batchnorm_notrain(self):
|
||||
|
||||
160
test/models/test_end2end.py
Normal file
160
test/models/test_end2end.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
|
||||
from tinygrad.tensor import Tensor
|
||||
from datasets import fetch_mnist
|
||||
|
||||
def compare_tiny_torch(model, model_torch, X, Y):
|
||||
Tensor.training = True
|
||||
model_torch.train()
|
||||
model_state_dict = optim.get_state_dict(model)
|
||||
for k,v in model_torch.named_parameters():
|
||||
print(f"initting {k} from torch")
|
||||
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
|
||||
|
||||
optimizer = optim.SGD(optim.get_parameters(model), lr=0.01)
|
||||
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
|
||||
|
||||
Xt = torch.Tensor(X.numpy())
|
||||
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
|
||||
|
||||
out = model(X)
|
||||
loss = (out * Y).mean()
|
||||
print(loss.realize().numpy()[0])
|
||||
|
||||
out_torch = model_torch(torch.Tensor(X.numpy()))
|
||||
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
|
||||
print(loss_torch.detach().numpy())
|
||||
|
||||
# assert losses match
|
||||
np.testing.assert_allclose(loss.realize().numpy()[0], loss_torch.detach().numpy(), atol=1e-4)
|
||||
|
||||
# zero and backward
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer_torch.zero_grad()
|
||||
loss_torch.backward()
|
||||
|
||||
for k,v in list(model_torch.named_parameters())[::-1]:
|
||||
g = model_state_dict[k].grad.numpy()
|
||||
gt = v.grad.detach().numpy()
|
||||
print("testing grads", k)
|
||||
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
|
||||
|
||||
# take the steps
|
||||
optimizer.step()
|
||||
optimizer_torch.step()
|
||||
|
||||
# assert weights match (they don't!)
|
||||
for k,v in model_torch.named_parameters():
|
||||
print("testing weight", k)
|
||||
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
|
||||
|
||||
def get_mnist_data():
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
BS = 32
|
||||
num_classes = 10
|
||||
X = Tensor(X_test[0:BS].astype(np.float32))
|
||||
Y = np.zeros((BS, num_classes), np.float32)
|
||||
Y[range(BS),Y_test[0:BS]] = -1.0*num_classes
|
||||
return X, Tensor(Y)
|
||||
|
||||
class TestEnd2End(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.X, cls.Y = get_mnist_data()
|
||||
|
||||
def test_linear_mnist(self):
|
||||
class LinTiny:
|
||||
def __init__(self, has_batchnorm=False):
|
||||
self.l1 = Linear(784, 128)
|
||||
self.l2 = Linear(128, 10)
|
||||
self.bn1 = BatchNorm2d(128) if has_batchnorm else lambda x: x
|
||||
def __call__(self, x):
|
||||
return self.l2(self.l1(x)).relu().log_softmax(-1)
|
||||
class LinTorch(nn.Module):
|
||||
def __init__(self, has_batchnorm=False):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(784, 128)
|
||||
self.l2 = nn.Linear(128, 10)
|
||||
def forward(self, x):
|
||||
return self.l2(self.l1(x)).relu().log_softmax(-1)
|
||||
compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y)
|
||||
|
||||
def test_bn_mnist(self):
|
||||
class LinTiny:
|
||||
def __init__(self):
|
||||
self.l1 = Linear(784, 128)
|
||||
self.l2 = Linear(128, 10)
|
||||
self.bn1 = BatchNorm2d(128)
|
||||
def __call__(self, x):
|
||||
return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1)
|
||||
class LinTorch(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(784, 128)
|
||||
self.l2 = nn.Linear(128, 10)
|
||||
self.bn1 = nn.BatchNorm2d(128)
|
||||
def forward(self, x):
|
||||
return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1)
|
||||
compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y)
|
||||
|
||||
def test_bn_alone(self):
|
||||
np.random.seed(1337)
|
||||
X = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32))
|
||||
Y = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32))
|
||||
compare_tiny_torch(BatchNorm2d(10), nn.BatchNorm2d(10), X, Y)
|
||||
|
||||
def test_bn_linear(self):
|
||||
BS, K = 2, 1
|
||||
eps = 0
|
||||
X = Tensor([1,0]).reshape(BS, K, 1, 1)
|
||||
Y = Tensor([-1,0]).reshape(BS, K, 1, 1)
|
||||
class LinTiny:
|
||||
def __init__(self):
|
||||
self.l1 = Conv2d(K, K, 1, bias=False)
|
||||
self.bn1 = BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps)
|
||||
def __call__(self, x): return self.bn1(self.l1(x))
|
||||
class LinTorch(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l1 = nn.Conv2d(K, K, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps)
|
||||
def forward(self, x): return self.bn1(self.l1(x))
|
||||
model_torch = LinTorch()
|
||||
with torch.no_grad():
|
||||
model_torch.l1.weight[:] = 1.
|
||||
compare_tiny_torch(LinTiny(), model_torch, X, Y)
|
||||
|
||||
def test_conv_mnist(self):
|
||||
class LinTiny:
|
||||
def __init__(self, has_batchnorm=False):
|
||||
self.c1 = Conv2d(1, 8, 3, stride=2)
|
||||
self.c2 = Conv2d(8, 16, 3, stride=2)
|
||||
self.l1 = Linear(16*6*6, 10)
|
||||
if has_batchnorm:
|
||||
self.bn1, self.bn2 = BatchNorm2d(8), BatchNorm2d(16)
|
||||
else:
|
||||
self.bn1, self.bn2 = lambda x: x, lambda x: x
|
||||
def __call__(self, x):
|
||||
return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1)
|
||||
class LinTorch(nn.Module):
|
||||
def __init__(self, has_batchnorm=False):
|
||||
super().__init__()
|
||||
self.c1 = nn.Conv2d(1, 8, 3, stride=2)
|
||||
self.c2 = nn.Conv2d(8, 16, 3, stride=2)
|
||||
self.l1 = nn.Linear(16*6*6, 10)
|
||||
if has_batchnorm:
|
||||
self.bn1, self.bn2 = nn.BatchNorm2d(8), nn.BatchNorm2d(16)
|
||||
else:
|
||||
self.bn1, self.bn2 = lambda x: x, lambda x: x
|
||||
def forward(self, x):
|
||||
return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1)
|
||||
for has_batchnorm in [False, True]:
|
||||
with self.subTest(has_batchnorm=has_batchnorm):
|
||||
compare_tiny_torch(LinTiny(has_batchnorm), LinTorch(has_batchnorm), self.X.reshape((-1, 1, 28, 28)), self.Y)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -3,10 +3,10 @@ from tinygrad.tensor import Tensor
|
||||
|
||||
class BatchNorm2d:
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
assert affine, "BatchNorm2d is only supported with affine"
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
|
||||
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
|
||||
else: self.weight, self.bias = None, None
|
||||
|
||||
self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
@@ -16,16 +16,15 @@ class BatchNorm2d:
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||
x_detached = x.detach()
|
||||
batch_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
batch_mean = x.mean(axis=(0,2,3))
|
||||
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
batch_var = (y*y).mean(axis=(0,2,3))
|
||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean)
|
||||
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var)
|
||||
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
||||
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var.detach())
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
batch_mean = self.running_mean
|
||||
|
||||
@@ -463,9 +463,11 @@ class Tensor:
|
||||
y = (self - self.mean(axis, keepdim=True))
|
||||
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
||||
|
||||
def batchnorm(self, weight:Tensor, bias:Tensor, mean:Tensor, invstd:Tensor) -> Tensor:
|
||||
x = (self - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1])
|
||||
return x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd) + bias.reshape(shape=[1, -1, 1, 1])
|
||||
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
|
||||
x = (self - mean.reshape(shape=[1, -1, 1, 1]))
|
||||
if weight: x = x * weight.reshape(shape=[1, -1, 1, 1])
|
||||
ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd)
|
||||
return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret
|
||||
|
||||
def dropout(self, p=0.5) -> Tensor:
|
||||
if not Tensor.training: return self
|
||||
|
||||
Reference in New Issue
Block a user