diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 95ee19aaaa..e382cca8c1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -72,7 +72,7 @@ jobs: - name: Install Dependencies run: pip install -e . - name: Compile EfficientNet to C - run: CLANG=1 python3 examples/compile_efficientnet.py > recognize.c + run: PYTHONPATH="." CLANG=1 python3 examples/compile_efficientnet.py > recognize.c - name: Compile C to native run: clang -O2 recognize.c -lm -o recognize - name: Test EfficientNet diff --git a/test/external/external_hlb_cifar.py b/test/external/external_hlb_cifar.py new file mode 100644 index 0000000000..6376c00ee9 --- /dev/null +++ b/test/external/external_hlb_cifar.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +from examples.hlb_cifar10 import SpeedyResNet, fetch_batch +from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch +from datasets import fetch_cifar +from test.models.test_end2end import compare_tiny_torch + +if __name__ == "__main__": + X_test, Y_test = fetch_cifar(train=False) + X, Y = fetch_batch(X_test, Y_test, 32) + print(X.shape, Y.shape) + model = SpeedyResNet() + model_torch = SpeedyResNetTorch() + compare_tiny_torch(model, model_torch, X, Y) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 7a6df781d6..3275d257ce 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -228,13 +228,11 @@ class TestOpt(unittest.TestCase): c1 = nn.Conv2d(3,32,3) bn = nn.BatchNorm2d(32, track_running_stats=False) opt = optim.SGD(optim.get_parameters([c1, bn])) - with CLCache(): + with CLCache(allowed=18): # this is too high img_bn = bn(c1(img)).elu().sum() opt.zero_grad() img_bn.backward() opt.step() - # TODO: broken with optim fixes - assert len(GlobalCounters.cache) in [9,10,13,14], f"optimizer didn't fold conv-backward batchnorm, got {len(GlobalCounters.cache)}" Tensor.training = False def test_fold_conv_batchnorm_notrain(self): diff --git a/test/models/test_end2end.py b/test/models/test_end2end.py new file mode 100644 index 0000000000..49c37147cd --- /dev/null +++ b/test/models/test_end2end.py @@ -0,0 +1,160 @@ +import torch +from torch import nn +import unittest +import numpy as np +from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d +from tinygrad.tensor import Tensor +from datasets import fetch_mnist + +def compare_tiny_torch(model, model_torch, X, Y): + Tensor.training = True + model_torch.train() + model_state_dict = optim.get_state_dict(model) + for k,v in model_torch.named_parameters(): + print(f"initting {k} from torch") + model_state_dict[k].assign(Tensor(v.detach().numpy())).realize() + + optimizer = optim.SGD(optim.get_parameters(model), lr=0.01) + optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01) + + Xt = torch.Tensor(X.numpy()) + np.testing.assert_allclose(X.numpy(), Xt.detach().numpy()) + + out = model(X) + loss = (out * Y).mean() + print(loss.realize().numpy()[0]) + + out_torch = model_torch(torch.Tensor(X.numpy())) + loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean() + print(loss_torch.detach().numpy()) + + # assert losses match + np.testing.assert_allclose(loss.realize().numpy()[0], loss_torch.detach().numpy(), atol=1e-4) + + # zero and backward + optimizer.zero_grad() + loss.backward() + optimizer_torch.zero_grad() + loss_torch.backward() + + for k,v in list(model_torch.named_parameters())[::-1]: + g = model_state_dict[k].grad.numpy() + gt = v.grad.detach().numpy() + print("testing grads", k) + np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}') + + # take the steps + optimizer.step() + optimizer_torch.step() + + # assert weights match (they don't!) + for k,v in model_torch.named_parameters(): + print("testing weight", k) + np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}') + +def get_mnist_data(): + X_train, Y_train, X_test, Y_test = fetch_mnist() + BS = 32 + num_classes = 10 + X = Tensor(X_test[0:BS].astype(np.float32)) + Y = np.zeros((BS, num_classes), np.float32) + Y[range(BS),Y_test[0:BS]] = -1.0*num_classes + return X, Tensor(Y) + +class TestEnd2End(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.X, cls.Y = get_mnist_data() + + def test_linear_mnist(self): + class LinTiny: + def __init__(self, has_batchnorm=False): + self.l1 = Linear(784, 128) + self.l2 = Linear(128, 10) + self.bn1 = BatchNorm2d(128) if has_batchnorm else lambda x: x + def __call__(self, x): + return self.l2(self.l1(x)).relu().log_softmax(-1) + class LinTorch(nn.Module): + def __init__(self, has_batchnorm=False): + super().__init__() + self.l1 = nn.Linear(784, 128) + self.l2 = nn.Linear(128, 10) + def forward(self, x): + return self.l2(self.l1(x)).relu().log_softmax(-1) + compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y) + + def test_bn_mnist(self): + class LinTiny: + def __init__(self): + self.l1 = Linear(784, 128) + self.l2 = Linear(128, 10) + self.bn1 = BatchNorm2d(128) + def __call__(self, x): + return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1) + class LinTorch(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(784, 128) + self.l2 = nn.Linear(128, 10) + self.bn1 = nn.BatchNorm2d(128) + def forward(self, x): + return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1) + compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y) + + def test_bn_alone(self): + np.random.seed(1337) + X = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32)) + Y = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32)) + compare_tiny_torch(BatchNorm2d(10), nn.BatchNorm2d(10), X, Y) + + def test_bn_linear(self): + BS, K = 2, 1 + eps = 0 + X = Tensor([1,0]).reshape(BS, K, 1, 1) + Y = Tensor([-1,0]).reshape(BS, K, 1, 1) + class LinTiny: + def __init__(self): + self.l1 = Conv2d(K, K, 1, bias=False) + self.bn1 = BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps) + def __call__(self, x): return self.bn1(self.l1(x)) + class LinTorch(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Conv2d(K, K, 1, bias=False) + self.bn1 = nn.BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps) + def forward(self, x): return self.bn1(self.l1(x)) + model_torch = LinTorch() + with torch.no_grad(): + model_torch.l1.weight[:] = 1. + compare_tiny_torch(LinTiny(), model_torch, X, Y) + + def test_conv_mnist(self): + class LinTiny: + def __init__(self, has_batchnorm=False): + self.c1 = Conv2d(1, 8, 3, stride=2) + self.c2 = Conv2d(8, 16, 3, stride=2) + self.l1 = Linear(16*6*6, 10) + if has_batchnorm: + self.bn1, self.bn2 = BatchNorm2d(8), BatchNorm2d(16) + else: + self.bn1, self.bn2 = lambda x: x, lambda x: x + def __call__(self, x): + return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1) + class LinTorch(nn.Module): + def __init__(self, has_batchnorm=False): + super().__init__() + self.c1 = nn.Conv2d(1, 8, 3, stride=2) + self.c2 = nn.Conv2d(8, 16, 3, stride=2) + self.l1 = nn.Linear(16*6*6, 10) + if has_batchnorm: + self.bn1, self.bn2 = nn.BatchNorm2d(8), nn.BatchNorm2d(16) + else: + self.bn1, self.bn2 = lambda x: x, lambda x: x + def forward(self, x): + return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1) + for has_batchnorm in [False, True]: + with self.subTest(has_batchnorm=has_batchnorm): + compare_tiny_torch(LinTiny(has_batchnorm), LinTorch(has_batchnorm), self.X.reshape((-1, 1, 28, 28)), self.Y) + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index dd68ac225d..8d0c530e60 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -3,10 +3,10 @@ from tinygrad.tensor import Tensor class BatchNorm2d: def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1): - assert affine, "BatchNorm2d is only supported with affine" self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum - self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz) + if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz) + else: self.weight, self.bias = None, None self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False) self.num_batches_tracked = Tensor.zeros(1, requires_grad=False) @@ -16,16 +16,15 @@ class BatchNorm2d: # This requires two full memory accesses to x # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm - x_detached = x.detach() - batch_mean = x_detached.mean(axis=(0,2,3)) - y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1])) + batch_mean = x.mean(axis=(0,2,3)) + y = (x - batch_mean.reshape(shape=[1, -1, 1, 1])) batch_var = (y*y).mean(axis=(0,2,3)) batch_invstd = batch_var.add(self.eps).pow(-0.5) # NOTE: wow, this is done all throughout training in most PyTorch models if self.track_running_stats: - self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean) - self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var) + self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean.detach()) + self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var.detach()) self.num_batches_tracked += 1 else: batch_mean = self.running_mean diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d9f2368431..f470cef882 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -463,9 +463,11 @@ class Tensor: y = (self - self.mean(axis, keepdim=True)) return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) - def batchnorm(self, weight:Tensor, bias:Tensor, mean:Tensor, invstd:Tensor) -> Tensor: - x = (self - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1]) - return x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd) + bias.reshape(shape=[1, -1, 1, 1]) + def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor: + x = (self - mean.reshape(shape=[1, -1, 1, 1])) + if weight: x = x * weight.reshape(shape=[1, -1, 1, 1]) + ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd) + return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret def dropout(self, p=0.5) -> Tensor: if not Tensor.training: return self