Files
tinygrad/test/models/test_end2end.py
cheeetoo a0965ee198 CI < 5 minutes (#1252)
* models matrix

* fix typo and install gpu deps

* install llvm deps if needed

* fix

* testops with cuda

* remove pip cache since not work

* cuda env

* install cuda deps

* maybe it will work now

* i can't read

* all tests in matrix

* trim down more

* opencl stuff in matrix

* opencl pip cache

* test split

* change cuda test exclusion

* test

* fix cuda maybe

* add models

* add more n=auto

* third thing

* fix bug

* cache pip more

* change name

* update tests

* try again cause why not

* balance

* try again...

* try apt cache for cuda

* try on gpu:

* try cuda again

* update packages step

* replace libz-dev with zlib1g-dev

* only cache cuda

* why error

* fix gpuocelot bug

* apt cache err

* apt cache to slow?

* opt and image in single runner

* add a couple n=autos

* remove test matrix

* try cuda apt cache again

* libz-dev -> zlib1g-dev

* remove -s since not supported by xdist

* the cache takes too long and doesn't work

* combine webgpu and metal tests

* combine imagenet to c and cpu tests

* torch tests with linters

* torch back by itself

* small windows clang test with torch tests

* fix a goofy windows bug

* im dumb

* bro

* clang with linters

* fix pylint error

* linter not work on windows

* try with clang again

* clang and imagenet?

* install deps

* fix

* fix quote

* clang by itself (windows too slow)

* env vars for imagenet

* cache pip for metal and webgpu tests

* try torch with metal and webgpu

* doesn't work, too long

* remove -v

* try -n=logical

* don't use logical

* revert accidental thing

* remove some prints unless CI

* fix print unless CI

* ignore speed tests for slow tests

* clang windows in matrix (ubuntu being tested in imagenet->c test)

* try manual pip cache

* fix windows pip cache path

* all manual pip cache

* fix pip cache dir for macos

* print_ci function in helpers

* CI as variable, no print_ci

* missed one

* cuda tests with docker image

* remove setup-python action for cuda

* python->python3?

* remove -s -v

* try fix pip cache

* maybe fix

* try to fix pip cache

* is this the path?

* maybe cache pip

* try again

* create wheels dir

* ?

* cuda pip deps in dockerfile

* disable pip cache for clang

* image from ghcr instead of docker hub

* why is clang like this

* fast deps

* try use different caches

* remove the fast thing

* try with lighter image

* remove setup python for cuda

* small docker and cuda fast deps

* ignore a few more tests

* cool docker thing (maybe)

* oops

* quotes

* fix docker command

* fix bug

* ignore train efficientnet test

* remove dockerfile (docker stuff takes too long)

* remove docker stuff and normal cuda

* oops

* ignore the tests for cuda

* does this work

* ignore test_train on slow backends

* add space

* llvm ignore same tests as cuda

* nvm

* ignore lr scheduler tests

* get some stats

* fix ignore bug

* remove extra '

* remove and

* ignore test for llvm

* change ignored tests and durationon all backends

* fix

* and -> or

* ignore some more cuda tests

* finally?

* does this fix it

* remove durations=0

* add some more tests to llvm

* make last pytest more readable

* fix

* don't train efficientnet on cpu

* try w/out pip cache

* pip cache seems to be generally better

* pytest file markers

* try apt fast for cuda

* use quick install for apt-fast

* apt-fast not worth

* apt-get to apt

* fix typo

* suppress warnings

* register markers

* disable debug on fuzz tests

* change marker names

* apt update and apt install in one command

* update marker names in test.yml

* webgpu pytest marker
2023-07-23 13:00:56 -07:00

163 lines
5.9 KiB
Python

import torch
from torch import nn
import unittest
import numpy as np
from tinygrad.state import get_parameters, get_state_dict
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor
from extra.datasets import fetch_mnist
from tinygrad.helpers import CI
def compare_tiny_torch(model, model_torch, X, Y):
Tensor.training = True
model_torch.train()
model_state_dict = get_state_dict(model)
for k,v in model_torch.named_parameters():
if not CI: print(f"initting {k} from torch")
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()
optimizer = optim.SGD(get_parameters(model), lr=0.01)
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
Xt = torch.Tensor(X.numpy())
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
out = model(X)
loss = (out * Y).mean()
if not CI: print(loss.realize().numpy())
out_torch = model_torch(torch.Tensor(X.numpy()))
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
if not CI: print(loss_torch.detach().numpy())
# assert losses match
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
# zero and backward
optimizer.zero_grad()
loss.backward()
optimizer_torch.zero_grad()
loss_torch.backward()
for k,v in list(model_torch.named_parameters())[::-1]:
g = model_state_dict[k].grad.numpy()
gt = v.grad.detach().numpy()
if not CI: print("testing grads", k)
np.testing.assert_allclose(g, gt, atol=1e-3, err_msg=f'grad mismatch {k}')
# take the steps
optimizer.step()
optimizer_torch.step()
# assert weights match (they don't!)
for k,v in model_torch.named_parameters():
if not CI: print("testing weight", k)
np.testing.assert_allclose(model_state_dict[k].numpy(), v.detach().numpy(), atol=1e-3, err_msg=f'weight mismatch {k}')
def get_mnist_data():
X_train, Y_train, X_test, Y_test = fetch_mnist()
BS = 32
num_classes = 10
X = Tensor(X_test[0:BS].astype(np.float32))
Y = np.zeros((BS, num_classes), np.float32)
Y[range(BS),Y_test[0:BS]] = -1.0*num_classes
return X, Tensor(Y)
class TestEnd2End(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.X, cls.Y = get_mnist_data()
def test_linear_mnist(self):
class LinTiny:
def __init__(self, has_batchnorm=False):
self.l1 = Linear(784, 128)
self.l2 = Linear(128, 10)
self.bn1 = BatchNorm2d(128) if has_batchnorm else lambda x: x
def __call__(self, x):
return self.l2(self.l1(x)).relu().log_softmax(-1)
class LinTorch(nn.Module):
def __init__(self, has_batchnorm=False):
super().__init__()
self.l1 = nn.Linear(784, 128)
self.l2 = nn.Linear(128, 10)
def forward(self, x):
return self.l2(self.l1(x)).relu().log_softmax(-1)
compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y)
def test_bn_mnist(self):
class LinTiny:
def __init__(self):
self.l1 = Linear(784, 128)
self.l2 = Linear(128, 10)
self.bn1 = BatchNorm2d(128)
def __call__(self, x):
return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1)
class LinTorch(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(784, 128)
self.l2 = nn.Linear(128, 10)
self.bn1 = nn.BatchNorm2d(128)
def forward(self, x):
return self.l2(self.bn1(self.l1(x).reshape(x.shape[0], -1, 1, 1)).reshape(x.shape[0], -1).relu()).log_softmax(-1)
compare_tiny_torch(LinTiny(), LinTorch(), self.X, self.Y)
def test_bn_alone(self):
np.random.seed(1337)
X = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32))
Y = Tensor(np.random.randn(32, 10, 1, 1).astype(np.float32))
compare_tiny_torch(BatchNorm2d(10), nn.BatchNorm2d(10), X, Y)
def test_bn_linear(self):
BS, K = 2, 1
eps = 0
X = Tensor([1,0]).reshape(BS, K, 1, 1)
Y = Tensor([-1,0]).reshape(BS, K, 1, 1)
class LinTiny:
def __init__(self):
self.l1 = Conv2d(K, K, 1, bias=False)
self.bn1 = BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps)
def __call__(self, x): return self.bn1(self.l1(x))
class LinTorch(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Conv2d(K, K, 1, bias=False)
self.bn1 = nn.BatchNorm2d(K, affine=False, track_running_stats=False, eps=eps)
def forward(self, x): return self.bn1(self.l1(x))
model_torch = LinTorch()
with torch.no_grad():
model_torch.l1.weight[:] = 1.
compare_tiny_torch(LinTiny(), model_torch, X, Y)
def test_conv_mnist(self):
class LinTiny:
def __init__(self, has_batchnorm=False):
self.c1 = Conv2d(1, 8, 3, stride=2)
self.c2 = Conv2d(8, 16, 3, stride=2)
self.l1 = Linear(16*6*6, 10)
if has_batchnorm:
self.bn1, self.bn2 = BatchNorm2d(8), BatchNorm2d(16)
else:
self.bn1, self.bn2 = lambda x: x, lambda x: x
def __call__(self, x):
return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1)
class LinTorch(nn.Module):
def __init__(self, has_batchnorm=False):
super().__init__()
self.c1 = nn.Conv2d(1, 8, 3, stride=2)
self.c2 = nn.Conv2d(8, 16, 3, stride=2)
self.l1 = nn.Linear(16*6*6, 10)
if has_batchnorm:
self.bn1, self.bn2 = nn.BatchNorm2d(8), nn.BatchNorm2d(16)
else:
self.bn1, self.bn2 = lambda x: x, lambda x: x
def forward(self, x):
return self.l1(self.bn2(self.c2(self.bn1(self.c1(x)).relu())).relu().reshape(x.shape[0], -1)).log_softmax(-1)
for has_batchnorm in [False, True]:
with self.subTest(has_batchnorm=has_batchnorm):
compare_tiny_torch(LinTiny(has_batchnorm), LinTorch(has_batchnorm), self.X.reshape((-1, 1, 28, 28)), self.Y)
if __name__ == "__main__":
unittest.main()